Add awaitable send_reliable()
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import asyncio
|
||||
import copy
|
||||
import dataclasses
|
||||
import datetime as dt
|
||||
@@ -19,6 +20,7 @@ from .udpserializer import UDPMessageSerializer
|
||||
class ReliableResendInfo:
|
||||
last_resent: dt.datetime
|
||||
message: Message
|
||||
completed: asyncio.Future = dataclasses.field(default_factory=asyncio.Future)
|
||||
tries_left: int = 3
|
||||
|
||||
|
||||
@@ -81,12 +83,21 @@ class Circuit:
|
||||
# Temporary alias
|
||||
send_message = send
|
||||
|
||||
def send_reliable(self, message: Message, transport=None) -> asyncio.Future:
|
||||
if message.packet_id and not message.injected:
|
||||
raise ValueError("Not able to send non-synthetic message reliably!")
|
||||
message.send_flags |= PacketFlags.RELIABLE
|
||||
self.send(message, transport)
|
||||
return self.unacked_reliable[(message.direction, message.packet_id)].completed
|
||||
|
||||
def collect_acks(self, message: Message):
|
||||
effective_acks = list(message.acks)
|
||||
if message.name == "PacketAck":
|
||||
effective_acks.extend(x["ID"] for x in message["Packets"])
|
||||
for ack in effective_acks:
|
||||
self.unacked_reliable.pop((~message.direction, ack), None)
|
||||
resend_info = self.unacked_reliable.pop((~message.direction, ack), None)
|
||||
if resend_info:
|
||||
resend_info.completed.set_result(None)
|
||||
|
||||
def resend_unacked(self):
|
||||
for resend_info in list(self.unacked_reliable.values()):
|
||||
@@ -100,6 +111,7 @@ class Circuit:
|
||||
if not resend_info.tries_left:
|
||||
logging.warning(f"Giving up on unacked {msg.packet_id}")
|
||||
del self.unacked_reliable[(msg.direction, msg.packet_id)]
|
||||
resend_info.completed.set_exception(TimeoutError("Exceeded resend limit"))
|
||||
continue
|
||||
resend_info.last_resent = dt.datetime.now()
|
||||
msg.send_flags |= PacketFlags.RESENT
|
||||
|
||||
@@ -276,16 +276,18 @@ class PacketIDTests(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
def test_reliable_ack_collection(self):
|
||||
msg = Message('ChatFromViewer', flags=PacketFlags.RELIABLE)
|
||||
self._send_message(msg)
|
||||
fut = self.circuit.send_reliable(msg)
|
||||
self.assertEqual(1, len(self.circuit.unacked_reliable))
|
||||
# Shouldn't count, this is an ACK going in the wrong direction!
|
||||
ack_msg = Message("PacketAck", Block("Packets", ID=msg.packet_id))
|
||||
self.circuit.collect_acks(ack_msg)
|
||||
self.assertEqual(1, len(self.circuit.unacked_reliable))
|
||||
self.assertFalse(fut.done())
|
||||
# But it should count if the ACK message is heading in
|
||||
ack_msg.direction = Direction.IN
|
||||
self.circuit.collect_acks(ack_msg)
|
||||
self.assertEqual(0, len(self.circuit.unacked_reliable))
|
||||
self.assertTrue(fut.done())
|
||||
|
||||
def test_start_ping_check(self):
|
||||
# Should not break if no unacked
|
||||
|
||||
Reference in New Issue
Block a user