159 lines
6.2 KiB
Python
159 lines
6.2 KiB
Python
from __future__ import annotations
|
|
|
|
import abc
|
|
import asyncio
|
|
import copy
|
|
import dataclasses
|
|
import datetime as dt
|
|
import logging
|
|
from collections import deque
|
|
from typing import *
|
|
from typing import Optional
|
|
|
|
from .message_handler import MessageHandler
|
|
from ..network.transport import AbstractUDPTransport, UDPPacket, Direction, ADDR_TUPLE
|
|
from .message import Block, Message
|
|
from .msgtypes import PacketFlags
|
|
from .udpserializer import UDPMessageSerializer
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class ReliableResendInfo:
|
|
last_resent: dt.datetime
|
|
message: Message
|
|
completed: asyncio.Future = dataclasses.field(default_factory=asyncio.Future)
|
|
tries_left: int = 10
|
|
|
|
|
|
class Circuit:
|
|
def __init__(
|
|
self,
|
|
near_host: Optional[ADDR_TUPLE],
|
|
far_host: ADDR_TUPLE,
|
|
transport: Optional[AbstractUDPTransport] = None,
|
|
):
|
|
self.near_host: Optional[ADDR_TUPLE] = near_host
|
|
self.host: ADDR_TUPLE = far_host
|
|
self.is_alive = True
|
|
self.transport = transport
|
|
self.serializer = UDPMessageSerializer()
|
|
self.last_packet_at = dt.datetime.now()
|
|
self.packet_id_base = 0
|
|
self.unacked_reliable: Dict[Tuple[Direction, int], ReliableResendInfo] = {}
|
|
self.resend_every: float = 3.0
|
|
# Reliable messages that we've already seen and handled, for resend suppression
|
|
self.seen_reliable: deque[int] = deque(maxlen=1_000)
|
|
|
|
def _send_prepared_message(self, message: Message, transport=None):
|
|
try:
|
|
serialized = self.serializer.serialize(message)
|
|
except:
|
|
logging.exception(f"Failed to serialize: {message.to_dict()!r}")
|
|
raise
|
|
return self.send_datagram(serialized, message.direction, transport=transport)
|
|
|
|
def disconnect(self):
|
|
self.packet_id_base = 0
|
|
self.unacked_reliable.clear()
|
|
self.is_alive = False
|
|
|
|
def send_datagram(self, data: bytes, direction: Direction, transport=None):
|
|
self.last_packet_at = dt.datetime.now()
|
|
src_addr, dst_addr = self.host, self.near_host
|
|
if direction == Direction.OUT:
|
|
src_addr, dst_addr = self.near_host, self.host
|
|
|
|
packet = UDPPacket(src_addr, dst_addr, data, direction)
|
|
(transport or self.transport).send_packet(packet)
|
|
return packet
|
|
|
|
def prepare_message(self, message: Message):
|
|
if message.finalized:
|
|
raise RuntimeError(f"Trying to re-send finalized {message!r}")
|
|
message.packet_id = self.packet_id_base
|
|
self.packet_id_base += 1
|
|
if message.acks:
|
|
message.send_flags |= PacketFlags.ACK
|
|
else:
|
|
message.send_flags &= ~PacketFlags.ACK
|
|
# If it was queued, it's not anymore
|
|
message.queued = False
|
|
message.finalized = True
|
|
return True
|
|
|
|
def send(self, message: Message, transport=None) -> UDPPacket:
|
|
if self.prepare_message(message):
|
|
# If the message originates from us then we're responsible for resends.
|
|
if message.reliable and message.synthetic and not transport:
|
|
self.unacked_reliable[(message.direction, message.packet_id)] = ReliableResendInfo(
|
|
last_resent=dt.datetime.now(),
|
|
message=message,
|
|
)
|
|
return self._send_prepared_message(message, transport)
|
|
|
|
def send_reliable(self, message: Message, transport=None) -> asyncio.Future:
|
|
"""send() wrapper that always sends reliably and allows `await`ing ACK receipt"""
|
|
if not message.synthetic:
|
|
raise ValueError("Not able to send non-synthetic message reliably!")
|
|
message.send_flags |= PacketFlags.RELIABLE
|
|
self.send(message, transport)
|
|
return self.unacked_reliable[(message.direction, message.packet_id)].completed
|
|
|
|
def collect_acks(self, message: Message):
|
|
effective_acks = list(message.acks)
|
|
if message.name == "PacketAck":
|
|
effective_acks.extend(x["ID"] for x in message["Packets"])
|
|
for ack in effective_acks:
|
|
resend_info = self.unacked_reliable.pop((~message.direction, ack), None)
|
|
if resend_info:
|
|
resend_info.completed.set_result(None)
|
|
|
|
def resend_unacked(self):
|
|
for resend_info in list(self.unacked_reliable.values()):
|
|
# Not time to attempt a resend yet
|
|
if dt.datetime.now() - resend_info.last_resent < dt.timedelta(seconds=self.resend_every):
|
|
continue
|
|
|
|
msg = copy.copy(resend_info.message)
|
|
resend_info.tries_left -= 1
|
|
# We were on our last try and we never received an ack
|
|
if not resend_info.tries_left:
|
|
logging.warning(f"Giving up on unacked {msg.packet_id}")
|
|
del self.unacked_reliable[(msg.direction, msg.packet_id)]
|
|
resend_info.completed.set_exception(TimeoutError("Exceeded resend limit"))
|
|
continue
|
|
resend_info.last_resent = dt.datetime.now()
|
|
msg.send_flags |= PacketFlags.RESENT
|
|
self._send_prepared_message(msg)
|
|
|
|
def send_acks(self, to_ack: Sequence[int], direction=Direction.OUT, packet_id=None):
|
|
logging.debug("%r acking %r" % (direction, to_ack))
|
|
# TODO: maybe tack this onto `.acks` for next message?
|
|
message = Message('PacketAck', *[Block('Packets', ID=x) for x in to_ack])
|
|
message.packet_id = packet_id
|
|
message.direction = direction
|
|
self.send(message)
|
|
|
|
def track_reliable(self, packet_id: int) -> bool:
|
|
"""Tracks a reliable packet, returning if it's a new message"""
|
|
if packet_id in self.seen_reliable:
|
|
return False
|
|
self.seen_reliable.append(packet_id)
|
|
return True
|
|
|
|
def __repr__(self):
|
|
return "<%s %r : %r>" % (self.__class__.__name__, self.near_host, self.host)
|
|
|
|
|
|
class ConnectionHolder(abc.ABC):
|
|
"""
|
|
Any object that has both a circuit and a message handler
|
|
|
|
Preferred to explicitly passing around a circuit, message handler pair
|
|
because generally a ConnectionHolder represents a region or a client.
|
|
The same region or client may have multiple different circuits across the
|
|
lifetime of a session (due to region restarts, etc.)
|
|
"""
|
|
circuit: Optional[Circuit]
|
|
message_handler: MessageHandler[Message, str]
|