""" Outbound Transfer only. """ from __future__ import annotations import asyncio import dataclasses from typing import * from hippolyzer.lib.base.datatypes import UUID from hippolyzer.lib.base.helpers import proxify from hippolyzer.lib.base.message.message import Block from hippolyzer.lib.proxy.message import ProxiedMessage from hippolyzer.lib.proxy.templates import ( TransferRequestParamsBase, TransferChannelType, TransferSourceType, TransferStatus, ) if TYPE_CHECKING: from hippolyzer.lib.proxy.region import ProxiedRegion _TRANSFER_MESSAGES = {"TransferInfo", "TransferPacket", "TransferAbort"} class Transfer: def __init__(self, transfer_id: UUID): super().__init__() self.transfer_id: UUID = transfer_id self.chunks: Dict[int, bytes] = {} self.expected_size: Optional[int] = None self.size_known = asyncio.Future() self.error_code: Union[int] = 0 self._future: asyncio.Future[Transfer] = asyncio.Future() def reassemble_chunks(self) -> bytes: assembled = bytearray() for _, data in sorted(self.chunks.items()): assembled.extend(data) return assembled def mark_done(self): self._future.set_result(self) def done(self) -> bool: return self._future.done() def cancelled(self) -> bool: return self._future.cancelled() def is_our_message(self, message: ProxiedMessage): if "TransferData" in message.blocks: transfer_block = message["TransferData"][0] else: transfer_block = message["TransferInfo"][0] return transfer_block["TransferID"] == self.transfer_id def cancel(self) -> bool: if not self.size_known.done(): self.size_known.cancel() return self._future.cancel() def set_exception(self, exc: Union[type, BaseException]) -> None: if not self.size_known.done(): self.size_known.set_exception(exc) return self._future.set_exception(exc) def __await__(self) -> Generator[Any, None, Transfer]: return self._future.__await__() class TransferManager: def __init__(self, region: ProxiedRegion): self._region: ProxiedRegion = proxify(region) def request( self, *, source_type: TransferSourceType, params: TransferRequestParamsBase, transfer_id: Optional[UUID] = None, channel_type: TransferChannelType = TransferChannelType.ASSET, priority: float = 101.0, ) -> Transfer: transfer_id = transfer_id if transfer_id is not None else UUID.random() params_dict = dataclasses.asdict(params) # Fill in any missing AgentID or SessionID attrs if the params type has them if params_dict.get("AgentID", dataclasses.MISSING) is None: params.AgentID = self._region.session().agent_id if params_dict.get("SessionID", dataclasses.MISSING) is None: params.SessionID = self._region.session().id self._region.circuit.send_message(ProxiedMessage( 'TransferRequest', Block( 'TransferInfo', TransferID=transfer_id, ChannelType=channel_type, SourceType=source_type, Priority=priority, Params_=params, ), )) transfer = Transfer(transfer_id) asyncio.create_task(self._pump_transfer_replies(transfer)) return transfer async def _pump_transfer_replies(self, transfer: Transfer): # Subscribe to message related to our transfer while we're in this block with self._region.message_handler.subscribe_async( _TRANSFER_MESSAGES, predicate=transfer.is_our_message ) as get_msg: while not transfer.done(): try: msg: ProxiedMessage = await asyncio.wait_for(get_msg(), 5.0) except TimeoutError as e: transfer.set_exception(e) return if transfer.cancelled(): # Just let any new packets drop on the floor. return if msg.name == "TransferPacket": self._handle_transfer_packet(msg, transfer) elif msg.name == "TransferInfo": self._handle_transfer_info(msg, transfer) elif msg.name == "TransferAbort": transfer.error_code = msg["TransferID"][0].deserialize_var("Result") transfer.set_exception( ConnectionAbortedError(f"Unknown failure") ) def _handle_transfer_packet(self, msg: ProxiedMessage, transfer: Transfer): transfer_block = msg["TransferData"][0] packet_id: int = transfer_block["Packet"] packet_data = transfer_block["Data"] transfer.chunks[packet_id] = packet_data if transfer_block["Status"] == TransferStatus.DONE: transfer.mark_done() def _handle_transfer_info(self, msg: ProxiedMessage, transfer: Transfer): transfer_block = msg["TransferInfo"][0] transfer.expected_size = transfer_block["Size"] # Don't re-set if we get a resend of packet 0 if not transfer.size_known.done(): transfer.size_known.set_result(transfer.expected_size) if transfer_block["Status"] != TransferStatus.OK: transfer.error_code = transfer_block['Status'] transfer.set_exception(ConnectionAbortedError( f"Transfer failed with {transfer_block['Status']}"))