diff --git a/hippolyzer/lib/base/events.py b/hippolyzer/lib/base/events.py index c8aa0ac..0fe5af3 100644 --- a/hippolyzer/lib/base/events.py +++ b/hippolyzer/lib/base/events.py @@ -18,7 +18,7 @@ You should have received a copy of the GNU Lesser General Public License along with this program; if not, write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. """ - +import asyncio from logging import getLogger logger = getLogger('utilities.events') @@ -54,13 +54,22 @@ class Event: def notify(self, args): for handler in self.subscribers[:]: - instance, inner_args, kwargs, one_shot, predicate = handler + handler, inner_args, kwargs, one_shot, predicate = handler if predicate and not predicate(args): continue if one_shot: - self.unsubscribe(instance, *inner_args, **kwargs) - if instance(args, *inner_args, **kwargs) and not one_shot: - self.unsubscribe(instance, *inner_args, **kwargs) + self.unsubscribe(handler, *inner_args, **kwargs) + if asyncio.iscoroutinefunction(handler): + # Note that unsubscription may be delayed due to asyncio scheduling :) + + async def _run_handler_wrapper(): + unsubscribe = await handler(args, *inner_args, **kwargs) + if unsubscribe: + self.unsubscribe(handler, *inner_args, **kwargs) + asyncio.get_event_loop().create_task(_run_handler_wrapper()) + else: + if handler(args, *inner_args, **kwargs) and not one_shot: + self.unsubscribe(handler, *inner_args, **kwargs) def __len__(self): return len(self.subscribers) diff --git a/hippolyzer/lib/client/hippo_client.py b/hippolyzer/lib/client/hippo_client.py index 1b7bfbc..f33b06f 100644 --- a/hippolyzer/lib/client/hippo_client.py +++ b/hippolyzer/lib/client/hippo_client.py @@ -35,6 +35,7 @@ LOG = logging.getLogger(__name__) class HippoCapsClient(CapsClient): def _request_fixups(self, cap_or_url: str, headers: Dict, proxy: Optional[bool], ssl: Any): headers["User-Agent"] = f"Hippolyzer/v{version('hippolyzer')}" + return cap_or_url, headers, proxy, False class HippoClientProtocol(asyncio.DatagramProtocol): @@ -120,7 +121,7 @@ class HippoClientSession(BaseClientSession): handle: Optional[int] = None) -> HippoClientRegion: return super().register_region(circuit_addr, seed_url, handle) # type:ignore - def open_circuit(self, circuit_addr): + async def open_circuit(self, circuit_addr): for region in self.regions: if region.circuit_addr == circuit_addr: valid_circuit = False @@ -136,13 +137,13 @@ class HippoClientSession(BaseClientSession): if valid_circuit: # TODO: This is a little bit crap, we need to know if a UseCircuitCode was ever ACKed # before we can start sending other packets, otherwise we might have a race. - region.circuit.send_reliable( + await region.circuit.send_reliable( Message( "UseCircuitCode", Block("CircuitCode", Code=self.circuit_code, SessionID=self.id, ID=self.agent_id), ) ) - # TODO: set this in a callback for UseCircuitCode ACK + # TODO: What happens if a circuit code is invalid, again? region.circuit.is_alive = True return valid_circuit return False @@ -375,10 +376,13 @@ class HippoClient(BaseClientSessionManager): self.session.transport, self.session.protocol = await self._create_transport() self._resend_task = asyncio.create_task(self._attempt_resends()) - assert self.session.open_circuit(self.session.regions[-1].circuit_addr) + assert await self.session.open_circuit(self.session.regions[-1].circuit_addr) region = self.session.regions[-1] self.session.main_region = region + # Kick this off and await it later + seed_resp_fut = region.caps_client.post("Seed", session=self.http_session, llsd=list(self.SUPPORTED_CAPS)) + # Register first so we can handle it even if the ack happens after the message is sent region_handshake_fut = region.message_handler.wait_for(("RegionHandshake",)) await region.circuit.send_reliable( @@ -423,6 +427,8 @@ class HippoClient(BaseClientSessionManager): ) ) ) + async with seed_resp_fut as seed_resp: + region.update_caps(await seed_resp.read_llsd()) async def logout(self): if not self.session: diff --git a/tests/base/test_events.py b/tests/base/test_events.py new file mode 100644 index 0000000..4079cfa --- /dev/null +++ b/tests/base/test_events.py @@ -0,0 +1,51 @@ +import asyncio +import unittest +from unittest.mock import MagicMock + +from hippolyzer.lib.base.events import Event + + +class TestEvents(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + self.event = Event() + + async def test_trigger_sync(self): + mock = MagicMock(return_value=False) + self.event.subscribe(mock) + self.event.notify("foo") + mock.assert_called_with("foo") + self.assertIn(mock, [x[0] for x in self.event.subscribers]) + + async def test_trigger_sync_unsub(self): + mock = MagicMock(return_value=True) + self.event.subscribe(mock) + self.event.notify("foo") + mock.assert_called_with("foo") + self.assertNotIn(mock, [x[0] for x in self.event.subscribers]) + + async def test_trigger_async(self): + called = asyncio.Event() + mock = MagicMock() + + async def _mock_wrapper(*args, **kwargs): + called.set() + mock(*args, **kwargs) + self.event.subscribe(_mock_wrapper) + self.event.notify("foo") + await called.wait() + mock.assert_called_with("foo") + self.assertIn(_mock_wrapper, [x[0] for x in self.event.subscribers]) + + async def test_trigger_async_unsub(self): + called = asyncio.Event() + mock = MagicMock() + + async def _mock_wrapper(*args, **kwargs): + called.set() + mock(*args, **kwargs) + return True + self.event.subscribe(_mock_wrapper) + self.event.notify("foo") + await called.wait() + mock.assert_called_with("foo") + self.assertNotIn(_mock_wrapper, [x[0] for x in self.event.subscribers]) diff --git a/tests/base/test_settings.py b/tests/base/test_settings.py index 738875a..4c29bcf 100644 --- a/tests/base/test_settings.py +++ b/tests/base/test_settings.py @@ -23,13 +23,7 @@ import unittest from hippolyzer.lib.base.settings import Settings -class TestEvents(unittest.TestCase): - def setUp(self): - pass - - def tearDown(self): - pass - +class TestSettings(unittest.TestCase): def test_base_settings(self): settings = Settings() self.assertEqual(settings.ENABLE_DEFERRED_PACKET_PARSING, True) diff --git a/tests/client/test_hippo_client.py b/tests/client/test_hippo_client.py index f281549..60085d2 100644 --- a/tests/client/test_hippo_client.py +++ b/tests/client/test_hippo_client.py @@ -99,6 +99,7 @@ class TestHippoClient(unittest.IsolatedAsyncioTestCase): async def _do_login(): with aioresponses.aioresponses() as m: m.post(self.FAKE_LOGIN_URI, body=self._make_fake_login_body()) + m.post(self.FAKE_LOGIN_RESP['seed_capability'], body="") await client.login("foo", "bar", login_uri=self.FAKE_LOGIN_URI) await client.logout()