diff --git a/hippolyzer/lib/base/message/circuit.py b/hippolyzer/lib/base/message/circuit.py index 981f88c..616562c 100644 --- a/hippolyzer/lib/base/message/circuit.py +++ b/hippolyzer/lib/base/message/circuit.py @@ -31,7 +31,7 @@ class Circuit: self.serializer = UDPMessageSerializer() self.last_packet_at = dt.datetime.now() self.packet_id_base = 0 - self.unacked_reliable: Dict[int, ReliableResendInfo] = {} + self.unacked_reliable: Dict[Tuple[Direction, int], ReliableResendInfo] = {} self.resend_every: float = 5.0 def _send_prepared_message(self, message: Message, transport=None): @@ -57,8 +57,10 @@ class Circuit: raise RuntimeError(f"Trying to re-send finalized {message!r}") message.packet_id = self.packet_id_base self.packet_id_base += 1 - if not message.acks: - message.send_flags &= PacketFlags.ACK + if message.acks: + message.send_flags |= PacketFlags.ACK + else: + message.send_flags &= ~PacketFlags.ACK # If it was queued, it's not anymore message.queued = False message.finalized = True @@ -70,16 +72,18 @@ class Circuit: if self.prepare_message(message): # If we injected the message then we're responsible for resends. if message.reliable and message.injected: - self.unacked_reliable[message.packet_id] = ReliableResendInfo( + self.unacked_reliable[(message.direction, message.packet_id)] = ReliableResendInfo( last_resent=dt.datetime.now(), message=message, ) return self._send_prepared_message(message, transport) - def collect_ack(self, packet_id: int): - resend_info = self.unacked_reliable.pop(packet_id, None) - if resend_info is None: - logging.warning(f"Got ack for unknown packed {packet_id}") + def collect_acks(self, message: Message): + effective_acks = list(message.acks) + if message.name == "PacketAck": + effective_acks.extend(x["ID"] for x in message["Packets"]) + for ack in effective_acks: + self.unacked_reliable.pop((~message.direction, ack), None) def resend_unacked(self): for resend_info in list(self.unacked_reliable.values()): @@ -92,7 +96,7 @@ class Circuit: # We were on our last try and we never received an ack if not resend_info.tries_left: logging.warning(f"Giving up on unacked {msg.packet_id}") - del self.unacked_reliable[msg.packet_id] + del self.unacked_reliable[(msg.direction, msg.packet_id)] continue resend_info.last_resent = dt.datetime.now() msg.send_flags |= PacketFlags.RESENT diff --git a/hippolyzer/lib/proxy/circuit.py b/hippolyzer/lib/proxy/circuit.py index a51c267..e63226b 100644 --- a/hippolyzer/lib/proxy/circuit.py +++ b/hippolyzer/lib/proxy/circuit.py @@ -58,17 +58,13 @@ class ProxiedCircuit(Circuit): (message.packet_id, message.name)) message.packet_id = fwd_injections.get_effective_id(message.packet_id) fwd_injections.track_seen(message.packet_id) - - new_acks = [] - for ack in message.acks: - if reverse_injections.was_injected(ack): - self.collect_ack(ack) - else: - new_acks.append(reverse_injections.get_original_id(ack)) - message.acks = tuple(new_acks) + message.acks = tuple( + reverse_injections.get_original_id(x) for x in message.acks + if not reverse_injections.was_injected(x) + ) if message.name == "PacketAck": - if not self._rewrite_packet_ack(message, reverse_injections): + if not self._rewrite_packet_ack(message, reverse_injections) and not message.acks: logging.debug(f"Dropping {direction} ack for injected packets!") # Let caller know this shouldn't be sent at all, it's strictly ACKs for # injected packets. @@ -76,7 +72,9 @@ class ProxiedCircuit(Circuit): elif message.name == "StartPingCheck": self._rewrite_start_ping_check(message, fwd_injections) - if not message.acks: + if message.acks: + message.send_flags |= PacketFlags.ACK + else: message.send_flags &= ~PacketFlags.ACK return True @@ -87,7 +85,6 @@ class ProxiedCircuit(Circuit): # This is an ACK for one the proxy injected, don't confuse # the other side by sending through the ACK if reverse_injections.was_injected(packet_id): - self.collect_ack(packet_id) continue block["ID"] = reverse_injections.get_original_id(packet_id) new_blocks.append(block) @@ -103,8 +100,10 @@ class ProxiedCircuit(Circuit): new_id = fwd_injections.get_effective_id(orig_id) if orig_id != new_id: logging.debug("Rewrote oldest unacked %s -> %s" % (orig_id, new_id)) - # Use the proxy's oldest unacked instead if it's older. - new_id = min((new_id, *self.unacked_reliable.keys())) + # Get a list of unacked IDs for the direction this StartPingCheck is heading + fwd_unacked = (a for (d, a) in self.unacked_reliable.keys() if d == message.direction) + # Use the proxy's oldest unacked ID if it's older than the client's + new_id = min((new_id, *fwd_unacked)) message["PingID"]["OldestUnacked"] = new_id def drop_message(self, message: Message, orig_direction=None): diff --git a/hippolyzer/lib/proxy/lludp_proxy.py b/hippolyzer/lib/proxy/lludp_proxy.py index a115952..44a76a9 100644 --- a/hippolyzer/lib/proxy/lludp_proxy.py +++ b/hippolyzer/lib/proxy/lludp_proxy.py @@ -111,6 +111,9 @@ class InterceptingLLUDPProxyProtocol(UDPProxyProtocol): LOG.error("No circuit for %r, dropping packet!" % (packet.far_addr,)) return + # Process any ACKs for messages we injected first + region.circuit.collect_acks(message) + if message.name == "AgentMovementComplete": self.session.main_region = region if region.handle is None: diff --git a/tests/proxy/test_messages.py b/tests/proxy/test_messages.py index 1cb0110..bbbf835 100644 --- a/tests/proxy/test_messages.py +++ b/tests/proxy/test_messages.py @@ -252,11 +252,11 @@ class PacketIDTests(unittest.TestCase): self._send_message(Message('ChatFromViewer', flags=PacketFlags.RELIABLE)) self._send_message(Message('ChatFromViewer', flags=PacketFlags.RELIABLE, packet_id=2)) # Only the first, injected message should be queued for resends - self.assertEqual({1}, set(self.circuit.unacked_reliable)) + self.assertEqual({(Direction.OUT, 1)}, set(self.circuit.unacked_reliable)) def test_reliable_resend_cadence(self): self._send_message(Message('ChatFromViewer', flags=PacketFlags.RELIABLE)) - resend_info = self.circuit.unacked_reliable[1] + resend_info = self.circuit.unacked_reliable[(Direction.OUT, 1)] self.circuit.resend_unacked() # Should have been too soon to retry self.assertEqual(3, resend_info.tries_left) @@ -274,6 +274,19 @@ class PacketIDTests(unittest.TestCase): # Should have used up all the retry attempts and been kicked out of the retry queue self.assertEqual(set(), set(self.circuit.unacked_reliable)) + def test_reliable_ack_collection(self): + msg = Message('ChatFromViewer', flags=PacketFlags.RELIABLE) + self._send_message(msg) + self.assertEqual(1, len(self.circuit.unacked_reliable)) + # Shouldn't count, this is an ACK going in the wrong direction! + ack_msg = Message("PacketAck", Block("Packets", ID=msg.packet_id)) + self.circuit.collect_acks(ack_msg) + self.assertEqual(1, len(self.circuit.unacked_reliable)) + # But it should count if the ACK message is heading in + ack_msg.direction = Direction.IN + self.circuit.collect_acks(ack_msg) + self.assertEqual(0, len(self.circuit.unacked_reliable)) + def test_start_ping_check(self): # Should not break if no unacked self._send_message(Message(