diff --git a/hippolyzer/lib/proxy/xfer_manager.py b/hippolyzer/lib/proxy/xfer_manager.py index abcb11d..87b676b 100644 --- a/hippolyzer/lib/proxy/xfer_manager.py +++ b/hippolyzer/lib/proxy/xfer_manager.py @@ -23,9 +23,16 @@ if TYPE_CHECKING: _XFER_MESSAGES = {"AbortXfer", "ConfirmXferPacket", "RequestXfer", "SendXferPacket"} +MAX_CHUNK_SIZE = 1150 + + class Xfer: - def __init__(self, xfer_id: Optional[int] = None, direction: Direction = Direction.OUT): - super().__init__() + def __init__( + self, + xfer_id: Optional[int] = None, + direction: Direction = Direction.OUT, + data: Optional[bytes] = None, + ): self.xfer_id: Optional[int] = xfer_id self.chunks: Dict[int, bytes] = {} self.expected_size: Optional[int] = None @@ -34,6 +41,16 @@ class Xfer: self.direction: Direction = direction self._future: asyncio.Future[Xfer] = asyncio.Future() + if data is not None: + # Prepend the expected length field to the first chunk + if not isinstance(data, RawBytes): + data = TemplateDataPacker.pack(len(data), MsgType.MVT_S32) + data + chunk_num = 0 + while data: + self.chunks[chunk_num] = data[:MAX_CHUNK_SIZE] + data = data[MAX_CHUNK_SIZE:] + chunk_num += 1 + def reassemble_chunks(self) -> bytes: assembled = bytearray() for _, data in sorted(self.chunks.items()): @@ -168,7 +185,7 @@ class XferManager: # Small amounts of data can be sent inline, decide based on size if upload_strategy is None: - if len(data) >= 1150: + if len(data) >= MAX_CHUNK_SIZE: upload_strategy = UploadStrategy.XFER else: upload_strategy = UploadStrategy.ASSET_UPLOAD_REQUEST @@ -176,15 +193,7 @@ class XferManager: xfer = None inline_data = b'' if upload_strategy == UploadStrategy.XFER: - # Prepend the expected length field to the first chunk - if not isinstance(data, RawBytes): - data = TemplateDataPacker.pack(len(data), MsgType.MVT_S32) + data - xfer = Xfer() - chunk_num = 0 - while data: - xfer.chunks[chunk_num] = data[:1150] - data = data[1150:] - chunk_num += 1 + xfer = Xfer(data=data) else: inline_data = data @@ -211,28 +220,10 @@ class XferManager: try: # Only need to do this if we're using the xfer upload strategy, otherwise all the # data was already sent in the AssetUploadRequest and we don't expect a RequestXfer. + def request_predicate(request_msg: ProxiedMessage): + return request_msg["XferID"]["VFileID"] == asset_id if xfer is not None: - def request_predicate(request_msg: ProxiedMessage): - return request_msg["XferID"]["VFileID"] == asset_id - msg = await message_handler.wait_for( - 'RequestXfer', predicate=request_predicate, timeout=5000) - xfer.xfer_id = msg["XferID"]["ID"] - - packet_id = 0 - # TODO: No resend yet. If it's lost, it's lost. - while xfer.chunks: - chunk = xfer.chunks.pop(packet_id) - # EOF if there are no chunks left - packet_val = XferPacket(PacketID=packet_id, IsEOF=not bool(xfer.chunks)) - self._region.circuit.send_message(ProxiedMessage( - "SendXferPacket", - Block("XferID", ID=xfer.xfer_id, Packet_=packet_val), - Block("DataPacket", Data=chunk), - )) - # Don't care about the value, just want to know it was confirmed. - await message_handler.wait_for( - "ConfirmXferPacket", predicate=xfer.is_our_message, timeout=5000) - packet_id += 1 + await self.serve_inbound_xfer_request(xfer, request_predicate) def complete_predicate(complete_msg: ProxiedMessage): return complete_msg["AssetBlock"]["UUID"] == asset_id @@ -244,3 +235,31 @@ class XferManager: except asyncio.TimeoutError as e: fut.set_exception(e) + + async def serve_inbound_xfer_request( + self, + xfer: Xfer, + request_predicate: Callable[[ProxiedMessage], bool], + ): + message_handler = self._region.message_handler + request_msg = await message_handler.wait_for( + 'RequestXfer', predicate=request_predicate, timeout=5000) + xfer.xfer_id = request_msg["XferID"]["ID"] + + packet_id = 0 + # TODO: No resend yet. If it's lost, it's lost. + while xfer.chunks: + chunk = xfer.chunks.pop(packet_id) + # EOF if there are no chunks left + packet_val = XferPacket(PacketID=packet_id, IsEOF=not bool(xfer.chunks)) + self._region.circuit.send_message(ProxiedMessage( + "SendXferPacket", + Block("XferID", ID=xfer.xfer_id, Packet_=packet_val), + Block("DataPacket", Data=chunk), + # Send this towards the sender of the RequestXfer + direction=~request_msg.direction, + )) + # Don't care about the value, just want to know it was confirmed. + await message_handler.wait_for( + "ConfirmXferPacket", predicate=xfer.is_our_message, timeout=5000) + packet_id += 1