Add awaitable send_reliable()

This commit is contained in:
Salad Dais
2021-12-09 05:30:35 +00:00
parent 8b93c5eefa
commit 3b610fdfd1
2 changed files with 16 additions and 2 deletions

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import abc
import asyncio
import copy
import dataclasses
import datetime as dt
@@ -19,6 +20,7 @@ from .udpserializer import UDPMessageSerializer
class ReliableResendInfo:
last_resent: dt.datetime
message: Message
completed: asyncio.Future = dataclasses.field(default_factory=asyncio.Future)
tries_left: int = 3
@@ -81,12 +83,21 @@ class Circuit:
# Temporary alias
send_message = send
def send_reliable(self, message: Message, transport=None) -> asyncio.Future:
if message.packet_id and not message.injected:
raise ValueError("Not able to send non-synthetic message reliably!")
message.send_flags |= PacketFlags.RELIABLE
self.send(message, transport)
return self.unacked_reliable[(message.direction, message.packet_id)].completed
def collect_acks(self, message: Message):
effective_acks = list(message.acks)
if message.name == "PacketAck":
effective_acks.extend(x["ID"] for x in message["Packets"])
for ack in effective_acks:
self.unacked_reliable.pop((~message.direction, ack), None)
resend_info = self.unacked_reliable.pop((~message.direction, ack), None)
if resend_info:
resend_info.completed.set_result(None)
def resend_unacked(self):
for resend_info in list(self.unacked_reliable.values()):
@@ -100,6 +111,7 @@ class Circuit:
if not resend_info.tries_left:
logging.warning(f"Giving up on unacked {msg.packet_id}")
del self.unacked_reliable[(msg.direction, msg.packet_id)]
resend_info.completed.set_exception(TimeoutError("Exceeded resend limit"))
continue
resend_info.last_resent = dt.datetime.now()
msg.send_flags |= PacketFlags.RESENT

View File

@@ -276,16 +276,18 @@ class PacketIDTests(unittest.IsolatedAsyncioTestCase):
def test_reliable_ack_collection(self):
msg = Message('ChatFromViewer', flags=PacketFlags.RELIABLE)
self._send_message(msg)
fut = self.circuit.send_reliable(msg)
self.assertEqual(1, len(self.circuit.unacked_reliable))
# Shouldn't count, this is an ACK going in the wrong direction!
ack_msg = Message("PacketAck", Block("Packets", ID=msg.packet_id))
self.circuit.collect_acks(ack_msg)
self.assertEqual(1, len(self.circuit.unacked_reliable))
self.assertFalse(fut.done())
# But it should count if the ACK message is heading in
ack_msg.direction = Direction.IN
self.circuit.collect_acks(ack_msg)
self.assertEqual(0, len(self.circuit.unacked_reliable))
self.assertTrue(fut.done())
def test_start_ping_check(self):
# Should not break if no unacked