diff --git a/hippolyzer/apps/proxy_leapreceiver.py b/hippolyzer/apps/proxy_leapreceiver.py index 1038e02..0a1d502 100644 --- a/hippolyzer/apps/proxy_leapreceiver.py +++ b/hippolyzer/apps/proxy_leapreceiver.py @@ -81,7 +81,7 @@ async def client_connected(client: LEAPClient): # Clear out the textbox, note that this does _not_ work when path is specified! # TODO: clearing a textbox isn't so nice. CTL+A doesn't work as expected even without a path, # it leaves a capital "A" in the text editor. We get rid of it by doing backspace right after. - window_api.key_press(keysym="a", mask=["CTL"]) + window_api.key_press(mask=["CTL"], keysym="a") window_api.key_press(keysym="Backsp") # Type some text diff --git a/hippolyzer/lib/proxy/leap.py b/hippolyzer/lib/proxy/leap.py index 9436ce7..c29f978 100644 --- a/hippolyzer/lib/proxy/leap.py +++ b/hippolyzer/lib/proxy/leap.py @@ -24,50 +24,30 @@ import weakref from hippolyzer.lib.base import llsd -class ConnectionStatus(enum.Enum): - READY = enum.auto() - CONNECTING = enum.auto() - CONNECTED = enum.auto() - DISCONNECTED = enum.auto() - - -@dataclasses.dataclass -class ListenerDetails: - listener: Optional[str] = dataclasses.field(default_factory=lambda: "PythonListener-%s" % uuid.uuid4()) - queues: Set[asyncio.Queue] = dataclasses.field(default_factory=set) - - class LEAPClient: - # TODO: better listener creation support - def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter): - self._reader = reader - self._writer = writer + """Client for script -> viewer communication over the LEAP protocol""" + def __init__(self, protocol: AbstractLEAPProtocol): + self._protocol = protocol # Pump used for receiving replies self._reply_pump: Optional[str] = None # Pump used for sending leap meta-commands to the viewer (getAPIs, etc.) self.cmd_pump: Optional[str] = None # Map of req id -> future held by requester to send responses to - # TODO: LRU dict with cancel on evict. self._reply_futs: Dict[uuid.UUID, asyncio.Future] = {} self._pump_listeners: Dict[str, ListenerDetails] = collections.defaultdict(ListenerDetails) self._connection_status = ConnectionStatus.READY - self._drain_task = None @property def connected(self) -> bool: return self._connection_status == ConnectionStatus.CONNECTED - @property - def address(self) -> Optional[Tuple]: - return self._writer.get_extra_info('peername', None) - async def connect(self) -> None: """Receive the "hello" message from the viewer and start the message pump""" assert self._connection_status == ConnectionStatus.READY self._connection_status = ConnectionStatus.CONNECTING try: - welcome_message = await self._read_message() + welcome_message = await self._protocol.read_message() self._reply_pump = welcome_message['pump'] self.cmd_pump = welcome_message['data']['command'] @@ -81,8 +61,8 @@ class LEAPClient: """Read and handle inbound messages in a background task""" async def _pump_messages_forever(): try: - while not self._writer.is_closing() and not self._reader.at_eof(): - self.handle_message(await self._read_message()) + while not self._protocol.closed: + self.handle_message(await self._protocol.read_message()) except asyncio.IncompleteReadError: pass finally: @@ -95,9 +75,9 @@ class LEAPClient: def disconnect(self) -> None: """Close the connection and clean up any pending request futures""" if self.connected: - logging.info('closing LEAP connection from %r' % (self.address,)) + logging.info('closing LEAP connection') self._connection_status = ConnectionStatus.DISCONNECTED - self._writer.close() + self._protocol.close() # Clean up any pending request futures for fut in list(self._reply_futs.values()): @@ -131,58 +111,31 @@ class LEAPClient: Post the event is done synchronously, only waiting for the reply is done async. """ assert self.connected - # If you don't pass in a dict for data, we have nowhere to stuff `reqid`. - # That means no reply tracking, meaning no future. fut = None - if isinstance(data, dict): - # Store some state so we can track replies + # If we expect a reply to this event, we need to do some extra bookkeeping. + # There are apparently some commands for which we can never expect to get a reply. + # Don't add a reqid or reply fut map entry in that case, since it will never be resolved. + if expect_reply: + # If you don't pass in a dict for data, we have nowhere to stuff `reqid`. + # That means no reply tracking, meaning no future. + if not isinstance(data, dict): + raise ValueError(f"Must send a dict in `data` if you want a reply, you sent {data!r}") + # We need to mutate the dict, make a copy so that we don't mess with the caller's version. data = data.copy() + # Tell the viewer the pump to send replies to + data["reply"] = self._reply_pump - # There are apparently some commands for which we can never expect to get a reply. - # Don't add a reqid or reply fut map entry in that case, since it will never be resolved. - if expect_reply: - # Tell the viewer the pump to send replies to - data["reply"] = self._reply_pump + req_id = uuid.uuid4() + data["reqid"] = req_id - req_id = uuid.uuid4() - data["reqid"] = req_id + fut = asyncio.Future() + # The future will be cleaned up when the Future is done. + fut.add_done_callback(self._cleanup_request_future) + self._reply_futs[req_id] = fut - fut = asyncio.Future() - # The future will be cleaned up when the Future is done. - fut.add_done_callback(self._cleanup_request_future) - self._reply_futs[req_id] = fut - elif expect_reply: - raise ValueError(f"Must send a dict in `data` if you want a reply, you sent {data!r}") - - self._write_message(pump, data) + self._protocol.write_message(pump, data) return fut - def _write_message(self, pump: str, data: Any) -> None: - assert self.connected - ser = llsd.format_notation({"pump": pump, "data": data}) - payload = bytearray(str(len(ser)).encode("utf8")) - payload.extend(b":") - payload.extend(ser) - self._writer.write(payload) - # We're in sync context, we need to schedule draining the socket, which is async. - # If a drain is already scheduled then we don't need to reschedule. - if not self._drain_task: - self._drain_task = asyncio.create_task(self._drain_soon()) - - async def _drain_soon(self) -> None: - self._drain_task = None - await self._writer.drain() - - async def _read_message(self) -> Any: - """Read a single inbound LEAP message""" - assert self._connection_status in (ConnectionStatus.CONNECTED, ConnectionStatus.CONNECTING) - - length = int((await self._reader.readuntil(b':')).decode("utf8").strip()[:-1]) - if length > 0xffFFff: - raise ValueError(f"Unreasonable LEAP payload length of {length}") - parsed = llsd.parse_notation((await self._reader.readexactly(length)).strip()) - return parsed - @contextlib.asynccontextmanager async def listen_scoped(self, source_pump: str) -> AsyncContextManager[Callable[[], Awaitable[Any]]]: """Subscribe to events published on source_pump, allow awaiting them""" @@ -239,11 +192,7 @@ class LEAPClient: raise KeyError(f"Couldn't find {msg_queue!r} in pump listeners") def handle_message(self, message: Any) -> bool: - """ - Handle an inbound message and try to route it to the right recipient - - TODO: Events, somehow. Maybe a catch-all event as well? - """ + """Handle an inbound message and try to route it to the right recipient""" if not isinstance(message, dict): logging.warning(f"Received a non-map message: {message!r}") return False @@ -287,6 +236,85 @@ class LEAPClient: return +class ConnectionStatus(enum.Enum): + READY = enum.auto() + CONNECTING = enum.auto() + CONNECTED = enum.auto() + DISCONNECTED = enum.auto() + + +@dataclasses.dataclass +class ListenerDetails: + # We can only have one listener with a given name active at a time. Give each listener a unique name. + listener: Optional[str] = dataclasses.field(default_factory=lambda: "PythonListener-%s" % uuid.uuid4()) + queues: Set[asyncio.Queue] = dataclasses.field(default_factory=set) + + +class AbstractLEAPProtocol(abc.ABC): + """Interface for a class representing communication with a LEAP peer""" + closed: bool + + @abc.abstractmethod + def close(self) -> None: + pass + + @abc.abstractmethod + def write_message(self, pump: str, data: Any) -> None: + pass + + @abc.abstractmethod + async def read_message(self) -> Dict: + pass + + +class LEAPProtocol(AbstractLEAPProtocol): + """Wrapper for communication with a LEAP peer over an asyncio reader/writer pair""" + def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter): + self._reader = reader + self._writer = writer + self._parser = llsd.LLSDNotationParser() + self._formatter = llsd.LLSDNotationFormatter() + self._drain_task = None + + @property + def closed(self) -> bool: + return self._writer.is_closing() or self._reader.at_eof() + + def close(self): + if not self._writer.is_closing(): + self._writer.write_eof() + self._writer.close() + + def write_message(self, pump: str, data: Any) -> None: + assert not self._writer.is_closing() + ser = self._formatter.format({"pump": pump, "data": data}) + payload = bytearray(str(len(ser)).encode("utf8")) + payload.extend(b":") + payload.extend(ser) + self._writer.write(payload) + # We're in sync context, we need to schedule draining the socket, which is async. + # If a drain is already scheduled then we don't need to reschedule. + if not self._drain_task: + self._drain_task = asyncio.create_task(self._drain_soon()) + + async def _drain_soon(self) -> None: + self._drain_task = None + await self._writer.drain() + + async def read_message(self) -> Dict: + assert not self._reader.at_eof() + + # Length is everything up until the first colon we see, stripping the colon off. + length = int((await self._reader.readuntil(b':'))[:-1].decode("utf8")) + if length > 0xffFFff: + raise ValueError(f"Unreasonable LEAP payload length of {length}") + # Everything after the colon is LLSD + parsed = self._parser.parse(await self._reader.readexactly(length)) + if not isinstance(parsed, dict): + raise ValueError(f"Expected LEAP message to be a dict, got {parsed!r}") + return parsed + + class LEAPAPIWrapper(abc.ABC): """Base class for classes wrapping specific LEAP APIs""" PUMP_NAME: Optional[str] = None @@ -504,8 +532,8 @@ class LEAPBridgeServer: self._client_connected_cb = client_connected_cb async def handle_connection(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter): - client = LEAPClient(reader, writer) - logging.info('Accepting LEAP connection from %r' % (client.address,)) + client = LEAPClient(LEAPProtocol(reader, writer)) + logging.info('Accepting LEAP connection from %r' % (writer.get_extra_info("peername", None),)) await client.connect() self.clients.add(client)