Files
Hippolyzer/hippolyzer/lib/base/message/message_handler.py
2024-01-14 07:00:20 +00:00

156 lines
5.9 KiB
Python

"""
Copyright 2009, Linden Research, Inc.
See NOTICE.md for previous contributors
Copyright 2021, Salad Dais
All Rights Reserved.
This program is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 3 of the License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License
along with this program; if not, write to the Free Software Foundation,
Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
"""
import asyncio
import contextlib
import logging
from typing import *
from hippolyzer.lib.base.events import Event
LOG = logging.getLogger(__name__)
_T = TypeVar("_T")
_K = TypeVar("_K", bound=Hashable)
MESSAGE_HANDLER = Callable[[_T], Any]
PREDICATE = Callable[[_T], bool]
# TODO: Can't do `Iterable[Union[_K, Literal["*"]]]` apparently?
MESSAGE_NAMES = Iterable[Union[_K, str]]
class MessageHandler(Generic[_T, _K]):
def __init__(self, take_by_default: bool = True):
self.handlers: Dict[_K, Event] = {}
self.take_by_default = take_by_default
def register(self, message_name: _K) -> Event:
LOG.debug('Creating a monitor for %s' % message_name)
return self.handlers.setdefault(message_name, Event(message_name))
def subscribe(self, message_name: Union[_K, Literal["*"]], handler: MESSAGE_HANDLER):
notifier = self.register(message_name)
notifier.subscribe(handler)
def _subscribe_all(self, message_names: MESSAGE_NAMES, handler: MESSAGE_HANDLER,
predicate: Optional[PREDICATE] = None) -> List[Event]:
notifiers = [self.register(name) for name in message_names]
for n in notifiers:
n.subscribe(handler, predicate=predicate)
return notifiers
@contextlib.contextmanager
def subscribe_async(self, message_names: MESSAGE_NAMES, predicate: Optional[PREDICATE] = None,
take: Optional[bool] = None) -> ContextManager[Callable[[], Awaitable[_T]]]:
"""
Subscribe to a set of message matching predicate while within a block
Defaults to taking the message out of the usual flow, and any matching messages will
not be automatically be forwarded through to the client, allowing the subscriber coroutine
time to modify or drop the message. Taken messages must be manually sent to the client by
subscribers if desired.
If a subscriber is just an observer that will never drop or modify a message, take=False
may be used and messages will be sent as usual.
"""
if take is None:
take = self.take_by_default
msg_queue = asyncio.Queue()
def _handler_wrapper(message: _T):
# Consider this message owned by one of the async handlers, drop it
if take:
message = message.take()
msg_queue.put_nowait(message)
notifiers = self._subscribe_all(message_names, _handler_wrapper, predicate=predicate)
async def _get_wrapper():
msg = await msg_queue.get()
# Consumption is completion
msg_queue.task_done()
return msg
try:
yield _get_wrapper
finally:
for n in notifiers:
n.unsubscribe(_handler_wrapper)
def wait_for(self, message_names: MESSAGE_NAMES, predicate: Optional[PREDICATE] = None,
timeout: Optional[float] = None, take: Optional[bool] = None) -> Awaitable[_T]:
"""
Wait for a single instance one of message_names matching predicate
Any packets matching predicate will be considered owned by the caller and will be
automatically dropped unless `take=False`. This should not be used if waiting for a
sequence of packets, since multiple packets may come in after the future has already
been marked completed, causing some to be missed.
"""
if take is None:
take = self.take_by_default
notifiers = [self.register(name) for name in message_names]
loop = asyncio.get_event_loop_policy().get_event_loop()
fut = loop.create_future()
timeout_task = None
async def _canceller():
await asyncio.sleep(timeout)
if not fut.done():
fut.set_exception(asyncio.exceptions.TimeoutError("Timed out waiting for packet"))
for n in notifiers:
n.unsubscribe(_handler)
if timeout:
timeout_task = asyncio.create_task(_canceller())
def _handler(message: _T):
if timeout_task:
timeout_task.cancel()
# Whatever was awaiting this future now owns this message
if take:
message = message.take()
if not fut.done():
fut.set_result(message)
# Make sure to unregister this handler for all message types
for n in notifiers:
n.unsubscribe(_handler)
for notifier in notifiers:
notifier.subscribe(_handler, predicate=predicate)
return fut
def is_handled(self, message_name: _K):
return message_name in self.handlers
def handle(self, message: _T):
self._handle_type(message.name, message)
# Always try to call wildcard handlers
self._handle_type('*', message)
def _handle_type(self, name: Union[_K, Literal["*"]], message: _T):
handler = self.handlers.get(name)
if not handler:
return
if len(handler) > 0:
LOG.debug('Handling message : %s' % name)
handler(message)