diff --git a/hippolyzer/lib/proxy/circuit.py b/hippolyzer/lib/proxy/circuit.py index e63226b..8664f3b 100644 --- a/hippolyzer/lib/proxy/circuit.py +++ b/hippolyzer/lib/proxy/circuit.py @@ -34,14 +34,13 @@ class ProxiedCircuit(Circuit): return self.out_injections, self.in_injections return self.in_injections, self.out_injections - def prepare_message(self, message: Message, direction=None): + def prepare_message(self, message: Message): if message.finalized: raise RuntimeError(f"Trying to re-send finalized {message!r}") if message.queued: # This is due to be dropped, nothing should be sending the original raise RuntimeError(f"Trying to send original of queued {message!r}") - direction = direction or getattr(message, 'direction') - fwd_injections, reverse_injections = self._get_injections(direction) + fwd_injections, reverse_injections = self._get_injections(message.direction) message.finalized = True @@ -65,7 +64,7 @@ class ProxiedCircuit(Circuit): if message.name == "PacketAck": if not self._rewrite_packet_ack(message, reverse_injections) and not message.acks: - logging.debug(f"Dropping {direction} ack for injected packets!") + logging.debug(f"Dropping {message.direction} ack for injected packets!") # Let caller know this shouldn't be sent at all, it's strictly ACKs for # injected packets. return False @@ -106,13 +105,12 @@ class ProxiedCircuit(Circuit): new_id = min((new_id, *fwd_unacked)) message["PingID"]["OldestUnacked"] = new_id - def drop_message(self, message: Message, orig_direction=None): + def drop_message(self, message: Message): if message.finalized: raise RuntimeError(f"Trying to drop finalized {message!r}") if message.packet_id is None: return - orig_direction = orig_direction or message.direction - fwd_injections, reverse_injections = self._get_injections(orig_direction) + fwd_injections, reverse_injections = self._get_injections(message.direction) fwd_injections.mark_dropped(message.packet_id) message.dropped = True @@ -120,7 +118,7 @@ class ProxiedCircuit(Circuit): # Was sent reliably, tell the other end that we saw it and to shut up. if message.reliable: - self.send_acks([message.packet_id], ~orig_direction) + self.send_acks([message.packet_id], ~message.direction) # This packet had acks for the other end, send them in a separate PacketAck effective_acks = tuple( @@ -128,7 +126,7 @@ class ProxiedCircuit(Circuit): if not reverse_injections.was_injected(x) ) if effective_acks: - self.send_acks(effective_acks, orig_direction, packet_id=message.packet_id) + self.send_acks(effective_acks, message.direction, packet_id=message.packet_id) class InjectionTracker: diff --git a/tests/proxy/test_messages.py b/tests/proxy/test_messages.py index f7f9665..c6fee88 100644 --- a/tests/proxy/test_messages.py +++ b/tests/proxy/test_messages.py @@ -178,10 +178,7 @@ class PacketIDTests(unittest.IsolatedAsyncioTestCase): def test_drop_proxied_message(self): self._send_message(Message('ChatFromViewer', packet_id=1)) - self.circuit.drop_message( - Message('ChatFromViewer', packet_id=2, flags=PacketFlags.RELIABLE), - Direction.OUT, - ) + self.circuit.drop_message(Message('ChatFromViewer', packet_id=2, flags=PacketFlags.RELIABLE)) self._send_message(Message('ChatFromViewer', packet_id=3)) self.assertSequenceEqual(self.circuit.sent_simple, [ @@ -193,10 +190,7 @@ class PacketIDTests(unittest.IsolatedAsyncioTestCase): def test_unreliable_proxied_message(self): self._send_message(Message('ChatFromViewer', packet_id=1)) - self.circuit.drop_message( - Message('ChatFromViewer', packet_id=2), - Direction.OUT, - ) + self.circuit.drop_message(Message('ChatFromViewer', packet_id=2)) self._send_message(Message('ChatFromViewer', packet_id=3)) self.assertSequenceEqual(self.circuit.sent_simple, [ @@ -209,10 +203,7 @@ class PacketIDTests(unittest.IsolatedAsyncioTestCase): self._send_message(Message('ChatFromViewer', packet_id=2)) self._send_message(Message('ChatFromViewer', packet_id=3)) self._send_message(Message('ChatFromSimulator'), outgoing=False) - self.circuit.drop_message( - Message('ChatFromViewer', packet_id=4, acks=(4,)), - Direction.OUT, - ) + self.circuit.drop_message(Message('ChatFromViewer', packet_id=4, acks=(4,))) self._send_message(Message('ChatFromViewer', packet_id=5)) self.assertSequenceEqual(self.circuit.sent_simple, [