diff --git a/addon_examples/transfer_example.py b/addon_examples/transfer_example.py index 90c81b5..f3cce06 100644 --- a/addon_examples/transfer_example.py +++ b/addon_examples/transfer_example.py @@ -40,7 +40,7 @@ class TransferExampleAddon(BaseAddon): Block('AgentData', AgentID=session.agent_id, SessionID=session.id), Block('InventoryData', LocalID=session.selected.object_local), )) - inv_message = await region.message_handler.wait_for('ReplyTaskInventory', timeout=5.0) + inv_message = await region.message_handler.wait_for(('ReplyTaskInventory',), timeout=5.0) # Xfer the inventory file and look for a script xfer = await region.xfer_manager.request( diff --git a/addon_examples/turbo_object_inventory.py b/addon_examples/turbo_object_inventory.py index 12b959b..90e6824 100644 --- a/addon_examples/turbo_object_inventory.py +++ b/addon_examples/turbo_object_inventory.py @@ -65,7 +65,7 @@ class TurboObjectInventoryAddon(BaseAddon): # by marking it complete on the server-side. Re-send our RequestTaskInventory # To make sure there's a fresh copy. region.circuit.send_message(request_msg.take()) - inv_message = await region.message_handler.wait_for('ReplyTaskInventory', timeout=5.0) + inv_message = await region.message_handler.wait_for(('ReplyTaskInventory',), timeout=5.0) # No task inventory, send the reply as-is file_name = inv_message["InventoryData"]["Filename"] if not file_name: diff --git a/addon_examples/xfer_example.py b/addon_examples/xfer_example.py index 04a5f66..1619a3e 100644 --- a/addon_examples/xfer_example.py +++ b/addon_examples/xfer_example.py @@ -22,7 +22,7 @@ class XferExampleAddon(BaseAddon): )) # Wait for any MuteListUpdate, dropping it before it reaches the viewer - update_msg = await region.message_handler.wait_for('MuteListUpdate', timeout=5.0) + update_msg = await region.message_handler.wait_for(('MuteListUpdate',), timeout=5.0) mute_file_name = update_msg["MuteData"]["Filename"] if not mute_file_name: show_message("Nobody muted?") @@ -42,7 +42,7 @@ class XferExampleAddon(BaseAddon): Block('InventoryData', LocalID=session.selected.object_local), )) - inv_message = await region.message_handler.wait_for('ReplyTaskInventory', timeout=5.0) + inv_message = await region.message_handler.wait_for(('ReplyTaskInventory',), timeout=5.0) # Xfer doesn't need to be immediately awaited, multiple signals can be waited on. xfer = region.xfer_manager.request( diff --git a/hippolyzer/lib/base/message/circuit.py b/hippolyzer/lib/base/message/circuit.py index a001c4b..848a765 100644 --- a/hippolyzer/lib/base/message/circuit.py +++ b/hippolyzer/lib/base/message/circuit.py @@ -77,4 +77,4 @@ class ConnectionHolder(abc.ABC): lifetime of a session (due to region restarts, etc.) """ circuit: Optional[Circuit] - message_handler: MessageHandler[Message] + message_handler: MessageHandler[Message, str] diff --git a/hippolyzer/lib/base/message/message_handler.py b/hippolyzer/lib/base/message/message_handler.py index 56356ef..4d579da 100644 --- a/hippolyzer/lib/base/message/message_handler.py +++ b/hippolyzer/lib/base/message/message_handler.py @@ -28,28 +28,28 @@ from hippolyzer.lib.base.events import Event LOG = logging.getLogger(__name__) _T = TypeVar("_T") +_K = TypeVar("_K", bound=Hashable) MESSAGE_HANDLER = Callable[[_T], Any] PREDICATE = Callable[[_T], bool] -MESSAGE_NAMES = Union[str, Iterable[str]] +MESSAGE_NAMES = Iterable[_K] -class MessageHandler(Generic[_T]): - def __init__(self): - self.handlers: Dict[str, Event] = {} +class MessageHandler(Generic[_T, _K]): + def __init__(self, take_by_default: bool = True): + self.handlers: Dict[_K, Event] = {} + self.take_by_default = take_by_default - def register(self, message_name: str) -> Event: + def register(self, message_name: _K) -> Event: LOG.debug('Creating a monitor for %s' % message_name) return self.handlers.setdefault(message_name, Event()) - def subscribe(self, message_name: str, handler: MESSAGE_HANDLER) -> Event: + def subscribe(self, message_name: _K, handler: MESSAGE_HANDLER) -> Event: notifier = self.register(message_name) notifier.subscribe(handler) return notifier def _subscribe_all(self, message_names: MESSAGE_NAMES, handler: MESSAGE_HANDLER, predicate: Optional[PREDICATE] = None) -> List[Event]: - if isinstance(message_names, str): - message_names = (message_names,) notifiers = [self.register(name) for name in message_names] for n in notifiers: n.subscribe(handler, predicate=predicate) @@ -57,7 +57,7 @@ class MessageHandler(Generic[_T]): @contextlib.contextmanager def subscribe_async(self, message_names: MESSAGE_NAMES, predicate: Optional[PREDICATE] = None, - take: bool = True) -> ContextManager[Callable[[], Awaitable[_T]]]: + take: Optional[bool] = None) -> ContextManager[Callable[[], Awaitable[_T]]]: """ Subscribe to a set of message matching predicate while within a block @@ -69,6 +69,8 @@ class MessageHandler(Generic[_T]): If a subscriber is just an observer that will never drop or modify a message, take=False may be used and messages will be sent as usual. """ + if take is None: + take = self.take_by_default msg_queue = asyncio.Queue() def _handler_wrapper(message: _T): @@ -91,8 +93,8 @@ class MessageHandler(Generic[_T]): for n in notifiers: n.unsubscribe(_handler_wrapper) - def wait_for(self, message_names: MESSAGE_NAMES, - predicate: Optional[PREDICATE] = None, timeout=None, take=True) -> Awaitable[_T]: + def wait_for(self, message_names: MESSAGE_NAMES, predicate: Optional[PREDICATE] = None, + timeout: Optional[float] = None, take: Optional[bool] = None) -> Awaitable[_T]: """ Wait for a single instance one of message_names matching predicate @@ -101,8 +103,8 @@ class MessageHandler(Generic[_T]): sequence of packets, since multiple packets may come in after the future has already been marked completed, causing some to be missed. """ - if isinstance(message_names, str): - message_names = (message_names,) + if take is None: + take = self.take_by_default notifiers = [self.register(name) for name in message_names] fut = asyncio.get_event_loop().create_future() @@ -132,7 +134,7 @@ class MessageHandler(Generic[_T]): notifier.subscribe(_handler, predicate=predicate) return fut - def is_handled(self, message_name: str): + def is_handled(self, message_name: _K): return message_name in self.handlers def handle(self, message: _T): @@ -140,7 +142,7 @@ class MessageHandler(Generic[_T]): # Always try to call wildcard handlers self._handle_type('*', message) - def _handle_type(self, name: str, message: _T): + def _handle_type(self, name: _K, message: _T): handler = self.handlers.get(name) if not handler: return diff --git a/hippolyzer/lib/base/xfer_manager.py b/hippolyzer/lib/base/xfer_manager.py index 7e05627..c9a88ab 100644 --- a/hippolyzer/lib/base/xfer_manager.py +++ b/hippolyzer/lib/base/xfer_manager.py @@ -246,7 +246,7 @@ class XferManager: def complete_predicate(complete_msg: Message): return complete_msg["AssetBlock"]["UUID"] == asset_id - msg = await message_handler.wait_for('AssetUploadComplete', predicate=complete_predicate) + msg = await message_handler.wait_for(('AssetUploadComplete',), predicate=complete_predicate) if msg["AssetBlock"]["Success"] == 1: fut.set_result(asset_id) else: @@ -263,7 +263,7 @@ class XferManager: ): message_handler = self._connection_holder.message_handler request_msg = await message_handler.wait_for( - 'RequestXfer', predicate=request_predicate, timeout=5.0) + ('RequestXfer',), predicate=request_predicate, timeout=5.0) xfer.xfer_id = request_msg["XferID"]["ID"] packet_id = 0 @@ -282,5 +282,5 @@ class XferManager: # Don't care about the value, just want to know it was confirmed. if wait_for_confirm: await message_handler.wait_for( - "ConfirmXferPacket", predicate=xfer.is_our_message, timeout=5.0) + ("ConfirmXferPacket",), predicate=xfer.is_our_message, timeout=5.0) packet_id += 1 diff --git a/hippolyzer/lib/client/namecache.py b/hippolyzer/lib/client/namecache.py index ca2b31a..83ed6b9 100644 --- a/hippolyzer/lib/client/namecache.py +++ b/hippolyzer/lib/client/namecache.py @@ -39,7 +39,7 @@ class NameCache: def create_subscriptions( self, - message_handler: MessageHandler[Message], + message_handler: MessageHandler[Message, str], ): message_handler.subscribe("UUIDNameReply", self._handle_uuid_name_reply) diff --git a/hippolyzer/lib/client/object_manager.py b/hippolyzer/lib/client/object_manager.py index 9e8352e..f611a44 100644 --- a/hippolyzer/lib/client/object_manager.py +++ b/hippolyzer/lib/client/object_manager.py @@ -16,6 +16,7 @@ from typing import * from hippolyzer.lib.base.datatypes import UUID, Vector3 from hippolyzer.lib.base.helpers import proxify from hippolyzer.lib.base.message.message import Block, Message +from hippolyzer.lib.base.message.message_handler import MessageHandler from hippolyzer.lib.base.objects import ( normalize_object_update, normalize_terse_object_update, @@ -38,6 +39,7 @@ class UpdateType(enum.IntEnum): PROPERTIES = enum.auto() FAMILY = enum.auto() COSTS = enum.auto() + KILL = enum.auto() class ClientObjectManager: @@ -161,12 +163,30 @@ class ClientObjectManager: return futures +class ObjectEvent: + __slots__ = ("object", "updated", "update_type") + + object: Object + updated: Set[str] + update_type: UpdateType + + def __init__(self, obj: Object, updated: Set[str], update_type: UpdateType): + self.object = obj + self.updated = updated + self.update_type = update_type + + @property + def name(self) -> UpdateType: + return self.update_type + + class ClientWorldObjectManager: """Manages Objects for a session's whole world""" def __init__(self, session: BaseClientSession, settings: Settings, name_cache: Optional[NameCache]): self._session: BaseClientSession = session self._settings = settings self.name_cache = name_cache or NameCache() + self.events: MessageHandler[ObjectEvent, UpdateType] = MessageHandler(take_by_default=False) self._fullid_lookup: Dict[UUID, Object] = {} self._avatars: Dict[UUID, Avatar] = {} self._avatar_objects: Dict[UUID, Object] = {} @@ -549,9 +569,10 @@ class ClientWorldObjectManager: if obj.PCode == PCode.AVATAR and "NameValue" in updated_props: if obj.NameValue: self.name_cache.update(obj.FullID, obj.NameValue.to_dict()) + self.events.handle(ObjectEvent(obj, updated_props, update_type)) def _run_kill_object_hooks(self, obj: Object): - pass + self.events.handle(ObjectEvent(obj, set(), UpdateType.KILL)) def _rebuild_avatar_objects(self): # Get all avatars known through coarse locations and which region the location was in diff --git a/hippolyzer/lib/client/state.py b/hippolyzer/lib/client/state.py index f32793c..c2236fd 100644 --- a/hippolyzer/lib/client/state.py +++ b/hippolyzer/lib/client/state.py @@ -29,7 +29,7 @@ class BaseClientSession(abc.ABC): id: UUID agent_id: UUID secure_session_id: UUID - message_handler: MessageHandler[Message] + message_handler: MessageHandler[Message, str] regions: Sequence[BaseClientRegion] region_by_handle: Callable[[int], Optional[BaseClientRegion]] region_by_circuit_addr: Callable[[ADDR_TUPLE], Optional[BaseClientRegion]] diff --git a/hippolyzer/lib/proxy/namecache.py b/hippolyzer/lib/proxy/namecache.py index b284168..4db3a8b 100644 --- a/hippolyzer/lib/proxy/namecache.py +++ b/hippolyzer/lib/proxy/namecache.py @@ -17,8 +17,8 @@ if TYPE_CHECKING: class ProxyNameCache(NameCache): def create_subscriptions( self, - message_handler: MessageHandler[Message], - http_message_handler: Optional[MessageHandler[HippoHTTPFlow]] = None, + message_handler: MessageHandler[Message, str], + http_message_handler: Optional[MessageHandler[HippoHTTPFlow, str]] = None, ): super().create_subscriptions(message_handler) if http_message_handler is not None: diff --git a/hippolyzer/lib/proxy/region.py b/hippolyzer/lib/proxy/region.py index b373989..024ba33 100644 --- a/hippolyzer/lib/proxy/region.py +++ b/hippolyzer/lib/proxy/region.py @@ -63,8 +63,8 @@ class ProxiedRegion(BaseClientRegion): if seed_cap: self._caps["Seed"] = (CapType.NORMAL, seed_cap) self.session: Callable[[], Session] = weakref.ref(session) - self.message_handler: MessageHandler[Message] = MessageHandler() - self.http_message_handler: MessageHandler[HippoHTTPFlow] = MessageHandler() + self.message_handler: MessageHandler[Message, str] = MessageHandler() + self.http_message_handler: MessageHandler[HippoHTTPFlow, str] = MessageHandler() self.eq_manager = EventQueueManager(self) settings = session.session_manager.settings self.caps_client = ProxyCapsClient(settings, proxify(self)) diff --git a/hippolyzer/lib/proxy/sessions.py b/hippolyzer/lib/proxy/sessions.py index f3b3a29..ef33200 100644 --- a/hippolyzer/lib/proxy/sessions.py +++ b/hippolyzer/lib/proxy/sessions.py @@ -42,8 +42,8 @@ class Session(BaseClientSession): self.selected: SelectionModel = SelectionModel() self.regions: List[ProxiedRegion] = [] self.started_at = datetime.datetime.now() - self.message_handler: MessageHandler[Message] = MessageHandler() - self.http_message_handler: MessageHandler[HippoHTTPFlow] = MessageHandler() + self.message_handler: MessageHandler[Message, str] = MessageHandler() + self.http_message_handler: MessageHandler[HippoHTTPFlow, str] = MessageHandler() self.objects = ProxyWorldObjectManager(self, session_manager.settings, session_manager.name_cache) self._main_region = None diff --git a/tests/base/test_message_wrapper.py b/tests/base/test_message_wrapper.py index 0085719..1f98983 100644 --- a/tests/base/test_message_wrapper.py +++ b/tests/base/test_message_wrapper.py @@ -169,7 +169,7 @@ class TestMessage(unittest.TestCase): class TestMessageHandlers(unittest.IsolatedAsyncioTestCase): def setUp(self) -> None: - self.message_handler: MessageHandler[Message] = MessageHandler() + self.message_handler: MessageHandler[Message, str] = MessageHandler() def _fake_received_message(self, msg: Message): self.message_handler.handle(msg) @@ -203,7 +203,7 @@ class TestMessageHandlers(unittest.IsolatedAsyncioTestCase): self.assertEqual(len(foo_handlers), 0) async def test_subscription_no_take(self): - with self.message_handler.subscribe_async("Foo", take=False) as get_msg: + with self.message_handler.subscribe_async(("Foo",), take=False) as get_msg: msg = Message("Foo", Block("Bar", Baz=1, Biz=1)) self._fake_received_message(msg) # Should not copy @@ -212,7 +212,7 @@ class TestMessageHandlers(unittest.IsolatedAsyncioTestCase): self.assertFalse(msg.queued) async def test_wait_for(self): - fut = self.message_handler.wait_for("Foo", timeout=0.001, take=False) + fut = self.message_handler.wait_for(("Foo",), timeout=0.001, take=False) foo_handlers = self.message_handler.handlers['Foo'] # We are subscribed self.assertEqual(len(foo_handlers), 1) @@ -226,7 +226,7 @@ class TestMessageHandlers(unittest.IsolatedAsyncioTestCase): self.assertEqual(len(foo_handlers), 0) async def test_wait_for_take(self): - fut = self.message_handler.wait_for("Foo", timeout=0.001) + fut = self.message_handler.wait_for(("Foo",), timeout=0.001) foo_handlers = self.message_handler.handlers['Foo'] # We are subscribed self.assertEqual(len(foo_handlers), 1) diff --git a/tests/base/test_xfer_transfer.py b/tests/base/test_xfer_transfer.py index f8460db..73804b1 100644 --- a/tests/base/test_xfer_transfer.py +++ b/tests/base/test_xfer_transfer.py @@ -23,7 +23,7 @@ from hippolyzer.lib.base.xfer_manager import XferManager class MockHandlingCircuit(ProxiedCircuit): - def __init__(self, handler: MessageHandler[Message]): + def __init__(self, handler: MessageHandler[Message, str]): super().__init__(("127.0.0.1", 1), ("127.0.0.1", 2), None) self.handler = handler @@ -42,8 +42,8 @@ class BaseTransferTests(unittest.IsolatedAsyncioTestCase): LARGE_PAYLOAD = b"foobar" * 500 def setUp(self) -> None: - self.server_message_handler: MessageHandler[Message] = MessageHandler() - self.client_message_handler: MessageHandler[Message] = MessageHandler() + self.server_message_handler: MessageHandler[Message, str] = MessageHandler() + self.client_message_handler: MessageHandler[Message, str] = MessageHandler() # The client side should send messages to the server side's message handler # and vice-versa self.client_circuit = MockHandlingCircuit(self.server_message_handler) @@ -60,7 +60,7 @@ class XferManagerTests(BaseTransferTests): self.received_bytes: Optional[bytes] = None async def _handle_vfile_upload(self): - msg = await self.server_message_handler.wait_for('AssetUploadRequest', timeout=0.01) + msg = await self.server_message_handler.wait_for(('AssetUploadRequest',), timeout=0.01) asset_block = msg["AssetBlock"] transaction_id = asset_block["TransactionID"] asset_id = UUID.combine(transaction_id, self.secure_session_id) @@ -102,7 +102,7 @@ class TestTransferManager(BaseTransferTests): ) async def _handle_covenant_download(self): - msg = await self.server_message_handler.wait_for('TransferRequest', timeout=0.01) + msg = await self.server_message_handler.wait_for(('TransferRequest',), timeout=0.01) self.assertEqual(TransferSourceType.SIM_ESTATE, msg["TransferInfo"]["SourceType"]) tid = msg["TransferInfo"]["TransferID"] params: TransferRequestParamsSimEstate = msg["TransferInfo"][0].deserialize_var("Params") diff --git a/tests/proxy/integration/test_lludp.py b/tests/proxy/integration/test_lludp.py index 4018012..55d3e0b 100644 --- a/tests/proxy/integration/test_lludp.py +++ b/tests/proxy/integration/test_lludp.py @@ -247,7 +247,7 @@ class LLUDPIntegrationTests(BaseProxyTest): async def test_session_message_handler(self): self._setup_default_circuit() obj_update = self._make_objectupdate_compressed(1234) - fut = self.session.message_handler.wait_for('ObjectUpdateCompressed') + fut = self.session.message_handler.wait_for(('ObjectUpdateCompressed',)) self.protocol.datagram_received(obj_update, self.region_addr) self.assertEqual("ObjectUpdateCompressed", (await fut).name)