Add ability to wait on object-related events

This commit is contained in:
Salad Dais
2021-06-12 10:43:16 +00:00
parent 8347b341f5
commit 41fd67577a
15 changed files with 65 additions and 42 deletions

View File

@@ -40,7 +40,7 @@ class TransferExampleAddon(BaseAddon):
Block('AgentData', AgentID=session.agent_id, SessionID=session.id),
Block('InventoryData', LocalID=session.selected.object_local),
))
inv_message = await region.message_handler.wait_for('ReplyTaskInventory', timeout=5.0)
inv_message = await region.message_handler.wait_for(('ReplyTaskInventory',), timeout=5.0)
# Xfer the inventory file and look for a script
xfer = await region.xfer_manager.request(

View File

@@ -65,7 +65,7 @@ class TurboObjectInventoryAddon(BaseAddon):
# by marking it complete on the server-side. Re-send our RequestTaskInventory
# To make sure there's a fresh copy.
region.circuit.send_message(request_msg.take())
inv_message = await region.message_handler.wait_for('ReplyTaskInventory', timeout=5.0)
inv_message = await region.message_handler.wait_for(('ReplyTaskInventory',), timeout=5.0)
# No task inventory, send the reply as-is
file_name = inv_message["InventoryData"]["Filename"]
if not file_name:

View File

@@ -22,7 +22,7 @@ class XferExampleAddon(BaseAddon):
))
# Wait for any MuteListUpdate, dropping it before it reaches the viewer
update_msg = await region.message_handler.wait_for('MuteListUpdate', timeout=5.0)
update_msg = await region.message_handler.wait_for(('MuteListUpdate',), timeout=5.0)
mute_file_name = update_msg["MuteData"]["Filename"]
if not mute_file_name:
show_message("Nobody muted?")
@@ -42,7 +42,7 @@ class XferExampleAddon(BaseAddon):
Block('InventoryData', LocalID=session.selected.object_local),
))
inv_message = await region.message_handler.wait_for('ReplyTaskInventory', timeout=5.0)
inv_message = await region.message_handler.wait_for(('ReplyTaskInventory',), timeout=5.0)
# Xfer doesn't need to be immediately awaited, multiple signals can be waited on.
xfer = region.xfer_manager.request(

View File

@@ -77,4 +77,4 @@ class ConnectionHolder(abc.ABC):
lifetime of a session (due to region restarts, etc.)
"""
circuit: Optional[Circuit]
message_handler: MessageHandler[Message]
message_handler: MessageHandler[Message, str]

View File

@@ -28,28 +28,28 @@ from hippolyzer.lib.base.events import Event
LOG = logging.getLogger(__name__)
_T = TypeVar("_T")
_K = TypeVar("_K", bound=Hashable)
MESSAGE_HANDLER = Callable[[_T], Any]
PREDICATE = Callable[[_T], bool]
MESSAGE_NAMES = Union[str, Iterable[str]]
MESSAGE_NAMES = Iterable[_K]
class MessageHandler(Generic[_T]):
def __init__(self):
self.handlers: Dict[str, Event] = {}
class MessageHandler(Generic[_T, _K]):
def __init__(self, take_by_default: bool = True):
self.handlers: Dict[_K, Event] = {}
self.take_by_default = take_by_default
def register(self, message_name: str) -> Event:
def register(self, message_name: _K) -> Event:
LOG.debug('Creating a monitor for %s' % message_name)
return self.handlers.setdefault(message_name, Event())
def subscribe(self, message_name: str, handler: MESSAGE_HANDLER) -> Event:
def subscribe(self, message_name: _K, handler: MESSAGE_HANDLER) -> Event:
notifier = self.register(message_name)
notifier.subscribe(handler)
return notifier
def _subscribe_all(self, message_names: MESSAGE_NAMES, handler: MESSAGE_HANDLER,
predicate: Optional[PREDICATE] = None) -> List[Event]:
if isinstance(message_names, str):
message_names = (message_names,)
notifiers = [self.register(name) for name in message_names]
for n in notifiers:
n.subscribe(handler, predicate=predicate)
@@ -57,7 +57,7 @@ class MessageHandler(Generic[_T]):
@contextlib.contextmanager
def subscribe_async(self, message_names: MESSAGE_NAMES, predicate: Optional[PREDICATE] = None,
take: bool = True) -> ContextManager[Callable[[], Awaitable[_T]]]:
take: Optional[bool] = None) -> ContextManager[Callable[[], Awaitable[_T]]]:
"""
Subscribe to a set of message matching predicate while within a block
@@ -69,6 +69,8 @@ class MessageHandler(Generic[_T]):
If a subscriber is just an observer that will never drop or modify a message, take=False
may be used and messages will be sent as usual.
"""
if take is None:
take = self.take_by_default
msg_queue = asyncio.Queue()
def _handler_wrapper(message: _T):
@@ -91,8 +93,8 @@ class MessageHandler(Generic[_T]):
for n in notifiers:
n.unsubscribe(_handler_wrapper)
def wait_for(self, message_names: MESSAGE_NAMES,
predicate: Optional[PREDICATE] = None, timeout=None, take=True) -> Awaitable[_T]:
def wait_for(self, message_names: MESSAGE_NAMES, predicate: Optional[PREDICATE] = None,
timeout: Optional[float] = None, take: Optional[bool] = None) -> Awaitable[_T]:
"""
Wait for a single instance one of message_names matching predicate
@@ -101,8 +103,8 @@ class MessageHandler(Generic[_T]):
sequence of packets, since multiple packets may come in after the future has already
been marked completed, causing some to be missed.
"""
if isinstance(message_names, str):
message_names = (message_names,)
if take is None:
take = self.take_by_default
notifiers = [self.register(name) for name in message_names]
fut = asyncio.get_event_loop().create_future()
@@ -132,7 +134,7 @@ class MessageHandler(Generic[_T]):
notifier.subscribe(_handler, predicate=predicate)
return fut
def is_handled(self, message_name: str):
def is_handled(self, message_name: _K):
return message_name in self.handlers
def handle(self, message: _T):
@@ -140,7 +142,7 @@ class MessageHandler(Generic[_T]):
# Always try to call wildcard handlers
self._handle_type('*', message)
def _handle_type(self, name: str, message: _T):
def _handle_type(self, name: _K, message: _T):
handler = self.handlers.get(name)
if not handler:
return

View File

@@ -246,7 +246,7 @@ class XferManager:
def complete_predicate(complete_msg: Message):
return complete_msg["AssetBlock"]["UUID"] == asset_id
msg = await message_handler.wait_for('AssetUploadComplete', predicate=complete_predicate)
msg = await message_handler.wait_for(('AssetUploadComplete',), predicate=complete_predicate)
if msg["AssetBlock"]["Success"] == 1:
fut.set_result(asset_id)
else:
@@ -263,7 +263,7 @@ class XferManager:
):
message_handler = self._connection_holder.message_handler
request_msg = await message_handler.wait_for(
'RequestXfer', predicate=request_predicate, timeout=5.0)
('RequestXfer',), predicate=request_predicate, timeout=5.0)
xfer.xfer_id = request_msg["XferID"]["ID"]
packet_id = 0
@@ -282,5 +282,5 @@ class XferManager:
# Don't care about the value, just want to know it was confirmed.
if wait_for_confirm:
await message_handler.wait_for(
"ConfirmXferPacket", predicate=xfer.is_our_message, timeout=5.0)
("ConfirmXferPacket",), predicate=xfer.is_our_message, timeout=5.0)
packet_id += 1

View File

@@ -39,7 +39,7 @@ class NameCache:
def create_subscriptions(
self,
message_handler: MessageHandler[Message],
message_handler: MessageHandler[Message, str],
):
message_handler.subscribe("UUIDNameReply", self._handle_uuid_name_reply)

View File

@@ -16,6 +16,7 @@ from typing import *
from hippolyzer.lib.base.datatypes import UUID, Vector3
from hippolyzer.lib.base.helpers import proxify
from hippolyzer.lib.base.message.message import Block, Message
from hippolyzer.lib.base.message.message_handler import MessageHandler
from hippolyzer.lib.base.objects import (
normalize_object_update,
normalize_terse_object_update,
@@ -38,6 +39,7 @@ class UpdateType(enum.IntEnum):
PROPERTIES = enum.auto()
FAMILY = enum.auto()
COSTS = enum.auto()
KILL = enum.auto()
class ClientObjectManager:
@@ -161,12 +163,30 @@ class ClientObjectManager:
return futures
class ObjectEvent:
__slots__ = ("object", "updated", "update_type")
object: Object
updated: Set[str]
update_type: UpdateType
def __init__(self, obj: Object, updated: Set[str], update_type: UpdateType):
self.object = obj
self.updated = updated
self.update_type = update_type
@property
def name(self) -> UpdateType:
return self.update_type
class ClientWorldObjectManager:
"""Manages Objects for a session's whole world"""
def __init__(self, session: BaseClientSession, settings: Settings, name_cache: Optional[NameCache]):
self._session: BaseClientSession = session
self._settings = settings
self.name_cache = name_cache or NameCache()
self.events: MessageHandler[ObjectEvent, UpdateType] = MessageHandler(take_by_default=False)
self._fullid_lookup: Dict[UUID, Object] = {}
self._avatars: Dict[UUID, Avatar] = {}
self._avatar_objects: Dict[UUID, Object] = {}
@@ -549,9 +569,10 @@ class ClientWorldObjectManager:
if obj.PCode == PCode.AVATAR and "NameValue" in updated_props:
if obj.NameValue:
self.name_cache.update(obj.FullID, obj.NameValue.to_dict())
self.events.handle(ObjectEvent(obj, updated_props, update_type))
def _run_kill_object_hooks(self, obj: Object):
pass
self.events.handle(ObjectEvent(obj, set(), UpdateType.KILL))
def _rebuild_avatar_objects(self):
# Get all avatars known through coarse locations and which region the location was in

View File

@@ -29,7 +29,7 @@ class BaseClientSession(abc.ABC):
id: UUID
agent_id: UUID
secure_session_id: UUID
message_handler: MessageHandler[Message]
message_handler: MessageHandler[Message, str]
regions: Sequence[BaseClientRegion]
region_by_handle: Callable[[int], Optional[BaseClientRegion]]
region_by_circuit_addr: Callable[[ADDR_TUPLE], Optional[BaseClientRegion]]

View File

@@ -17,8 +17,8 @@ if TYPE_CHECKING:
class ProxyNameCache(NameCache):
def create_subscriptions(
self,
message_handler: MessageHandler[Message],
http_message_handler: Optional[MessageHandler[HippoHTTPFlow]] = None,
message_handler: MessageHandler[Message, str],
http_message_handler: Optional[MessageHandler[HippoHTTPFlow, str]] = None,
):
super().create_subscriptions(message_handler)
if http_message_handler is not None:

View File

@@ -63,8 +63,8 @@ class ProxiedRegion(BaseClientRegion):
if seed_cap:
self._caps["Seed"] = (CapType.NORMAL, seed_cap)
self.session: Callable[[], Session] = weakref.ref(session)
self.message_handler: MessageHandler[Message] = MessageHandler()
self.http_message_handler: MessageHandler[HippoHTTPFlow] = MessageHandler()
self.message_handler: MessageHandler[Message, str] = MessageHandler()
self.http_message_handler: MessageHandler[HippoHTTPFlow, str] = MessageHandler()
self.eq_manager = EventQueueManager(self)
settings = session.session_manager.settings
self.caps_client = ProxyCapsClient(settings, proxify(self))

View File

@@ -42,8 +42,8 @@ class Session(BaseClientSession):
self.selected: SelectionModel = SelectionModel()
self.regions: List[ProxiedRegion] = []
self.started_at = datetime.datetime.now()
self.message_handler: MessageHandler[Message] = MessageHandler()
self.http_message_handler: MessageHandler[HippoHTTPFlow] = MessageHandler()
self.message_handler: MessageHandler[Message, str] = MessageHandler()
self.http_message_handler: MessageHandler[HippoHTTPFlow, str] = MessageHandler()
self.objects = ProxyWorldObjectManager(self, session_manager.settings, session_manager.name_cache)
self._main_region = None

View File

@@ -169,7 +169,7 @@ class TestMessage(unittest.TestCase):
class TestMessageHandlers(unittest.IsolatedAsyncioTestCase):
def setUp(self) -> None:
self.message_handler: MessageHandler[Message] = MessageHandler()
self.message_handler: MessageHandler[Message, str] = MessageHandler()
def _fake_received_message(self, msg: Message):
self.message_handler.handle(msg)
@@ -203,7 +203,7 @@ class TestMessageHandlers(unittest.IsolatedAsyncioTestCase):
self.assertEqual(len(foo_handlers), 0)
async def test_subscription_no_take(self):
with self.message_handler.subscribe_async("Foo", take=False) as get_msg:
with self.message_handler.subscribe_async(("Foo",), take=False) as get_msg:
msg = Message("Foo", Block("Bar", Baz=1, Biz=1))
self._fake_received_message(msg)
# Should not copy
@@ -212,7 +212,7 @@ class TestMessageHandlers(unittest.IsolatedAsyncioTestCase):
self.assertFalse(msg.queued)
async def test_wait_for(self):
fut = self.message_handler.wait_for("Foo", timeout=0.001, take=False)
fut = self.message_handler.wait_for(("Foo",), timeout=0.001, take=False)
foo_handlers = self.message_handler.handlers['Foo']
# We are subscribed
self.assertEqual(len(foo_handlers), 1)
@@ -226,7 +226,7 @@ class TestMessageHandlers(unittest.IsolatedAsyncioTestCase):
self.assertEqual(len(foo_handlers), 0)
async def test_wait_for_take(self):
fut = self.message_handler.wait_for("Foo", timeout=0.001)
fut = self.message_handler.wait_for(("Foo",), timeout=0.001)
foo_handlers = self.message_handler.handlers['Foo']
# We are subscribed
self.assertEqual(len(foo_handlers), 1)

View File

@@ -23,7 +23,7 @@ from hippolyzer.lib.base.xfer_manager import XferManager
class MockHandlingCircuit(ProxiedCircuit):
def __init__(self, handler: MessageHandler[Message]):
def __init__(self, handler: MessageHandler[Message, str]):
super().__init__(("127.0.0.1", 1), ("127.0.0.1", 2), None)
self.handler = handler
@@ -42,8 +42,8 @@ class BaseTransferTests(unittest.IsolatedAsyncioTestCase):
LARGE_PAYLOAD = b"foobar" * 500
def setUp(self) -> None:
self.server_message_handler: MessageHandler[Message] = MessageHandler()
self.client_message_handler: MessageHandler[Message] = MessageHandler()
self.server_message_handler: MessageHandler[Message, str] = MessageHandler()
self.client_message_handler: MessageHandler[Message, str] = MessageHandler()
# The client side should send messages to the server side's message handler
# and vice-versa
self.client_circuit = MockHandlingCircuit(self.server_message_handler)
@@ -60,7 +60,7 @@ class XferManagerTests(BaseTransferTests):
self.received_bytes: Optional[bytes] = None
async def _handle_vfile_upload(self):
msg = await self.server_message_handler.wait_for('AssetUploadRequest', timeout=0.01)
msg = await self.server_message_handler.wait_for(('AssetUploadRequest',), timeout=0.01)
asset_block = msg["AssetBlock"]
transaction_id = asset_block["TransactionID"]
asset_id = UUID.combine(transaction_id, self.secure_session_id)
@@ -102,7 +102,7 @@ class TestTransferManager(BaseTransferTests):
)
async def _handle_covenant_download(self):
msg = await self.server_message_handler.wait_for('TransferRequest', timeout=0.01)
msg = await self.server_message_handler.wait_for(('TransferRequest',), timeout=0.01)
self.assertEqual(TransferSourceType.SIM_ESTATE, msg["TransferInfo"]["SourceType"])
tid = msg["TransferInfo"]["TransferID"]
params: TransferRequestParamsSimEstate = msg["TransferInfo"][0].deserialize_var("Params")

View File

@@ -247,7 +247,7 @@ class LLUDPIntegrationTests(BaseProxyTest):
async def test_session_message_handler(self):
self._setup_default_circuit()
obj_update = self._make_objectupdate_compressed(1234)
fut = self.session.message_handler.wait_for('ObjectUpdateCompressed')
fut = self.session.message_handler.wait_for(('ObjectUpdateCompressed',))
self.protocol.datagram_received(obj_update, self.region_addr)
self.assertEqual("ObjectUpdateCompressed", (await fut).name)