Split LEAPProtocol out of LEAPClient

This commit is contained in:
Salad Dais
2022-09-18 18:05:16 +00:00
parent e6423d2f43
commit 87d5e8340b
2 changed files with 109 additions and 81 deletions

View File

@@ -81,7 +81,7 @@ async def client_connected(client: LEAPClient):
# Clear out the textbox, note that this does _not_ work when path is specified!
# TODO: clearing a textbox isn't so nice. CTL+A doesn't work as expected even without a path,
# it leaves a capital "A" in the text editor. We get rid of it by doing backspace right after.
window_api.key_press(keysym="a", mask=["CTL"])
window_api.key_press(mask=["CTL"], keysym="a")
window_api.key_press(keysym="Backsp")
# Type some text

View File

@@ -24,50 +24,30 @@ import weakref
from hippolyzer.lib.base import llsd
class ConnectionStatus(enum.Enum):
READY = enum.auto()
CONNECTING = enum.auto()
CONNECTED = enum.auto()
DISCONNECTED = enum.auto()
@dataclasses.dataclass
class ListenerDetails:
listener: Optional[str] = dataclasses.field(default_factory=lambda: "PythonListener-%s" % uuid.uuid4())
queues: Set[asyncio.Queue] = dataclasses.field(default_factory=set)
class LEAPClient:
# TODO: better listener creation support
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
self._reader = reader
self._writer = writer
"""Client for script -> viewer communication over the LEAP protocol"""
def __init__(self, protocol: AbstractLEAPProtocol):
self._protocol = protocol
# Pump used for receiving replies
self._reply_pump: Optional[str] = None
# Pump used for sending leap meta-commands to the viewer (getAPIs, etc.)
self.cmd_pump: Optional[str] = None
# Map of req id -> future held by requester to send responses to
# TODO: LRU dict with cancel on evict.
self._reply_futs: Dict[uuid.UUID, asyncio.Future] = {}
self._pump_listeners: Dict[str, ListenerDetails] = collections.defaultdict(ListenerDetails)
self._connection_status = ConnectionStatus.READY
self._drain_task = None
@property
def connected(self) -> bool:
return self._connection_status == ConnectionStatus.CONNECTED
@property
def address(self) -> Optional[Tuple]:
return self._writer.get_extra_info('peername', None)
async def connect(self) -> None:
"""Receive the "hello" message from the viewer and start the message pump"""
assert self._connection_status == ConnectionStatus.READY
self._connection_status = ConnectionStatus.CONNECTING
try:
welcome_message = await self._read_message()
welcome_message = await self._protocol.read_message()
self._reply_pump = welcome_message['pump']
self.cmd_pump = welcome_message['data']['command']
@@ -81,8 +61,8 @@ class LEAPClient:
"""Read and handle inbound messages in a background task"""
async def _pump_messages_forever():
try:
while not self._writer.is_closing() and not self._reader.at_eof():
self.handle_message(await self._read_message())
while not self._protocol.closed:
self.handle_message(await self._protocol.read_message())
except asyncio.IncompleteReadError:
pass
finally:
@@ -95,9 +75,9 @@ class LEAPClient:
def disconnect(self) -> None:
"""Close the connection and clean up any pending request futures"""
if self.connected:
logging.info('closing LEAP connection from %r' % (self.address,))
logging.info('closing LEAP connection')
self._connection_status = ConnectionStatus.DISCONNECTED
self._writer.close()
self._protocol.close()
# Clean up any pending request futures
for fut in list(self._reply_futs.values()):
@@ -131,58 +111,31 @@ class LEAPClient:
Post the event is done synchronously, only waiting for the reply is done async.
"""
assert self.connected
# If you don't pass in a dict for data, we have nowhere to stuff `reqid`.
# That means no reply tracking, meaning no future.
fut = None
if isinstance(data, dict):
# Store some state so we can track replies
# If we expect a reply to this event, we need to do some extra bookkeeping.
# There are apparently some commands for which we can never expect to get a reply.
# Don't add a reqid or reply fut map entry in that case, since it will never be resolved.
if expect_reply:
# If you don't pass in a dict for data, we have nowhere to stuff `reqid`.
# That means no reply tracking, meaning no future.
if not isinstance(data, dict):
raise ValueError(f"Must send a dict in `data` if you want a reply, you sent {data!r}")
# We need to mutate the dict, make a copy so that we don't mess with the caller's version.
data = data.copy()
# Tell the viewer the pump to send replies to
data["reply"] = self._reply_pump
# There are apparently some commands for which we can never expect to get a reply.
# Don't add a reqid or reply fut map entry in that case, since it will never be resolved.
if expect_reply:
# Tell the viewer the pump to send replies to
data["reply"] = self._reply_pump
req_id = uuid.uuid4()
data["reqid"] = req_id
req_id = uuid.uuid4()
data["reqid"] = req_id
fut = asyncio.Future()
# The future will be cleaned up when the Future is done.
fut.add_done_callback(self._cleanup_request_future)
self._reply_futs[req_id] = fut
fut = asyncio.Future()
# The future will be cleaned up when the Future is done.
fut.add_done_callback(self._cleanup_request_future)
self._reply_futs[req_id] = fut
elif expect_reply:
raise ValueError(f"Must send a dict in `data` if you want a reply, you sent {data!r}")
self._write_message(pump, data)
self._protocol.write_message(pump, data)
return fut
def _write_message(self, pump: str, data: Any) -> None:
assert self.connected
ser = llsd.format_notation({"pump": pump, "data": data})
payload = bytearray(str(len(ser)).encode("utf8"))
payload.extend(b":")
payload.extend(ser)
self._writer.write(payload)
# We're in sync context, we need to schedule draining the socket, which is async.
# If a drain is already scheduled then we don't need to reschedule.
if not self._drain_task:
self._drain_task = asyncio.create_task(self._drain_soon())
async def _drain_soon(self) -> None:
self._drain_task = None
await self._writer.drain()
async def _read_message(self) -> Any:
"""Read a single inbound LEAP message"""
assert self._connection_status in (ConnectionStatus.CONNECTED, ConnectionStatus.CONNECTING)
length = int((await self._reader.readuntil(b':')).decode("utf8").strip()[:-1])
if length > 0xffFFff:
raise ValueError(f"Unreasonable LEAP payload length of {length}")
parsed = llsd.parse_notation((await self._reader.readexactly(length)).strip())
return parsed
@contextlib.asynccontextmanager
async def listen_scoped(self, source_pump: str) -> AsyncContextManager[Callable[[], Awaitable[Any]]]:
"""Subscribe to events published on source_pump, allow awaiting them"""
@@ -239,11 +192,7 @@ class LEAPClient:
raise KeyError(f"Couldn't find {msg_queue!r} in pump listeners")
def handle_message(self, message: Any) -> bool:
"""
Handle an inbound message and try to route it to the right recipient
TODO: Events, somehow. Maybe a catch-all event as well?
"""
"""Handle an inbound message and try to route it to the right recipient"""
if not isinstance(message, dict):
logging.warning(f"Received a non-map message: {message!r}")
return False
@@ -287,6 +236,85 @@ class LEAPClient:
return
class ConnectionStatus(enum.Enum):
READY = enum.auto()
CONNECTING = enum.auto()
CONNECTED = enum.auto()
DISCONNECTED = enum.auto()
@dataclasses.dataclass
class ListenerDetails:
# We can only have one listener with a given name active at a time. Give each listener a unique name.
listener: Optional[str] = dataclasses.field(default_factory=lambda: "PythonListener-%s" % uuid.uuid4())
queues: Set[asyncio.Queue] = dataclasses.field(default_factory=set)
class AbstractLEAPProtocol(abc.ABC):
"""Interface for a class representing communication with a LEAP peer"""
closed: bool
@abc.abstractmethod
def close(self) -> None:
pass
@abc.abstractmethod
def write_message(self, pump: str, data: Any) -> None:
pass
@abc.abstractmethod
async def read_message(self) -> Dict:
pass
class LEAPProtocol(AbstractLEAPProtocol):
"""Wrapper for communication with a LEAP peer over an asyncio reader/writer pair"""
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
self._reader = reader
self._writer = writer
self._parser = llsd.LLSDNotationParser()
self._formatter = llsd.LLSDNotationFormatter()
self._drain_task = None
@property
def closed(self) -> bool:
return self._writer.is_closing() or self._reader.at_eof()
def close(self):
if not self._writer.is_closing():
self._writer.write_eof()
self._writer.close()
def write_message(self, pump: str, data: Any) -> None:
assert not self._writer.is_closing()
ser = self._formatter.format({"pump": pump, "data": data})
payload = bytearray(str(len(ser)).encode("utf8"))
payload.extend(b":")
payload.extend(ser)
self._writer.write(payload)
# We're in sync context, we need to schedule draining the socket, which is async.
# If a drain is already scheduled then we don't need to reschedule.
if not self._drain_task:
self._drain_task = asyncio.create_task(self._drain_soon())
async def _drain_soon(self) -> None:
self._drain_task = None
await self._writer.drain()
async def read_message(self) -> Dict:
assert not self._reader.at_eof()
# Length is everything up until the first colon we see, stripping the colon off.
length = int((await self._reader.readuntil(b':'))[:-1].decode("utf8"))
if length > 0xffFFff:
raise ValueError(f"Unreasonable LEAP payload length of {length}")
# Everything after the colon is LLSD
parsed = self._parser.parse(await self._reader.readexactly(length))
if not isinstance(parsed, dict):
raise ValueError(f"Expected LEAP message to be a dict, got {parsed!r}")
return parsed
class LEAPAPIWrapper(abc.ABC):
"""Base class for classes wrapping specific LEAP APIs"""
PUMP_NAME: Optional[str] = None
@@ -504,8 +532,8 @@ class LEAPBridgeServer:
self._client_connected_cb = client_connected_cb
async def handle_connection(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
client = LEAPClient(reader, writer)
logging.info('Accepting LEAP connection from %r' % (client.address,))
client = LEAPClient(LEAPProtocol(reader, writer))
logging.info('Accepting LEAP connection from %r' % (writer.get_extra_info("peername", None),))
await client.connect()
self.clients.add(client)