Fix ACK collection logic for injected reliable messages

This commit is contained in:
Salad Dais
2021-12-08 22:14:36 +00:00
parent d486aa130d
commit eb6406bca4
4 changed files with 43 additions and 24 deletions

View File

@@ -31,7 +31,7 @@ class Circuit:
self.serializer = UDPMessageSerializer()
self.last_packet_at = dt.datetime.now()
self.packet_id_base = 0
self.unacked_reliable: Dict[int, ReliableResendInfo] = {}
self.unacked_reliable: Dict[Tuple[Direction, int], ReliableResendInfo] = {}
self.resend_every: float = 5.0
def _send_prepared_message(self, message: Message, transport=None):
@@ -57,8 +57,10 @@ class Circuit:
raise RuntimeError(f"Trying to re-send finalized {message!r}")
message.packet_id = self.packet_id_base
self.packet_id_base += 1
if not message.acks:
message.send_flags &= PacketFlags.ACK
if message.acks:
message.send_flags |= PacketFlags.ACK
else:
message.send_flags &= ~PacketFlags.ACK
# If it was queued, it's not anymore
message.queued = False
message.finalized = True
@@ -70,16 +72,18 @@ class Circuit:
if self.prepare_message(message):
# If we injected the message then we're responsible for resends.
if message.reliable and message.injected:
self.unacked_reliable[message.packet_id] = ReliableResendInfo(
self.unacked_reliable[(message.direction, message.packet_id)] = ReliableResendInfo(
last_resent=dt.datetime.now(),
message=message,
)
return self._send_prepared_message(message, transport)
def collect_ack(self, packet_id: int):
resend_info = self.unacked_reliable.pop(packet_id, None)
if resend_info is None:
logging.warning(f"Got ack for unknown packed {packet_id}")
def collect_acks(self, message: Message):
effective_acks = list(message.acks)
if message.name == "PacketAck":
effective_acks.extend(x["ID"] for x in message["Packets"])
for ack in effective_acks:
self.unacked_reliable.pop((~message.direction, ack), None)
def resend_unacked(self):
for resend_info in list(self.unacked_reliable.values()):
@@ -92,7 +96,7 @@ class Circuit:
# We were on our last try and we never received an ack
if not resend_info.tries_left:
logging.warning(f"Giving up on unacked {msg.packet_id}")
del self.unacked_reliable[msg.packet_id]
del self.unacked_reliable[(msg.direction, msg.packet_id)]
continue
resend_info.last_resent = dt.datetime.now()
msg.send_flags |= PacketFlags.RESENT

View File

@@ -58,17 +58,13 @@ class ProxiedCircuit(Circuit):
(message.packet_id, message.name))
message.packet_id = fwd_injections.get_effective_id(message.packet_id)
fwd_injections.track_seen(message.packet_id)
new_acks = []
for ack in message.acks:
if reverse_injections.was_injected(ack):
self.collect_ack(ack)
else:
new_acks.append(reverse_injections.get_original_id(ack))
message.acks = tuple(new_acks)
message.acks = tuple(
reverse_injections.get_original_id(x) for x in message.acks
if not reverse_injections.was_injected(x)
)
if message.name == "PacketAck":
if not self._rewrite_packet_ack(message, reverse_injections):
if not self._rewrite_packet_ack(message, reverse_injections) and not message.acks:
logging.debug(f"Dropping {direction} ack for injected packets!")
# Let caller know this shouldn't be sent at all, it's strictly ACKs for
# injected packets.
@@ -76,7 +72,9 @@ class ProxiedCircuit(Circuit):
elif message.name == "StartPingCheck":
self._rewrite_start_ping_check(message, fwd_injections)
if not message.acks:
if message.acks:
message.send_flags |= PacketFlags.ACK
else:
message.send_flags &= ~PacketFlags.ACK
return True
@@ -87,7 +85,6 @@ class ProxiedCircuit(Circuit):
# This is an ACK for one the proxy injected, don't confuse
# the other side by sending through the ACK
if reverse_injections.was_injected(packet_id):
self.collect_ack(packet_id)
continue
block["ID"] = reverse_injections.get_original_id(packet_id)
new_blocks.append(block)
@@ -103,8 +100,10 @@ class ProxiedCircuit(Circuit):
new_id = fwd_injections.get_effective_id(orig_id)
if orig_id != new_id:
logging.debug("Rewrote oldest unacked %s -> %s" % (orig_id, new_id))
# Use the proxy's oldest unacked instead if it's older.
new_id = min((new_id, *self.unacked_reliable.keys()))
# Get a list of unacked IDs for the direction this StartPingCheck is heading
fwd_unacked = (a for (d, a) in self.unacked_reliable.keys() if d == message.direction)
# Use the proxy's oldest unacked ID if it's older than the client's
new_id = min((new_id, *fwd_unacked))
message["PingID"]["OldestUnacked"] = new_id
def drop_message(self, message: Message, orig_direction=None):

View File

@@ -111,6 +111,9 @@ class InterceptingLLUDPProxyProtocol(UDPProxyProtocol):
LOG.error("No circuit for %r, dropping packet!" % (packet.far_addr,))
return
# Process any ACKs for messages we injected first
region.circuit.collect_acks(message)
if message.name == "AgentMovementComplete":
self.session.main_region = region
if region.handle is None:

View File

@@ -252,11 +252,11 @@ class PacketIDTests(unittest.TestCase):
self._send_message(Message('ChatFromViewer', flags=PacketFlags.RELIABLE))
self._send_message(Message('ChatFromViewer', flags=PacketFlags.RELIABLE, packet_id=2))
# Only the first, injected message should be queued for resends
self.assertEqual({1}, set(self.circuit.unacked_reliable))
self.assertEqual({(Direction.OUT, 1)}, set(self.circuit.unacked_reliable))
def test_reliable_resend_cadence(self):
self._send_message(Message('ChatFromViewer', flags=PacketFlags.RELIABLE))
resend_info = self.circuit.unacked_reliable[1]
resend_info = self.circuit.unacked_reliable[(Direction.OUT, 1)]
self.circuit.resend_unacked()
# Should have been too soon to retry
self.assertEqual(3, resend_info.tries_left)
@@ -274,6 +274,19 @@ class PacketIDTests(unittest.TestCase):
# Should have used up all the retry attempts and been kicked out of the retry queue
self.assertEqual(set(), set(self.circuit.unacked_reliable))
def test_reliable_ack_collection(self):
msg = Message('ChatFromViewer', flags=PacketFlags.RELIABLE)
self._send_message(msg)
self.assertEqual(1, len(self.circuit.unacked_reliable))
# Shouldn't count, this is an ACK going in the wrong direction!
ack_msg = Message("PacketAck", Block("Packets", ID=msg.packet_id))
self.circuit.collect_acks(ack_msg)
self.assertEqual(1, len(self.circuit.unacked_reliable))
# But it should count if the ACK message is heading in
ack_msg.direction = Direction.IN
self.circuit.collect_acks(ack_msg)
self.assertEqual(0, len(self.circuit.unacked_reliable))
def test_start_ping_check(self):
# Should not break if no unacked
self._send_message(Message(