From 3b610fdfd11395e1da8d86df8664164ca3c41718 Mon Sep 17 00:00:00 2001 From: Salad Dais Date: Thu, 9 Dec 2021 05:30:35 +0000 Subject: [PATCH] Add awaitable send_reliable() --- hippolyzer/lib/base/message/circuit.py | 14 +++++++++++++- tests/proxy/test_messages.py | 4 +++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/hippolyzer/lib/base/message/circuit.py b/hippolyzer/lib/base/message/circuit.py index 3ac9bca..4f6b832 100644 --- a/hippolyzer/lib/base/message/circuit.py +++ b/hippolyzer/lib/base/message/circuit.py @@ -1,6 +1,7 @@ from __future__ import annotations import abc +import asyncio import copy import dataclasses import datetime as dt @@ -19,6 +20,7 @@ from .udpserializer import UDPMessageSerializer class ReliableResendInfo: last_resent: dt.datetime message: Message + completed: asyncio.Future = dataclasses.field(default_factory=asyncio.Future) tries_left: int = 3 @@ -81,12 +83,21 @@ class Circuit: # Temporary alias send_message = send + def send_reliable(self, message: Message, transport=None) -> asyncio.Future: + if message.packet_id and not message.injected: + raise ValueError("Not able to send non-synthetic message reliably!") + message.send_flags |= PacketFlags.RELIABLE + self.send(message, transport) + return self.unacked_reliable[(message.direction, message.packet_id)].completed + def collect_acks(self, message: Message): effective_acks = list(message.acks) if message.name == "PacketAck": effective_acks.extend(x["ID"] for x in message["Packets"]) for ack in effective_acks: - self.unacked_reliable.pop((~message.direction, ack), None) + resend_info = self.unacked_reliable.pop((~message.direction, ack), None) + if resend_info: + resend_info.completed.set_result(None) def resend_unacked(self): for resend_info in list(self.unacked_reliable.values()): @@ -100,6 +111,7 @@ class Circuit: if not resend_info.tries_left: logging.warning(f"Giving up on unacked {msg.packet_id}") del self.unacked_reliable[(msg.direction, msg.packet_id)] + resend_info.completed.set_exception(TimeoutError("Exceeded resend limit")) continue resend_info.last_resent = dt.datetime.now() msg.send_flags |= PacketFlags.RESENT diff --git a/tests/proxy/test_messages.py b/tests/proxy/test_messages.py index 4b3f533..f7f9665 100644 --- a/tests/proxy/test_messages.py +++ b/tests/proxy/test_messages.py @@ -276,16 +276,18 @@ class PacketIDTests(unittest.IsolatedAsyncioTestCase): def test_reliable_ack_collection(self): msg = Message('ChatFromViewer', flags=PacketFlags.RELIABLE) - self._send_message(msg) + fut = self.circuit.send_reliable(msg) self.assertEqual(1, len(self.circuit.unacked_reliable)) # Shouldn't count, this is an ACK going in the wrong direction! ack_msg = Message("PacketAck", Block("Packets", ID=msg.packet_id)) self.circuit.collect_acks(ack_msg) self.assertEqual(1, len(self.circuit.unacked_reliable)) + self.assertFalse(fut.done()) # But it should count if the ACK message is heading in ack_msg.direction = Direction.IN self.circuit.collect_acks(ack_msg) self.assertEqual(0, len(self.circuit.unacked_reliable)) + self.assertTrue(fut.done()) def test_start_ping_check(self): # Should not break if no unacked