Split LEAPProtocol out of LEAPClient
This commit is contained in:
@@ -81,7 +81,7 @@ async def client_connected(client: LEAPClient):
|
||||
# Clear out the textbox, note that this does _not_ work when path is specified!
|
||||
# TODO: clearing a textbox isn't so nice. CTL+A doesn't work as expected even without a path,
|
||||
# it leaves a capital "A" in the text editor. We get rid of it by doing backspace right after.
|
||||
window_api.key_press(keysym="a", mask=["CTL"])
|
||||
window_api.key_press(mask=["CTL"], keysym="a")
|
||||
window_api.key_press(keysym="Backsp")
|
||||
|
||||
# Type some text
|
||||
|
||||
@@ -24,50 +24,30 @@ import weakref
|
||||
from hippolyzer.lib.base import llsd
|
||||
|
||||
|
||||
class ConnectionStatus(enum.Enum):
|
||||
READY = enum.auto()
|
||||
CONNECTING = enum.auto()
|
||||
CONNECTED = enum.auto()
|
||||
DISCONNECTED = enum.auto()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ListenerDetails:
|
||||
listener: Optional[str] = dataclasses.field(default_factory=lambda: "PythonListener-%s" % uuid.uuid4())
|
||||
queues: Set[asyncio.Queue] = dataclasses.field(default_factory=set)
|
||||
|
||||
|
||||
class LEAPClient:
|
||||
# TODO: better listener creation support
|
||||
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
|
||||
self._reader = reader
|
||||
self._writer = writer
|
||||
"""Client for script -> viewer communication over the LEAP protocol"""
|
||||
def __init__(self, protocol: AbstractLEAPProtocol):
|
||||
self._protocol = protocol
|
||||
# Pump used for receiving replies
|
||||
self._reply_pump: Optional[str] = None
|
||||
# Pump used for sending leap meta-commands to the viewer (getAPIs, etc.)
|
||||
self.cmd_pump: Optional[str] = None
|
||||
# Map of req id -> future held by requester to send responses to
|
||||
# TODO: LRU dict with cancel on evict.
|
||||
self._reply_futs: Dict[uuid.UUID, asyncio.Future] = {}
|
||||
self._pump_listeners: Dict[str, ListenerDetails] = collections.defaultdict(ListenerDetails)
|
||||
self._connection_status = ConnectionStatus.READY
|
||||
self._drain_task = None
|
||||
|
||||
@property
|
||||
def connected(self) -> bool:
|
||||
return self._connection_status == ConnectionStatus.CONNECTED
|
||||
|
||||
@property
|
||||
def address(self) -> Optional[Tuple]:
|
||||
return self._writer.get_extra_info('peername', None)
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Receive the "hello" message from the viewer and start the message pump"""
|
||||
assert self._connection_status == ConnectionStatus.READY
|
||||
self._connection_status = ConnectionStatus.CONNECTING
|
||||
|
||||
try:
|
||||
welcome_message = await self._read_message()
|
||||
welcome_message = await self._protocol.read_message()
|
||||
self._reply_pump = welcome_message['pump']
|
||||
self.cmd_pump = welcome_message['data']['command']
|
||||
|
||||
@@ -81,8 +61,8 @@ class LEAPClient:
|
||||
"""Read and handle inbound messages in a background task"""
|
||||
async def _pump_messages_forever():
|
||||
try:
|
||||
while not self._writer.is_closing() and not self._reader.at_eof():
|
||||
self.handle_message(await self._read_message())
|
||||
while not self._protocol.closed:
|
||||
self.handle_message(await self._protocol.read_message())
|
||||
except asyncio.IncompleteReadError:
|
||||
pass
|
||||
finally:
|
||||
@@ -95,9 +75,9 @@ class LEAPClient:
|
||||
def disconnect(self) -> None:
|
||||
"""Close the connection and clean up any pending request futures"""
|
||||
if self.connected:
|
||||
logging.info('closing LEAP connection from %r' % (self.address,))
|
||||
logging.info('closing LEAP connection')
|
||||
self._connection_status = ConnectionStatus.DISCONNECTED
|
||||
self._writer.close()
|
||||
self._protocol.close()
|
||||
|
||||
# Clean up any pending request futures
|
||||
for fut in list(self._reply_futs.values()):
|
||||
@@ -131,58 +111,31 @@ class LEAPClient:
|
||||
Post the event is done synchronously, only waiting for the reply is done async.
|
||||
"""
|
||||
assert self.connected
|
||||
# If you don't pass in a dict for data, we have nowhere to stuff `reqid`.
|
||||
# That means no reply tracking, meaning no future.
|
||||
fut = None
|
||||
if isinstance(data, dict):
|
||||
# Store some state so we can track replies
|
||||
# If we expect a reply to this event, we need to do some extra bookkeeping.
|
||||
# There are apparently some commands for which we can never expect to get a reply.
|
||||
# Don't add a reqid or reply fut map entry in that case, since it will never be resolved.
|
||||
if expect_reply:
|
||||
# If you don't pass in a dict for data, we have nowhere to stuff `reqid`.
|
||||
# That means no reply tracking, meaning no future.
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError(f"Must send a dict in `data` if you want a reply, you sent {data!r}")
|
||||
# We need to mutate the dict, make a copy so that we don't mess with the caller's version.
|
||||
data = data.copy()
|
||||
# Tell the viewer the pump to send replies to
|
||||
data["reply"] = self._reply_pump
|
||||
|
||||
# There are apparently some commands for which we can never expect to get a reply.
|
||||
# Don't add a reqid or reply fut map entry in that case, since it will never be resolved.
|
||||
if expect_reply:
|
||||
# Tell the viewer the pump to send replies to
|
||||
data["reply"] = self._reply_pump
|
||||
req_id = uuid.uuid4()
|
||||
data["reqid"] = req_id
|
||||
|
||||
req_id = uuid.uuid4()
|
||||
data["reqid"] = req_id
|
||||
fut = asyncio.Future()
|
||||
# The future will be cleaned up when the Future is done.
|
||||
fut.add_done_callback(self._cleanup_request_future)
|
||||
self._reply_futs[req_id] = fut
|
||||
|
||||
fut = asyncio.Future()
|
||||
# The future will be cleaned up when the Future is done.
|
||||
fut.add_done_callback(self._cleanup_request_future)
|
||||
self._reply_futs[req_id] = fut
|
||||
elif expect_reply:
|
||||
raise ValueError(f"Must send a dict in `data` if you want a reply, you sent {data!r}")
|
||||
|
||||
self._write_message(pump, data)
|
||||
self._protocol.write_message(pump, data)
|
||||
return fut
|
||||
|
||||
def _write_message(self, pump: str, data: Any) -> None:
|
||||
assert self.connected
|
||||
ser = llsd.format_notation({"pump": pump, "data": data})
|
||||
payload = bytearray(str(len(ser)).encode("utf8"))
|
||||
payload.extend(b":")
|
||||
payload.extend(ser)
|
||||
self._writer.write(payload)
|
||||
# We're in sync context, we need to schedule draining the socket, which is async.
|
||||
# If a drain is already scheduled then we don't need to reschedule.
|
||||
if not self._drain_task:
|
||||
self._drain_task = asyncio.create_task(self._drain_soon())
|
||||
|
||||
async def _drain_soon(self) -> None:
|
||||
self._drain_task = None
|
||||
await self._writer.drain()
|
||||
|
||||
async def _read_message(self) -> Any:
|
||||
"""Read a single inbound LEAP message"""
|
||||
assert self._connection_status in (ConnectionStatus.CONNECTED, ConnectionStatus.CONNECTING)
|
||||
|
||||
length = int((await self._reader.readuntil(b':')).decode("utf8").strip()[:-1])
|
||||
if length > 0xffFFff:
|
||||
raise ValueError(f"Unreasonable LEAP payload length of {length}")
|
||||
parsed = llsd.parse_notation((await self._reader.readexactly(length)).strip())
|
||||
return parsed
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def listen_scoped(self, source_pump: str) -> AsyncContextManager[Callable[[], Awaitable[Any]]]:
|
||||
"""Subscribe to events published on source_pump, allow awaiting them"""
|
||||
@@ -239,11 +192,7 @@ class LEAPClient:
|
||||
raise KeyError(f"Couldn't find {msg_queue!r} in pump listeners")
|
||||
|
||||
def handle_message(self, message: Any) -> bool:
|
||||
"""
|
||||
Handle an inbound message and try to route it to the right recipient
|
||||
|
||||
TODO: Events, somehow. Maybe a catch-all event as well?
|
||||
"""
|
||||
"""Handle an inbound message and try to route it to the right recipient"""
|
||||
if not isinstance(message, dict):
|
||||
logging.warning(f"Received a non-map message: {message!r}")
|
||||
return False
|
||||
@@ -287,6 +236,85 @@ class LEAPClient:
|
||||
return
|
||||
|
||||
|
||||
class ConnectionStatus(enum.Enum):
|
||||
READY = enum.auto()
|
||||
CONNECTING = enum.auto()
|
||||
CONNECTED = enum.auto()
|
||||
DISCONNECTED = enum.auto()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ListenerDetails:
|
||||
# We can only have one listener with a given name active at a time. Give each listener a unique name.
|
||||
listener: Optional[str] = dataclasses.field(default_factory=lambda: "PythonListener-%s" % uuid.uuid4())
|
||||
queues: Set[asyncio.Queue] = dataclasses.field(default_factory=set)
|
||||
|
||||
|
||||
class AbstractLEAPProtocol(abc.ABC):
|
||||
"""Interface for a class representing communication with a LEAP peer"""
|
||||
closed: bool
|
||||
|
||||
@abc.abstractmethod
|
||||
def close(self) -> None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def write_message(self, pump: str, data: Any) -> None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def read_message(self) -> Dict:
|
||||
pass
|
||||
|
||||
|
||||
class LEAPProtocol(AbstractLEAPProtocol):
|
||||
"""Wrapper for communication with a LEAP peer over an asyncio reader/writer pair"""
|
||||
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
|
||||
self._reader = reader
|
||||
self._writer = writer
|
||||
self._parser = llsd.LLSDNotationParser()
|
||||
self._formatter = llsd.LLSDNotationFormatter()
|
||||
self._drain_task = None
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
return self._writer.is_closing() or self._reader.at_eof()
|
||||
|
||||
def close(self):
|
||||
if not self._writer.is_closing():
|
||||
self._writer.write_eof()
|
||||
self._writer.close()
|
||||
|
||||
def write_message(self, pump: str, data: Any) -> None:
|
||||
assert not self._writer.is_closing()
|
||||
ser = self._formatter.format({"pump": pump, "data": data})
|
||||
payload = bytearray(str(len(ser)).encode("utf8"))
|
||||
payload.extend(b":")
|
||||
payload.extend(ser)
|
||||
self._writer.write(payload)
|
||||
# We're in sync context, we need to schedule draining the socket, which is async.
|
||||
# If a drain is already scheduled then we don't need to reschedule.
|
||||
if not self._drain_task:
|
||||
self._drain_task = asyncio.create_task(self._drain_soon())
|
||||
|
||||
async def _drain_soon(self) -> None:
|
||||
self._drain_task = None
|
||||
await self._writer.drain()
|
||||
|
||||
async def read_message(self) -> Dict:
|
||||
assert not self._reader.at_eof()
|
||||
|
||||
# Length is everything up until the first colon we see, stripping the colon off.
|
||||
length = int((await self._reader.readuntil(b':'))[:-1].decode("utf8"))
|
||||
if length > 0xffFFff:
|
||||
raise ValueError(f"Unreasonable LEAP payload length of {length}")
|
||||
# Everything after the colon is LLSD
|
||||
parsed = self._parser.parse(await self._reader.readexactly(length))
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValueError(f"Expected LEAP message to be a dict, got {parsed!r}")
|
||||
return parsed
|
||||
|
||||
|
||||
class LEAPAPIWrapper(abc.ABC):
|
||||
"""Base class for classes wrapping specific LEAP APIs"""
|
||||
PUMP_NAME: Optional[str] = None
|
||||
@@ -504,8 +532,8 @@ class LEAPBridgeServer:
|
||||
self._client_connected_cb = client_connected_cb
|
||||
|
||||
async def handle_connection(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
|
||||
client = LEAPClient(reader, writer)
|
||||
logging.info('Accepting LEAP connection from %r' % (client.address,))
|
||||
client = LEAPClient(LEAPProtocol(reader, writer))
|
||||
logging.info('Accepting LEAP connection from %r' % (writer.get_extra_info("peername", None),))
|
||||
await client.connect()
|
||||
|
||||
self.clients.add(client)
|
||||
|
||||
Reference in New Issue
Block a user