Fix ACK collection logic for injected reliable messages
This commit is contained in:
@@ -31,7 +31,7 @@ class Circuit:
|
||||
self.serializer = UDPMessageSerializer()
|
||||
self.last_packet_at = dt.datetime.now()
|
||||
self.packet_id_base = 0
|
||||
self.unacked_reliable: Dict[int, ReliableResendInfo] = {}
|
||||
self.unacked_reliable: Dict[Tuple[Direction, int], ReliableResendInfo] = {}
|
||||
self.resend_every: float = 5.0
|
||||
|
||||
def _send_prepared_message(self, message: Message, transport=None):
|
||||
@@ -57,8 +57,10 @@ class Circuit:
|
||||
raise RuntimeError(f"Trying to re-send finalized {message!r}")
|
||||
message.packet_id = self.packet_id_base
|
||||
self.packet_id_base += 1
|
||||
if not message.acks:
|
||||
message.send_flags &= PacketFlags.ACK
|
||||
if message.acks:
|
||||
message.send_flags |= PacketFlags.ACK
|
||||
else:
|
||||
message.send_flags &= ~PacketFlags.ACK
|
||||
# If it was queued, it's not anymore
|
||||
message.queued = False
|
||||
message.finalized = True
|
||||
@@ -70,16 +72,18 @@ class Circuit:
|
||||
if self.prepare_message(message):
|
||||
# If we injected the message then we're responsible for resends.
|
||||
if message.reliable and message.injected:
|
||||
self.unacked_reliable[message.packet_id] = ReliableResendInfo(
|
||||
self.unacked_reliable[(message.direction, message.packet_id)] = ReliableResendInfo(
|
||||
last_resent=dt.datetime.now(),
|
||||
message=message,
|
||||
)
|
||||
return self._send_prepared_message(message, transport)
|
||||
|
||||
def collect_ack(self, packet_id: int):
|
||||
resend_info = self.unacked_reliable.pop(packet_id, None)
|
||||
if resend_info is None:
|
||||
logging.warning(f"Got ack for unknown packed {packet_id}")
|
||||
def collect_acks(self, message: Message):
|
||||
effective_acks = list(message.acks)
|
||||
if message.name == "PacketAck":
|
||||
effective_acks.extend(x["ID"] for x in message["Packets"])
|
||||
for ack in effective_acks:
|
||||
self.unacked_reliable.pop((~message.direction, ack), None)
|
||||
|
||||
def resend_unacked(self):
|
||||
for resend_info in list(self.unacked_reliable.values()):
|
||||
@@ -92,7 +96,7 @@ class Circuit:
|
||||
# We were on our last try and we never received an ack
|
||||
if not resend_info.tries_left:
|
||||
logging.warning(f"Giving up on unacked {msg.packet_id}")
|
||||
del self.unacked_reliable[msg.packet_id]
|
||||
del self.unacked_reliable[(msg.direction, msg.packet_id)]
|
||||
continue
|
||||
resend_info.last_resent = dt.datetime.now()
|
||||
msg.send_flags |= PacketFlags.RESENT
|
||||
|
||||
@@ -58,17 +58,13 @@ class ProxiedCircuit(Circuit):
|
||||
(message.packet_id, message.name))
|
||||
message.packet_id = fwd_injections.get_effective_id(message.packet_id)
|
||||
fwd_injections.track_seen(message.packet_id)
|
||||
|
||||
new_acks = []
|
||||
for ack in message.acks:
|
||||
if reverse_injections.was_injected(ack):
|
||||
self.collect_ack(ack)
|
||||
else:
|
||||
new_acks.append(reverse_injections.get_original_id(ack))
|
||||
message.acks = tuple(new_acks)
|
||||
message.acks = tuple(
|
||||
reverse_injections.get_original_id(x) for x in message.acks
|
||||
if not reverse_injections.was_injected(x)
|
||||
)
|
||||
|
||||
if message.name == "PacketAck":
|
||||
if not self._rewrite_packet_ack(message, reverse_injections):
|
||||
if not self._rewrite_packet_ack(message, reverse_injections) and not message.acks:
|
||||
logging.debug(f"Dropping {direction} ack for injected packets!")
|
||||
# Let caller know this shouldn't be sent at all, it's strictly ACKs for
|
||||
# injected packets.
|
||||
@@ -76,7 +72,9 @@ class ProxiedCircuit(Circuit):
|
||||
elif message.name == "StartPingCheck":
|
||||
self._rewrite_start_ping_check(message, fwd_injections)
|
||||
|
||||
if not message.acks:
|
||||
if message.acks:
|
||||
message.send_flags |= PacketFlags.ACK
|
||||
else:
|
||||
message.send_flags &= ~PacketFlags.ACK
|
||||
return True
|
||||
|
||||
@@ -87,7 +85,6 @@ class ProxiedCircuit(Circuit):
|
||||
# This is an ACK for one the proxy injected, don't confuse
|
||||
# the other side by sending through the ACK
|
||||
if reverse_injections.was_injected(packet_id):
|
||||
self.collect_ack(packet_id)
|
||||
continue
|
||||
block["ID"] = reverse_injections.get_original_id(packet_id)
|
||||
new_blocks.append(block)
|
||||
@@ -103,8 +100,10 @@ class ProxiedCircuit(Circuit):
|
||||
new_id = fwd_injections.get_effective_id(orig_id)
|
||||
if orig_id != new_id:
|
||||
logging.debug("Rewrote oldest unacked %s -> %s" % (orig_id, new_id))
|
||||
# Use the proxy's oldest unacked instead if it's older.
|
||||
new_id = min((new_id, *self.unacked_reliable.keys()))
|
||||
# Get a list of unacked IDs for the direction this StartPingCheck is heading
|
||||
fwd_unacked = (a for (d, a) in self.unacked_reliable.keys() if d == message.direction)
|
||||
# Use the proxy's oldest unacked ID if it's older than the client's
|
||||
new_id = min((new_id, *fwd_unacked))
|
||||
message["PingID"]["OldestUnacked"] = new_id
|
||||
|
||||
def drop_message(self, message: Message, orig_direction=None):
|
||||
|
||||
@@ -111,6 +111,9 @@ class InterceptingLLUDPProxyProtocol(UDPProxyProtocol):
|
||||
LOG.error("No circuit for %r, dropping packet!" % (packet.far_addr,))
|
||||
return
|
||||
|
||||
# Process any ACKs for messages we injected first
|
||||
region.circuit.collect_acks(message)
|
||||
|
||||
if message.name == "AgentMovementComplete":
|
||||
self.session.main_region = region
|
||||
if region.handle is None:
|
||||
|
||||
@@ -252,11 +252,11 @@ class PacketIDTests(unittest.TestCase):
|
||||
self._send_message(Message('ChatFromViewer', flags=PacketFlags.RELIABLE))
|
||||
self._send_message(Message('ChatFromViewer', flags=PacketFlags.RELIABLE, packet_id=2))
|
||||
# Only the first, injected message should be queued for resends
|
||||
self.assertEqual({1}, set(self.circuit.unacked_reliable))
|
||||
self.assertEqual({(Direction.OUT, 1)}, set(self.circuit.unacked_reliable))
|
||||
|
||||
def test_reliable_resend_cadence(self):
|
||||
self._send_message(Message('ChatFromViewer', flags=PacketFlags.RELIABLE))
|
||||
resend_info = self.circuit.unacked_reliable[1]
|
||||
resend_info = self.circuit.unacked_reliable[(Direction.OUT, 1)]
|
||||
self.circuit.resend_unacked()
|
||||
# Should have been too soon to retry
|
||||
self.assertEqual(3, resend_info.tries_left)
|
||||
@@ -274,6 +274,19 @@ class PacketIDTests(unittest.TestCase):
|
||||
# Should have used up all the retry attempts and been kicked out of the retry queue
|
||||
self.assertEqual(set(), set(self.circuit.unacked_reliable))
|
||||
|
||||
def test_reliable_ack_collection(self):
|
||||
msg = Message('ChatFromViewer', flags=PacketFlags.RELIABLE)
|
||||
self._send_message(msg)
|
||||
self.assertEqual(1, len(self.circuit.unacked_reliable))
|
||||
# Shouldn't count, this is an ACK going in the wrong direction!
|
||||
ack_msg = Message("PacketAck", Block("Packets", ID=msg.packet_id))
|
||||
self.circuit.collect_acks(ack_msg)
|
||||
self.assertEqual(1, len(self.circuit.unacked_reliable))
|
||||
# But it should count if the ACK message is heading in
|
||||
ack_msg.direction = Direction.IN
|
||||
self.circuit.collect_acks(ack_msg)
|
||||
self.assertEqual(0, len(self.circuit.unacked_reliable))
|
||||
|
||||
def test_start_ping_check(self):
|
||||
# Should not break if no unacked
|
||||
self._send_message(Message(
|
||||
|
||||
Reference in New Issue
Block a user