Initial implementation of reliable injected packets
See #17. Not yet tested for real.
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import copy
|
||||
import dataclasses
|
||||
import datetime as dt
|
||||
import logging
|
||||
from typing import *
|
||||
@@ -13,6 +15,13 @@ from .msgtypes import PacketFlags
|
||||
from .udpserializer import UDPMessageSerializer
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ReliableResendInfo:
|
||||
last_resent: dt.datetime
|
||||
message: Message
|
||||
tries_left: int = 3
|
||||
|
||||
|
||||
class Circuit:
|
||||
def __init__(self, near_host: Optional[ADDR_TUPLE], far_host: ADDR_TUPLE, transport):
|
||||
self.near_host: Optional[ADDR_TUPLE] = near_host
|
||||
@@ -22,6 +31,8 @@ class Circuit:
|
||||
self.serializer = UDPMessageSerializer()
|
||||
self.last_packet_at = dt.datetime.now()
|
||||
self.packet_id_base = 0
|
||||
self.unacked_reliable: Dict[int, ReliableResendInfo] = {}
|
||||
self.resend_every: float = 5.0
|
||||
|
||||
def _send_prepared_message(self, message: Message, transport=None):
|
||||
try:
|
||||
@@ -51,11 +62,42 @@ class Circuit:
|
||||
# If it was queued, it's not anymore
|
||||
message.queued = False
|
||||
message.finalized = True
|
||||
# In typical client contexts all messages are "injected" in that the client
|
||||
# themselves created them.
|
||||
message.injected = True
|
||||
|
||||
def send_message(self, message: Message, transport=None):
|
||||
if self.prepare_message(message):
|
||||
# If we injected the message then we're responsible for resends.
|
||||
if message.reliable and message.injected:
|
||||
self.unacked_reliable[message.packet_id] = ReliableResendInfo(
|
||||
last_resent=dt.datetime.now(),
|
||||
message=message,
|
||||
)
|
||||
return self._send_prepared_message(message, transport)
|
||||
|
||||
def collect_ack(self, packet_id: int):
|
||||
resend_info = self.unacked_reliable.pop(packet_id, None)
|
||||
if resend_info is None:
|
||||
logging.warning(f"Got ack for unknown packed {packet_id}")
|
||||
|
||||
def resend_unacked(self):
|
||||
for resend_info in list(self.unacked_reliable.values()):
|
||||
# Not time to attempt a resend yet
|
||||
if dt.datetime.now() - resend_info.last_resent < dt.timedelta(seconds=self.resend_every):
|
||||
continue
|
||||
|
||||
msg = copy.copy(resend_info.message)
|
||||
resend_info.tries_left -= 1
|
||||
# We were on our last try and we never received an ack
|
||||
if not resend_info.tries_left:
|
||||
logging.warning(f"Giving up on unacked {msg.packet_id}")
|
||||
del self.unacked_reliable[msg.packet_id]
|
||||
continue
|
||||
resend_info.last_resent = dt.datetime.now()
|
||||
msg.send_flags |= PacketFlags.RESENT
|
||||
self._send_prepared_message(msg)
|
||||
|
||||
def send_acks(self, to_ack: Sequence[int], direction=Direction.OUT, packet_id=None):
|
||||
logging.debug("%r acking %r" % (direction, to_ack))
|
||||
# TODO: maybe tack this onto `.acks` for next message?
|
||||
|
||||
@@ -43,11 +43,15 @@ class ProxiedCircuit(Circuit):
|
||||
direction = direction or getattr(message, 'direction')
|
||||
fwd_injections, reverse_injections = self._get_injections(direction)
|
||||
|
||||
message.finalized = True
|
||||
|
||||
# Injected, let's gen an ID
|
||||
if message.packet_id is None:
|
||||
message.packet_id = fwd_injections.gen_injectable_id()
|
||||
message.injected = True
|
||||
else:
|
||||
# This message wasn't injected by the proxy so we need to rewrite packet IDs
|
||||
# to account for IDs the real creator of the packet couldn't have known about.
|
||||
elif not message.injected:
|
||||
# was_dropped needs the unmodified packet ID
|
||||
if fwd_injections.was_dropped(message.packet_id) and message.name != "PacketAck":
|
||||
logging.warning("Attempting to re-send previously dropped %s:%s, did we ack?" %
|
||||
@@ -55,15 +59,13 @@ class ProxiedCircuit(Circuit):
|
||||
message.packet_id = fwd_injections.get_effective_id(message.packet_id)
|
||||
fwd_injections.track_seen(message.packet_id)
|
||||
|
||||
message.finalized = True
|
||||
|
||||
if not message.injected:
|
||||
# This message wasn't injected by the proxy so we need to rewrite packet IDs
|
||||
# to account for IDs the other parties couldn't have known about.
|
||||
message.acks = tuple(
|
||||
reverse_injections.get_original_id(x) for x in message.acks
|
||||
if not reverse_injections.was_injected(x)
|
||||
)
|
||||
new_acks = []
|
||||
for ack in message.acks:
|
||||
if reverse_injections.was_injected(ack):
|
||||
self.collect_ack(ack)
|
||||
else:
|
||||
new_acks.append(reverse_injections.get_original_id(ack))
|
||||
message.acks = tuple(new_acks)
|
||||
|
||||
if message.name == "PacketAck":
|
||||
if not self._rewrite_packet_ack(message, reverse_injections):
|
||||
@@ -85,6 +87,7 @@ class ProxiedCircuit(Circuit):
|
||||
# This is an ACK for one the proxy injected, don't confuse
|
||||
# the other side by sending through the ACK
|
||||
if reverse_injections.was_injected(packet_id):
|
||||
self.collect_ack(packet_id)
|
||||
continue
|
||||
block["ID"] = reverse_injections.get_original_id(packet_id)
|
||||
new_blocks.append(block)
|
||||
@@ -100,6 +103,8 @@ class ProxiedCircuit(Circuit):
|
||||
new_id = fwd_injections.get_effective_id(orig_id)
|
||||
if orig_id != new_id:
|
||||
logging.debug("Rewrote oldest unacked %s -> %s" % (orig_id, new_id))
|
||||
# Use the proxy's oldest unacked instead if it's older.
|
||||
new_id = min((new_id, *self.unacked_reliable.keys()))
|
||||
message["PingID"]["OldestUnacked"] = new_id
|
||||
|
||||
def drop_message(self, message: Message, orig_direction=None):
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import weakref
|
||||
from typing import Optional, Tuple
|
||||
@@ -35,6 +36,17 @@ class InterceptingLLUDPProxyProtocol(UDPProxyProtocol):
|
||||
)
|
||||
self.message_xml = MessageDotXML()
|
||||
self.session: Optional[Session] = None
|
||||
self.resend_task = asyncio.get_event_loop().create_task(self.attempt_resends())
|
||||
|
||||
async def attempt_resends(self):
|
||||
while True:
|
||||
await asyncio.sleep(0.1)
|
||||
if self.session is None:
|
||||
continue
|
||||
for region in self.session.regions:
|
||||
if not region.circuit or not region.circuit.is_alive:
|
||||
continue
|
||||
region.circuit.resend_unacked()
|
||||
|
||||
def _ensure_message_allowed(self, msg: Message):
|
||||
if not self.message_xml.validate_udp_msg(msg.name):
|
||||
@@ -156,3 +168,4 @@ class InterceptingLLUDPProxyProtocol(UDPProxyProtocol):
|
||||
AddonManager.handle_session_closed(self.session)
|
||||
self.session_manager.close_session(self.session)
|
||||
self.session = None
|
||||
self.resend_task.cancel()
|
||||
|
||||
@@ -37,6 +37,9 @@ class BaseProxyTest(unittest.IsolatedAsyncioTestCase):
|
||||
self.serializer = UDPMessageSerializer()
|
||||
self.session.objects.track_region_objects(123)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.protocol.close()
|
||||
|
||||
async def _wait_drained(self):
|
||||
await asyncio.sleep(0.001)
|
||||
|
||||
|
||||
@@ -247,3 +247,48 @@ class PacketIDTests(unittest.TestCase):
|
||||
# ended up getting the same packet ID when injected
|
||||
(2, "ChatFromViewer", Direction.OUT, True, ()),
|
||||
])
|
||||
|
||||
def test_reliable_unacked_queueing(self):
|
||||
self._send_message(Message('ChatFromViewer', flags=PacketFlags.RELIABLE))
|
||||
self._send_message(Message('ChatFromViewer', flags=PacketFlags.RELIABLE, packet_id=2))
|
||||
# Only the first, injected message should be queued for resends
|
||||
self.assertEqual({1}, set(self.circuit.unacked_reliable))
|
||||
|
||||
def test_reliable_resend_cadence(self):
|
||||
self._send_message(Message('ChatFromViewer', flags=PacketFlags.RELIABLE))
|
||||
resend_info = self.circuit.unacked_reliable[1]
|
||||
self.circuit.resend_unacked()
|
||||
# Should have been too soon to retry
|
||||
self.assertEqual(3, resend_info.tries_left)
|
||||
# Switch to allowing resends every 0s
|
||||
self.circuit.resend_every = 0.0
|
||||
self.circuit.resend_unacked()
|
||||
self.assertSequenceEqual(self.circuit.sent_simple, [
|
||||
(1, "ChatFromViewer", Direction.OUT, True, ()),
|
||||
# Should have resent
|
||||
(1, "ChatFromViewer", Direction.OUT, True, ()),
|
||||
])
|
||||
self.assertEqual(2, resend_info.tries_left)
|
||||
for _ in range(3):
|
||||
self.circuit.resend_unacked()
|
||||
# Should have used up all the retry attempts and been kicked out of the retry queue
|
||||
self.assertEqual(set(), set(self.circuit.unacked_reliable))
|
||||
|
||||
def test_start_ping_check(self):
|
||||
# Should not break if no unacked
|
||||
self._send_message(Message(
|
||||
"StartPingCheck",
|
||||
Block("PingID", PingID=0, OldestUnacked=20),
|
||||
packet_id=5,
|
||||
))
|
||||
|
||||
injected_msg = Message('ChatFromViewer', flags=PacketFlags.RELIABLE)
|
||||
self._send_message(injected_msg)
|
||||
|
||||
self._send_message(Message(
|
||||
"StartPingCheck",
|
||||
Block("PingID", PingID=0, OldestUnacked=20),
|
||||
packet_id=8,
|
||||
))
|
||||
# Oldest unacked should have been replaced with the injected packet's ID, it's older!
|
||||
self.assertEqual(self.circuit.sent_msgs[2]["PingID"]["OldestUnacked"], injected_msg.packet_id)
|
||||
|
||||
Reference in New Issue
Block a user