diff --git a/hippolyzer/lib/proxy/xfer_manager.py b/hippolyzer/lib/proxy/xfer_manager.py index e7f7e0d..3532d62 100644 --- a/hippolyzer/lib/proxy/xfer_manager.py +++ b/hippolyzer/lib/proxy/xfer_manager.py @@ -24,6 +24,7 @@ _XFER_MESSAGES = {"AbortXfer", "ConfirmXferPacket", "RequestXfer", "SendXferPack MAX_CHUNK_SIZE = 1150 +ACK_AHEAD_MAX = 10 class Xfer: @@ -32,13 +33,17 @@ class Xfer: xfer_id: Optional[int] = None, direction: Direction = Direction.OUT, data: Optional[bytes] = None, + turbo: bool = False, ): self.xfer_id: Optional[int] = xfer_id self.chunks: Dict[int, bytes] = {} self.expected_size: Optional[int] = None self.size_known = asyncio.Future() self.error_code: Union[int, XferError] = 0 + self.next_ackable = 0 + self.turbo = turbo self.direction: Direction = direction + self.expected_chunks: Optional[int] = None self._future: asyncio.Future[Xfer] = asyncio.Future() if data is not None: @@ -99,7 +104,8 @@ class XferManager: vfile_id: Optional[UUID] = None, vfile_type: Optional[Union[AssetType, int]] = None, use_big_packets: bool = False, - delete_on_completion: bool = False, + delete_on_completion: bool = True, + turbo: bool = False, direction: Direction = Direction.OUT, ) -> Xfer: xfer_id = xfer_id if xfer_id is not None else random.getrandbits(64) @@ -117,7 +123,7 @@ class XferManager: ), direction=direction, )) - xfer = Xfer(xfer_id, direction=direction) + xfer = Xfer(xfer_id, direction=direction, turbo=turbo) asyncio.create_task(self._pump_xfer_replies(xfer)) return xfer @@ -160,14 +166,24 @@ class XferManager: xfer.size_known.set_result(xfer.expected_size) packet_data = packet_data[4:] - self._region.circuit.send_message(ProxiedMessage( - "ConfirmXferPacket", - Block("XferID", ID=xfer.xfer_id, Packet=packet_id.PacketID), - direction=xfer.direction, - )) + to_ack = (packet_id.PacketID,) + if xfer.turbo: + # ACK the next few packets we expect to be sent, if we haven't already + ack_max = packet_id.PacketID + ACK_AHEAD_MAX + to_ack = range(xfer.next_ackable, ack_max) + xfer.next_ackable = ack_max + for ack_id in to_ack: + self._region.circuit.send_message(ProxiedMessage( + "ConfirmXferPacket", + Block("XferID", ID=xfer.xfer_id, Packet=ack_id), + direction=xfer.direction, + )) xfer.chunks[packet_id.PacketID] = packet_data - if packet_id.IsEOF and not xfer.done(): + # We may be waiting on other packets so we can't end immediately. + if packet_id.IsEOF: + xfer.expected_chunks = packet_id.PacketID + 1 + if not xfer.done() and len(xfer.chunks) == xfer.expected_chunks: xfer.mark_done() def upload_asset( @@ -240,6 +256,7 @@ class XferManager: self, xfer: Xfer, request_predicate: Callable[[ProxiedMessage], bool], + wait_for_confirm: bool = True ): message_handler = self._region.message_handler request_msg = await message_handler.wait_for( @@ -260,6 +277,7 @@ class XferManager: direction=~request_msg.direction, )) # Don't care about the value, just want to know it was confirmed. - await message_handler.wait_for( - "ConfirmXferPacket", predicate=xfer.is_our_message, timeout=5000) + if wait_for_confirm: + await message_handler.wait_for( + "ConfirmXferPacket", predicate=xfer.is_our_message, timeout=5000) packet_id += 1