From 1165769aca17ef4db7732efe8a0d935e113ca5cc Mon Sep 17 00:00:00 2001 From: Salad Dais Date: Mon, 18 Dec 2023 05:27:32 +0000 Subject: [PATCH] Start writing voice client tests --- hippolyzer/lib/voice/client.py | 26 +++++----- hippolyzer/lib/voice/connection.py | 37 ++++++++++---- tests/voice/test_voice.py | 79 ++++++++++++++++++++++++++---- 3 files changed, 113 insertions(+), 29 deletions(-) diff --git a/hippolyzer/lib/voice/client.py b/hippolyzer/lib/voice/client.py index 3f5ead1..7bddb40 100644 --- a/hippolyzer/lib/voice/client.py +++ b/hippolyzer/lib/voice/client.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import base64 import json @@ -324,7 +326,7 @@ class VoiceClient: }) def send_message(self, msg_type: str, data: Any) -> asyncio.Future[dict]: - request_id = str(uuid.uuid4()) + request_id = self._make_request_id() # This is apparently what the viewer does, not clear if # request_id has any semantic significance if msg_type == "Session.Create.1": @@ -352,20 +354,19 @@ class VoiceClient: while not self.vivox_conn: await asyncio.sleep(0.001) - async for msg_type, msg_action, request_id, dict_msg in self.vivox_conn.read_messages(): + async for msg in self.vivox_conn.read_messages(): try: - if msg_type == "Event": - RESP_LOG.debug("%s %s %r" % ("Event", msg_action, dict_msg)) - self.message_handler.handle(VoiceEvent(msg_action, dict_msg)) + RESP_LOG.debug(repr(msg)) + if msg.type == "Event": + self.message_handler.handle(VoiceEvent(msg.action, msg.data)) # Spin off handler tasks for each event so that we don't block polling - _ = asyncio.get_event_loop().create_task(self._dispatch_received_event(msg_action, dict_msg)) - elif msg_type == "Response": - RESP_LOG.debug("%s %s %s %r" % ("Response", request_id, msg_action, dict_msg)) + _ = asyncio.get_event_loop().create_task(self._dispatch_received_event(msg.action, msg.data)) + elif msg.type == "Response": # Might not have this request ID if it was sent directly via the socket - if request_id in self._pending_req_futures: - self._pending_req_futures[request_id].set_result(dict_msg) - del self._pending_req_futures[request_id] + if msg.request_id in self._pending_req_futures: + self._pending_req_futures[msg.request_id].set_result(msg.data) + del self._pending_req_futures[msg.request_id] except Exception: LOG.exception("Error in response handler?") @@ -450,3 +451,6 @@ class VoiceClient: if self._region_global_x is not None: pos = self.region_pos self.channel_info_updated.notify(pos) + + def _make_request_id(self): + return str(uuid.uuid4()) diff --git a/hippolyzer/lib/voice/connection.py b/hippolyzer/lib/voice/connection.py index 2e1b16f..5eabeb9 100644 --- a/hippolyzer/lib/voice/connection.py +++ b/hippolyzer/lib/voice/connection.py @@ -1,14 +1,22 @@ # TODO: some fancy parser that parses everything into # dicts or objects using schemas. +from __future__ import annotations import asyncio import weakref -from typing import Any, Tuple, Optional, Coroutine +from typing import Any, Optional, Coroutine, NamedTuple import defusedxml.lxml import lxml.etree +class VivoxMessage(NamedTuple): + type: str + action: str + request_id: Optional[str] + data: dict + + def xml_to_dict(element): return element.tag, dict(map(xml_to_dict, element)) or element.text @@ -96,7 +104,7 @@ class VivoxConnection: msg = await self._reader.readuntil(b"\n\n\n") return self.parse(msg[:-3]) - def parse(self, raw_msg) -> Tuple[str, Optional[str], Optional[str], dict]: + def parse(self, raw_msg) -> VivoxMessage: parsed_msg = defusedxml.lxml.fromstring(raw_msg.decode("utf8")) msg_type = parsed_msg.tag request_id = parsed_msg.attrib.get("requestId", None) @@ -114,10 +122,9 @@ class VivoxConnection: dict_msg = _clean_message(msg_action, parsed_msg, dict_msg) elif msg_type == "Request": msg_action = parsed_msg.attrib.get("action") - dict_msg = _clean_message(msg_action, parsed_msg, dict_msg) else: raise Exception("Unknown Vivox message type %r?" % msg_type) - return msg_type, msg_action, request_id, dict_msg + return VivoxMessage(msg_type, msg_action, request_id, dict_msg) def send_raw(self, buf: bytes) -> Coroutine[Any, Any, None]: self._writer.write(buf + b"\n\n\n") @@ -126,12 +133,24 @@ class VivoxConnection: weakref.finalize(drain_coro, drain_coro.close) return drain_coro - def send_request(self, request_id: str, msg_type: str, data: Any) -> Coroutine[Any, Any, None]: - elem = lxml.etree.Element("Request") - elem.attrib["requestId"] = request_id - elem.attrib["action"] = msg_type - if msg_type == "Account.WebCall.1": + def send_request(self, request_id: str, action: str, data: Any) -> Coroutine[Any, Any, None]: + if action == "Account.WebCall.1": data = dict(data) data["Parameters"] = _build_webcall_params(data["Parameters"]) + return self._send_request_response("Request", request_id, action, data) + + def send_response(self, request_id: str, action: str, data: Any) -> Coroutine[Any, Any, None]: + return self._send_request_response("Response", request_id, action, data) + + def _send_request_response(self, msg_type: str, request_id: str, action: str, data: Any): + elem = lxml.etree.Element(msg_type) + elem.attrib["requestId"] = request_id + elem.attrib["action"] = action + serialized = lxml.etree.tostring(buildxml(elem, data)) + return self.send_raw(serialized) + + def send_event(self, event_type: str, data: Any) -> Coroutine[Any, Any, None]: + elem = lxml.etree.Element("Event") + elem.attrib["type"] = event_type serialized = lxml.etree.tostring(buildxml(elem, data)) return self.send_raw(serialized) diff --git a/tests/voice/test_voice.py b/tests/voice/test_voice.py index f8dd7be..bd7b14e 100644 --- a/tests/voice/test_voice.py +++ b/tests/voice/test_voice.py @@ -4,12 +4,13 @@ import asyncio import unittest from unittest import mock +from hippolyzer.lib.voice.client import VoiceClient from hippolyzer.lib.voice.connection import VivoxConnection -def _make_transport(buf: Any): +def _make_transport(write_func): transport = mock.Mock() - transport.write.side_effect = buf.extend + transport.write.side_effect = write_func transport.is_closing.return_value = False return transport @@ -23,7 +24,7 @@ def _make_protocol(transport: Any): class TestVivoxConnection(unittest.IsolatedAsyncioTestCase): async def asyncSetUp(self): self._writer_buf = bytearray() - self._transport = _make_transport(self._writer_buf) + self._transport = _make_transport(self._writer_buf.extend) self._protocol = _make_protocol(self._transport) self.reader = asyncio.StreamReader() self.writer = asyncio.StreamWriter(self._transport, self._protocol, self.reader, asyncio.get_event_loop()) @@ -90,15 +91,15 @@ class TestVivoxConnection(unittest.IsolatedAsyncioTestCase): self.reader.feed_eof() i = 0 - async for msg_type, msg_action, request_id, body in self.vivox_connection.read_messages(): + async for msg in self.vivox_connection.read_messages(): if i == 0: - self.assertEqual("foobar", request_id) + self.assertEqual("foobar", msg.request_id) else: - self.assertEqual("quux", request_id) - self.assertEqual("Request", msg_type) + self.assertEqual("quux", msg.request_id) - self.assertEqual("Aux.GetRenderDevices.1", msg_action) - self.assertDictEqual({"Foo": "1"}, body) + self.assertEqual("Request", msg.type) + self.assertEqual("Aux.GetRenderDevices.1", msg.action) + self.assertDictEqual({"Foo": "1"}, msg.data) i += 1 async def test_send_message(self): @@ -107,3 +108,63 @@ class TestVivoxConnection(unittest.IsolatedAsyncioTestCase): b'1\n\n\n', self._writer_buf ) + + +class TestVoiceClient(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + self._client_transport = _make_transport( + lambda *args: asyncio.get_event_loop().call_soon(self.server_reader.feed_data, *args) + ) + self._client_protocol = _make_protocol(self._client_transport) + self.client_reader = asyncio.StreamReader() + self.client_writer = asyncio.StreamWriter( + self._client_transport, + self._client_protocol, + self.client_reader, + asyncio.get_event_loop() + ) + + self._server_transport = _make_transport( + lambda *args: asyncio.get_event_loop().call_soon(self.client_reader.feed_data, *args) + ) + self._server_protocol = _make_protocol(self._server_transport) + self.server_reader = asyncio.StreamReader() + self.server_writer = asyncio.StreamWriter( + self._server_transport, + self._server_protocol, + self.server_reader, + asyncio.get_event_loop() + ) + + self.client_connection = VivoxConnection(self.client_reader, self.client_writer) + self.server_connection = VivoxConnection(self.server_reader, self.server_writer) + self.client = VoiceClient("127.0.0.1", 0) + self.client.vivox_conn = self.client_connection + + def _make_request_id(): + _make_request_id.i += 1 + return str(_make_request_id.i) + + _make_request_id.i = 0 + + self.client._make_request_id = _make_request_id + + async def test_connection(self): + async def _serve_login(): + await self.server_connection.send_event( + "VoiceServiceConnectionStateChangedEvent", + { + "Connected": 1, + "Platform": "Linux", + "Version": 1, + "DataDirectory": "/tmp/whatever", + } + ) + self.assertEqual( + ('Request', 'Aux.GetCaptureDevices.1', '1', {}), + await self.server_connection.read_message() + ) + + serve_coro = asyncio.get_event_loop().create_task(_serve_login()) + # Await this here so we can see any exceptions + await serve_coro