292 lines
11 KiB
Python
292 lines
11 KiB
Python
"""
|
|
Managers for inbound and outbound xfer as well as the AssetUploadRequest flow
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import enum
|
|
import random
|
|
from typing import *
|
|
|
|
from hippolyzer.lib.base.datatypes import UUID, RawBytes
|
|
from hippolyzer.lib.base.helpers import create_logged_task
|
|
from hippolyzer.lib.base.message.data_packer import TemplateDataPacker
|
|
from hippolyzer.lib.base.message.message import Block, Message
|
|
from hippolyzer.lib.base.message.msgtypes import MsgType, PacketFlags
|
|
from hippolyzer.lib.base.network.transport import Direction
|
|
from hippolyzer.lib.base.message.circuit import ConnectionHolder
|
|
from hippolyzer.lib.base.templates import XferPacket, XferFilePath, AssetType, XferError
|
|
|
|
_XFER_MESSAGES = {"AbortXfer", "ConfirmXferPacket", "RequestXfer", "SendXferPacket"}
|
|
|
|
|
|
MAX_CHUNK_SIZE = 1150
|
|
ACK_AHEAD_MAX = 10
|
|
|
|
|
|
class Xfer:
|
|
def __init__(
|
|
self,
|
|
xfer_id: Optional[int] = None,
|
|
direction: Direction = Direction.OUT,
|
|
data: Optional[bytes] = None,
|
|
turbo: bool = False,
|
|
):
|
|
self.xfer_id: Optional[int] = xfer_id
|
|
self.chunks: Dict[int, bytes] = {}
|
|
self.expected_size: Optional[int] = None
|
|
self.size_known = asyncio.Future()
|
|
self.error_code: Union[int, XferError] = 0
|
|
self.next_ackable = 0
|
|
self.turbo = turbo
|
|
self.direction: Direction = direction
|
|
self.expected_chunks: Optional[int] = None
|
|
self._future: asyncio.Future[Xfer] = asyncio.Future()
|
|
|
|
if data is not None:
|
|
# Prepend the expected length field to the first chunk
|
|
if not isinstance(data, RawBytes):
|
|
data = TemplateDataPacker.pack(len(data), MsgType.MVT_S32) + data
|
|
chunk_num = 0
|
|
while data:
|
|
self.chunks[chunk_num] = data[:MAX_CHUNK_SIZE]
|
|
data = data[MAX_CHUNK_SIZE:]
|
|
chunk_num += 1
|
|
|
|
def reassemble_chunks(self) -> bytes:
|
|
assembled = bytearray()
|
|
for _, data in sorted(self.chunks.items()):
|
|
assembled.extend(data)
|
|
return assembled
|
|
|
|
def mark_done(self):
|
|
self._future.set_result(self)
|
|
|
|
def done(self) -> bool:
|
|
return self._future.done()
|
|
|
|
def cancelled(self) -> bool:
|
|
return self._future.cancelled()
|
|
|
|
def is_our_message(self, message):
|
|
return message["XferID"]["ID"] == self.xfer_id
|
|
|
|
def cancel(self) -> bool:
|
|
if not self.size_known.done():
|
|
self.size_known.cancel()
|
|
return self._future.cancel()
|
|
|
|
def set_exception(self, exc: Union[type, BaseException]) -> None:
|
|
if not self.size_known.done():
|
|
self.size_known.set_exception(exc)
|
|
return self._future.set_exception(exc)
|
|
|
|
def __await__(self) -> Generator[Any, None, Xfer]:
|
|
return self._future.__await__()
|
|
|
|
|
|
class UploadStrategy(enum.IntEnum):
|
|
XFER = enum.auto()
|
|
ASSET_UPLOAD_REQUEST = enum.auto()
|
|
|
|
|
|
class XferManager:
|
|
def __init__(
|
|
self,
|
|
connection_holder: ConnectionHolder,
|
|
secure_session_id: Optional[UUID] = None,
|
|
):
|
|
self._connection_holder = connection_holder
|
|
self._secure_session_id = secure_session_id
|
|
|
|
def request(
|
|
self, xfer_id: Optional[int] = None,
|
|
file_name: Union[bytes, str, None] = None,
|
|
file_path: Optional[Union[XferFilePath, int]] = None,
|
|
vfile_id: Optional[UUID] = None,
|
|
vfile_type: Optional[Union[AssetType, int]] = None,
|
|
use_big_packets: bool = False,
|
|
delete_on_completion: bool = True,
|
|
turbo: bool = False,
|
|
direction: Direction = Direction.OUT,
|
|
) -> Xfer:
|
|
xfer_id = xfer_id if xfer_id is not None else random.getrandbits(64)
|
|
self._connection_holder.circuit.send(Message(
|
|
'RequestXfer',
|
|
Block(
|
|
'XferID',
|
|
ID=xfer_id,
|
|
Filename=file_name or b'',
|
|
FilePath=file_path or XferFilePath.NONE,
|
|
DeleteOnCompletion=delete_on_completion,
|
|
UseBigPackets=use_big_packets,
|
|
VFileID=vfile_id or UUID(),
|
|
VFileType=vfile_type or AssetType.NONE,
|
|
),
|
|
direction=direction,
|
|
))
|
|
xfer = Xfer(xfer_id, direction=direction, turbo=turbo)
|
|
create_logged_task(self._pump_xfer_replies(xfer), "Xfer Pump")
|
|
return xfer
|
|
|
|
async def _pump_xfer_replies(self, xfer: Xfer):
|
|
with self._connection_holder.message_handler.subscribe_async(
|
|
_XFER_MESSAGES,
|
|
predicate=xfer.is_our_message,
|
|
) as get_msg:
|
|
while not xfer.done():
|
|
try:
|
|
msg: Message = await asyncio.wait_for(get_msg(), 5.0)
|
|
except asyncio.exceptions.TimeoutError as e:
|
|
xfer.set_exception(e)
|
|
return
|
|
|
|
if xfer.cancelled():
|
|
# AbortXfer doesn't seem to work on in-progress Xfers.
|
|
# Just let any new packets drop on the floor.
|
|
return
|
|
|
|
if msg.name == "SendXferPacket":
|
|
self._handle_send_xfer_packet(msg, xfer)
|
|
elif msg.name == "AbortXfer":
|
|
xfer.error_code = msg["XferID"][0].deserialize_var("Result")
|
|
xfer.set_exception(
|
|
ConnectionAbortedError(f"Xfer failed with {xfer.error_code!r}")
|
|
)
|
|
|
|
def _handle_send_xfer_packet(self, msg: Message, xfer: Xfer):
|
|
# Received a SendXfer for an Xfer we sent ourselves
|
|
packet_id: XferPacket = msg["XferID"][0].deserialize_var("Packet")
|
|
packet_data = msg["DataPacket"]["Data"]
|
|
# First 4 bytes are expected total data length
|
|
if packet_id.PacketID == 0:
|
|
# Yes, S32. Only used as a hint so buffers can be pre-allocated,
|
|
# EOF bit determines when the data actually ends.
|
|
xfer.expected_size = TemplateDataPacker.unpack(packet_data[:4], MsgType.MVT_S32)
|
|
# Don't re-set if we get a resend of packet 0
|
|
if not xfer.size_known.done():
|
|
xfer.size_known.set_result(xfer.expected_size)
|
|
packet_data = packet_data[4:]
|
|
|
|
to_ack = (packet_id.PacketID,)
|
|
if xfer.turbo:
|
|
# ACK the next few packets we expect to be sent, if we haven't already
|
|
ack_max = packet_id.PacketID + ACK_AHEAD_MAX
|
|
to_ack = range(xfer.next_ackable, ack_max)
|
|
xfer.next_ackable = ack_max
|
|
for ack_id in to_ack:
|
|
self._connection_holder.circuit.send_reliable(Message(
|
|
"ConfirmXferPacket",
|
|
Block("XferID", ID=xfer.xfer_id, Packet=ack_id),
|
|
direction=xfer.direction,
|
|
flags=PacketFlags.RELIABLE,
|
|
))
|
|
|
|
xfer.chunks[packet_id.PacketID] = packet_data
|
|
# We may be waiting on other packets so we can't end immediately.
|
|
if packet_id.IsEOF:
|
|
xfer.expected_chunks = packet_id.PacketID + 1
|
|
if not xfer.done() and len(xfer.chunks) == xfer.expected_chunks:
|
|
xfer.mark_done()
|
|
|
|
def upload_asset(
|
|
self,
|
|
asset_type: AssetType,
|
|
data: Union[bytes, str],
|
|
store_local: bool = False,
|
|
temp_file: bool = False,
|
|
transaction_id: Optional[UUID] = None,
|
|
upload_strategy: Optional[UploadStrategy] = None,
|
|
) -> asyncio.Future[UUID]:
|
|
"""Upload an asset through the Xfer upload path"""
|
|
if not transaction_id:
|
|
transaction_id = UUID.random()
|
|
if isinstance(data, str):
|
|
data = data.encode("utf8")
|
|
|
|
# Small amounts of data can be sent inline, decide based on size
|
|
if upload_strategy is None:
|
|
if len(data) >= MAX_CHUNK_SIZE:
|
|
upload_strategy = UploadStrategy.XFER
|
|
else:
|
|
upload_strategy = UploadStrategy.ASSET_UPLOAD_REQUEST
|
|
|
|
xfer = None
|
|
inline_data = b''
|
|
if upload_strategy == UploadStrategy.XFER:
|
|
xfer = Xfer(data=data)
|
|
else:
|
|
inline_data = data
|
|
|
|
self._connection_holder.circuit.send(Message(
|
|
"AssetUploadRequest",
|
|
Block(
|
|
"AssetBlock",
|
|
TransactionID=transaction_id,
|
|
Type=asset_type,
|
|
Tempfile=temp_file,
|
|
StoreLocal=store_local,
|
|
AssetData=inline_data,
|
|
),
|
|
flags=PacketFlags.RELIABLE
|
|
))
|
|
fut = asyncio.Future()
|
|
asyncio.create_task(self._pump_asset_upload(xfer, transaction_id, fut))
|
|
return fut
|
|
|
|
async def _pump_asset_upload(self, xfer: Optional[Xfer], transaction_id: UUID, fut: asyncio.Future):
|
|
message_handler = self._connection_holder.message_handler
|
|
# We'll receive an Xfer request for the asset we're uploading.
|
|
# asset ID is determined by hashing secure session ID with chosen transaction ID.
|
|
asset_id: UUID = UUID.combine(transaction_id, self._secure_session_id)
|
|
try:
|
|
# Only need to do this if we're using the xfer upload strategy, otherwise all the
|
|
# data was already sent in the AssetUploadRequest and we don't expect a RequestXfer.
|
|
def request_predicate(request_msg: Message):
|
|
return request_msg["XferID"]["VFileID"] == asset_id
|
|
if xfer is not None:
|
|
await self.serve_inbound_xfer_request(xfer, request_predicate)
|
|
|
|
def complete_predicate(complete_msg: Message):
|
|
return complete_msg["AssetBlock"]["UUID"] == asset_id
|
|
msg = await message_handler.wait_for(('AssetUploadComplete',), predicate=complete_predicate)
|
|
if msg["AssetBlock"]["Success"] == 1:
|
|
fut.set_result(asset_id)
|
|
else:
|
|
fut.set_exception(RuntimeError(f"Xfer for transaction {transaction_id} failed"))
|
|
|
|
except asyncio.TimeoutError as e:
|
|
fut.set_exception(e)
|
|
|
|
async def serve_inbound_xfer_request(
|
|
self,
|
|
xfer: Xfer,
|
|
request_predicate: Callable[[Message], bool],
|
|
wait_for_confirm: bool = True
|
|
):
|
|
message_handler = self._connection_holder.message_handler
|
|
request_msg = await message_handler.wait_for(
|
|
('RequestXfer',), predicate=request_predicate, timeout=5.0)
|
|
xfer.xfer_id = request_msg["XferID"]["ID"]
|
|
|
|
packet_id = 0
|
|
while xfer.chunks:
|
|
chunk = xfer.chunks.pop(packet_id)
|
|
# EOF if there are no chunks left
|
|
packet_val = XferPacket(PacketID=packet_id, IsEOF=not bool(xfer.chunks))
|
|
# We just send reliably since I don't care to implement the Xfer-specific
|
|
# resend-on-unacked nastiness
|
|
_ = self._connection_holder.circuit.send_reliable(Message(
|
|
"SendXferPacket",
|
|
Block("XferID", ID=xfer.xfer_id, Packet_=packet_val),
|
|
Block("DataPacket", Data=chunk),
|
|
# Send this towards the sender of the RequestXfer
|
|
direction=~request_msg.direction,
|
|
flags=PacketFlags.RELIABLE,
|
|
))
|
|
# Don't care about the value, just want to know it was confirmed.
|
|
if wait_for_confirm:
|
|
await message_handler.wait_for(
|
|
("ConfirmXferPacket",), predicate=xfer.is_our_message, timeout=5.0)
|
|
packet_id += 1
|