diff --git a/hippolyzer/lib/proxy/task_scheduler.py b/hippolyzer/lib/proxy/task_scheduler.py index 26e1f86..e8bbfaf 100644 --- a/hippolyzer/lib/proxy/task_scheduler.py +++ b/hippolyzer/lib/proxy/task_scheduler.py @@ -63,9 +63,14 @@ class TaskScheduler: def shutdown(self): for task_data, task in self.tasks: task.cancel() - await_all = asyncio.gather(*(task for task_data, task in self.tasks)) - event_loop = asyncio.get_event_loop_policy().get_event_loop() - event_loop.run_until_complete(await_all) + + try: + event_loop = asyncio.get_running_loop() + await_all = asyncio.gather(*(task for task_data, task in self.tasks)) + event_loop.run_until_complete(await_all) + except RuntimeError: + pass + self.tasks.clear() def _task_done(self, task: asyncio.Task): for task_details in reversed(self.tasks): diff --git a/tests/proxy/test_httpflows.py b/tests/proxy/test_httpflows.py index 6a72eee..587e44e 100644 --- a/tests/proxy/test_httpflows.py +++ b/tests/proxy/test_httpflows.py @@ -19,7 +19,7 @@ class TestHTTPFlows(BaseProxyTest): "ViewerAsset": "http://assets.example.com", }) - def test_request_formatting(self): + async def test_request_formatting(self): req = tutils.treq(host="example.com", port=80) fake_flow = tflow.tflow(req=req, resp=tutils.tresp()) flow = HippoHTTPFlow.from_state(fake_flow.get_state(), self.session_manager) @@ -33,7 +33,7 @@ content-length: 7\r \r content""") - def test_binary_request_formatting(self): + async def test_binary_request_formatting(self): req = tutils.treq(host="example.com", port=80) fake_flow = tflow.tflow(req=req, resp=tutils.tresp()) flow = HippoHTTPFlow.from_state(fake_flow.get_state(), self.session_manager) @@ -47,7 +47,7 @@ X-Hippo-Escaped-Body: 1\r \r c\\x00ntent""") - def test_llsd_response_formatting(self): + async def test_llsd_response_formatting(self): fake_flow = tflow.tflow(req=tutils.treq(), resp=tutils.tresp()) flow = HippoHTTPFlow.from_state(fake_flow.get_state(), self.session_manager) # Half the time LLSD is sent with a random Content-Type and no PI indicating @@ -64,7 +64,7 @@ content-length: 33\r """) - def test_flow_state_serde(self): + async def test_flow_state_serde(self): fake_flow = tflow.tflow(req=tutils.treq(host="example.com"), resp=tutils.tresp()) flow = HippoHTTPFlow.from_state(fake_flow.get_state(), self.session_manager) # Make sure cap resolution works correctly @@ -73,7 +73,7 @@ content-length: 33\r new_flow = HippoHTTPFlow.from_state(flow_state, self.session_manager) self.assertIs(self.session, new_flow.cap_data.session()) - def test_http_asset_repo(self): + async def test_http_asset_repo(self): asset_repo = self.session_manager.asset_repo asset_id = asset_repo.create_asset(b"foobar", one_shot=True) req = tutils.treq(host="assets.example.com", path=f"/?animatn_id={asset_id}") @@ -84,7 +84,7 @@ content-length: 33\r self.assertTrue(asset_repo.try_serve_asset(flow)) self.assertEqual(b"foobar", flow.response.content) - def test_temporary_cap_resolution(self): + async def test_temporary_cap_resolution(self): self.region.register_cap("TempExample", "http://not.example.com", CapType.TEMPORARY) self.region.register_cap("TempExample", "http://not2.example.com", CapType.TEMPORARY) # Resolving the cap should consume it diff --git a/tests/proxy/test_message_filter.py b/tests/proxy/test_message_filter.py index 004b3b7..9493ec9 100644 --- a/tests/proxy/test_message_filter.py +++ b/tests/proxy/test_message_filter.py @@ -130,7 +130,7 @@ class MessageFilterTests(unittest.IsolatedAsyncioTestCase): # Make sure numbers outside 32bit range come through self.assertTrue(self._filter_matches("Foo.Bar.Foo == 0xFFffFFffFF", msg)) - def test_http_flow(self): + async def test_http_flow(self): session_manager = SessionManager(ProxySettings()) fake_flow = tflow.tflow(req=tutils.treq(), resp=tutils.tresp()) fake_flow.metadata["cap_data_ser"] = SerializedCapData( @@ -141,7 +141,7 @@ class MessageFilterTests(unittest.IsolatedAsyncioTestCase): self.assertTrue(self._filter_matches("FakeCap", entry)) self.assertFalse(self._filter_matches("NotFakeCap", entry)) - def test_http_header_filter(self): + async def test_http_header_filter(self): session_manager = SessionManager(ProxySettings()) fake_flow = tflow.tflow(req=tutils.treq(), resp=tutils.tresp()) fake_flow.request.headers["Cookie"] = 'foo="bar"' @@ -151,7 +151,7 @@ class MessageFilterTests(unittest.IsolatedAsyncioTestCase): self.assertTrue(self._filter_matches('Meta.ReqHeaders.cookie ~= "foo"', entry)) self.assertFalse(self._filter_matches('Meta.ReqHeaders.foobar ~= "foo"', entry)) - def test_export_import_http_flow(self): + async def test_export_import_http_flow(self): fake_flow = tflow.tflow(req=tutils.treq(), resp=tutils.tresp()) fake_flow.metadata["cap_data_ser"] = SerializedCapData( cap_name="FakeCap",