Initial implementation of reliable injected packets

See #17. Not yet tested for real.
This commit is contained in:
Salad Dais
2021-12-08 03:38:31 +00:00
parent d86da70eeb
commit d66d5226a2
5 changed files with 118 additions and 10 deletions

View File

@@ -1,6 +1,8 @@
from __future__ import annotations
import abc
import copy
import dataclasses
import datetime as dt
import logging
from typing import *
@@ -13,6 +15,13 @@ from .msgtypes import PacketFlags
from .udpserializer import UDPMessageSerializer
@dataclasses.dataclass
class ReliableResendInfo:
last_resent: dt.datetime
message: Message
tries_left: int = 3
class Circuit:
def __init__(self, near_host: Optional[ADDR_TUPLE], far_host: ADDR_TUPLE, transport):
self.near_host: Optional[ADDR_TUPLE] = near_host
@@ -22,6 +31,8 @@ class Circuit:
self.serializer = UDPMessageSerializer()
self.last_packet_at = dt.datetime.now()
self.packet_id_base = 0
self.unacked_reliable: Dict[int, ReliableResendInfo] = {}
self.resend_every: float = 5.0
def _send_prepared_message(self, message: Message, transport=None):
try:
@@ -51,11 +62,42 @@ class Circuit:
# If it was queued, it's not anymore
message.queued = False
message.finalized = True
# In typical client contexts all messages are "injected" in that the client
# themselves created them.
message.injected = True
def send_message(self, message: Message, transport=None):
if self.prepare_message(message):
# If we injected the message then we're responsible for resends.
if message.reliable and message.injected:
self.unacked_reliable[message.packet_id] = ReliableResendInfo(
last_resent=dt.datetime.now(),
message=message,
)
return self._send_prepared_message(message, transport)
def collect_ack(self, packet_id: int):
resend_info = self.unacked_reliable.pop(packet_id, None)
if resend_info is None:
logging.warning(f"Got ack for unknown packed {packet_id}")
def resend_unacked(self):
for resend_info in list(self.unacked_reliable.values()):
# Not time to attempt a resend yet
if dt.datetime.now() - resend_info.last_resent < dt.timedelta(seconds=self.resend_every):
continue
msg = copy.copy(resend_info.message)
resend_info.tries_left -= 1
# We were on our last try and we never received an ack
if not resend_info.tries_left:
logging.warning(f"Giving up on unacked {msg.packet_id}")
del self.unacked_reliable[msg.packet_id]
continue
resend_info.last_resent = dt.datetime.now()
msg.send_flags |= PacketFlags.RESENT
self._send_prepared_message(msg)
def send_acks(self, to_ack: Sequence[int], direction=Direction.OUT, packet_id=None):
logging.debug("%r acking %r" % (direction, to_ack))
# TODO: maybe tack this onto `.acks` for next message?

View File

@@ -43,11 +43,15 @@ class ProxiedCircuit(Circuit):
direction = direction or getattr(message, 'direction')
fwd_injections, reverse_injections = self._get_injections(direction)
message.finalized = True
# Injected, let's gen an ID
if message.packet_id is None:
message.packet_id = fwd_injections.gen_injectable_id()
message.injected = True
else:
# This message wasn't injected by the proxy so we need to rewrite packet IDs
# to account for IDs the real creator of the packet couldn't have known about.
elif not message.injected:
# was_dropped needs the unmodified packet ID
if fwd_injections.was_dropped(message.packet_id) and message.name != "PacketAck":
logging.warning("Attempting to re-send previously dropped %s:%s, did we ack?" %
@@ -55,15 +59,13 @@ class ProxiedCircuit(Circuit):
message.packet_id = fwd_injections.get_effective_id(message.packet_id)
fwd_injections.track_seen(message.packet_id)
message.finalized = True
if not message.injected:
# This message wasn't injected by the proxy so we need to rewrite packet IDs
# to account for IDs the other parties couldn't have known about.
message.acks = tuple(
reverse_injections.get_original_id(x) for x in message.acks
if not reverse_injections.was_injected(x)
)
new_acks = []
for ack in message.acks:
if reverse_injections.was_injected(ack):
self.collect_ack(ack)
else:
new_acks.append(reverse_injections.get_original_id(ack))
message.acks = tuple(new_acks)
if message.name == "PacketAck":
if not self._rewrite_packet_ack(message, reverse_injections):
@@ -85,6 +87,7 @@ class ProxiedCircuit(Circuit):
# This is an ACK for one the proxy injected, don't confuse
# the other side by sending through the ACK
if reverse_injections.was_injected(packet_id):
self.collect_ack(packet_id)
continue
block["ID"] = reverse_injections.get_original_id(packet_id)
new_blocks.append(block)
@@ -100,6 +103,8 @@ class ProxiedCircuit(Circuit):
new_id = fwd_injections.get_effective_id(orig_id)
if orig_id != new_id:
logging.debug("Rewrote oldest unacked %s -> %s" % (orig_id, new_id))
# Use the proxy's oldest unacked instead if it's older.
new_id = min((new_id, *self.unacked_reliable.keys()))
message["PingID"]["OldestUnacked"] = new_id
def drop_message(self, message: Message, orig_direction=None):

View File

@@ -1,3 +1,4 @@
import asyncio
import logging
import weakref
from typing import Optional, Tuple
@@ -35,6 +36,17 @@ class InterceptingLLUDPProxyProtocol(UDPProxyProtocol):
)
self.message_xml = MessageDotXML()
self.session: Optional[Session] = None
self.resend_task = asyncio.get_event_loop().create_task(self.attempt_resends())
async def attempt_resends(self):
while True:
await asyncio.sleep(0.1)
if self.session is None:
continue
for region in self.session.regions:
if not region.circuit or not region.circuit.is_alive:
continue
region.circuit.resend_unacked()
def _ensure_message_allowed(self, msg: Message):
if not self.message_xml.validate_udp_msg(msg.name):
@@ -156,3 +168,4 @@ class InterceptingLLUDPProxyProtocol(UDPProxyProtocol):
AddonManager.handle_session_closed(self.session)
self.session_manager.close_session(self.session)
self.session = None
self.resend_task.cancel()

View File

@@ -37,6 +37,9 @@ class BaseProxyTest(unittest.IsolatedAsyncioTestCase):
self.serializer = UDPMessageSerializer()
self.session.objects.track_region_objects(123)
def tearDown(self) -> None:
self.protocol.close()
async def _wait_drained(self):
await asyncio.sleep(0.001)

View File

@@ -247,3 +247,48 @@ class PacketIDTests(unittest.TestCase):
# ended up getting the same packet ID when injected
(2, "ChatFromViewer", Direction.OUT, True, ()),
])
def test_reliable_unacked_queueing(self):
self._send_message(Message('ChatFromViewer', flags=PacketFlags.RELIABLE))
self._send_message(Message('ChatFromViewer', flags=PacketFlags.RELIABLE, packet_id=2))
# Only the first, injected message should be queued for resends
self.assertEqual({1}, set(self.circuit.unacked_reliable))
def test_reliable_resend_cadence(self):
self._send_message(Message('ChatFromViewer', flags=PacketFlags.RELIABLE))
resend_info = self.circuit.unacked_reliable[1]
self.circuit.resend_unacked()
# Should have been too soon to retry
self.assertEqual(3, resend_info.tries_left)
# Switch to allowing resends every 0s
self.circuit.resend_every = 0.0
self.circuit.resend_unacked()
self.assertSequenceEqual(self.circuit.sent_simple, [
(1, "ChatFromViewer", Direction.OUT, True, ()),
# Should have resent
(1, "ChatFromViewer", Direction.OUT, True, ()),
])
self.assertEqual(2, resend_info.tries_left)
for _ in range(3):
self.circuit.resend_unacked()
# Should have used up all the retry attempts and been kicked out of the retry queue
self.assertEqual(set(), set(self.circuit.unacked_reliable))
def test_start_ping_check(self):
# Should not break if no unacked
self._send_message(Message(
"StartPingCheck",
Block("PingID", PingID=0, OldestUnacked=20),
packet_id=5,
))
injected_msg = Message('ChatFromViewer', flags=PacketFlags.RELIABLE)
self._send_message(injected_msg)
self._send_message(Message(
"StartPingCheck",
Block("PingID", PingID=0, OldestUnacked=20),
packet_id=8,
))
# Oldest unacked should have been replaced with the injected packet's ID, it's older!
self.assertEqual(self.circuit.sent_msgs[2]["PingID"]["OldestUnacked"], injected_msg.packet_id)