From 4cc1513e58dbf4c52e338c522b53d77a1dc96cb5 Mon Sep 17 00:00:00 2001 From: Salad Dais Date: Fri, 15 Dec 2023 19:07:17 +0000 Subject: [PATCH] Correct type signatures in MessageHandler --- hippolyzer/lib/base/message/message_handler.py | 7 ++++--- hippolyzer/lib/client/hippo_client.py | 2 +- tests/client/test_hippo_client.py | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/hippolyzer/lib/base/message/message_handler.py b/hippolyzer/lib/base/message/message_handler.py index 5ed6413..89b8d59 100644 --- a/hippolyzer/lib/base/message/message_handler.py +++ b/hippolyzer/lib/base/message/message_handler.py @@ -31,7 +31,8 @@ _T = TypeVar("_T") _K = TypeVar("_K", bound=Hashable) MESSAGE_HANDLER = Callable[[_T], Any] PREDICATE = Callable[[_T], bool] -MESSAGE_NAMES = Iterable[_K] +# TODO: Can't do `Iterable[Union[_K, Literal["*"]]` apparently? +MESSAGE_NAMES = Iterable[Union[_K, LiteralString]] class MessageHandler(Generic[_T, _K]): @@ -43,7 +44,7 @@ class MessageHandler(Generic[_T, _K]): LOG.debug('Creating a monitor for %s' % message_name) return self.handlers.setdefault(message_name, Event()) - def subscribe(self, message_name: _K, handler: MESSAGE_HANDLER) -> Event: + def subscribe(self, message_name: Union[_K, Literal["*"]], handler: MESSAGE_HANDLER) -> Event: notifier = self.register(message_name) notifier.subscribe(handler) return notifier @@ -145,7 +146,7 @@ class MessageHandler(Generic[_T, _K]): # Always try to call wildcard handlers self._handle_type('*', message) - def _handle_type(self, name: _K, message: _T): + def _handle_type(self, name: Union[_K, Literal["*"]], message: _T): handler = self.handlers.get(name) if not handler: return diff --git a/hippolyzer/lib/client/hippo_client.py b/hippolyzer/lib/client/hippo_client.py index ce00937..e329d2f 100644 --- a/hippolyzer/lib/client/hippo_client.py +++ b/hippolyzer/lib/client/hippo_client.py @@ -103,7 +103,7 @@ class HippoClientRegion(BaseClientRegion): def __init__(self, circuit_addr, seed_cap: str, session: HippoClientSession, handle=None): super().__init__() self.caps = multidict.MultiDict() - self.message_handler = MessageHandler(take_by_default=False) + self.message_handler: MessageHandler[Message, str] = MessageHandler(take_by_default=False) self.circuit_addr = circuit_addr self.handle = handle if seed_cap: diff --git a/tests/client/test_hippo_client.py b/tests/client/test_hippo_client.py index 2592025..2173cd0 100644 --- a/tests/client/test_hippo_client.py +++ b/tests/client/test_hippo_client.py @@ -105,7 +105,7 @@ class TestHippoClient(unittest.IsolatedAsyncioTestCase): } async def asyncSetUp(self): - self.server_handler = MessageHandler() + self.server_handler: MessageHandler[Message, str] = MessageHandler() self.server_transport = PacketForwardingTransport() self.server_circuit = Circuit(("127.0.0.1", 2), ("127.0.0.1", 99), self.server_transport) self.server = MockServer(self.server_circuit, self.server_handler)