From 613dd32a40a2b9b78fd39853ea6cbc12ebe2698a Mon Sep 17 00:00:00 2001 From: Salad Dais Date: Mon, 18 Dec 2023 03:29:40 +0000 Subject: [PATCH] Add tests for voice stuff --- hippolyzer/lib/voice/client.py | 4 +- hippolyzer/lib/voice/connection.py | 16 +++-- tests/voice/__init__.py | 0 tests/voice/test_voice.py | 109 +++++++++++++++++++++++++++++ 4 files changed, 123 insertions(+), 6 deletions(-) create mode 100644 tests/voice/__init__.py create mode 100644 tests/voice/test_voice.py diff --git a/hippolyzer/lib/voice/client.py b/hippolyzer/lib/voice/client.py index fadabf0..3f5ead1 100644 --- a/hippolyzer/lib/voice/client.py +++ b/hippolyzer/lib/voice/client.py @@ -44,7 +44,7 @@ class VoiceEvent(UserDict): super().__init__(event_dict) -class VoiceClient(object): +class VoiceClient: SERVER_URL = "http://www.bhr.vivox.com/api2/" # noqa def __init__(self, host, port): @@ -332,7 +332,7 @@ class VoiceClient(object): RESP_LOG.debug("%s %s %s %r" % ("Request", request_id, msg_type, data)) - asyncio.get_event_loop().create_task(self.vivox_conn.send(request_id, msg_type, data)) + asyncio.get_event_loop().create_task(self.vivox_conn.send_request(request_id, msg_type, data)) future = asyncio.Future() self._pending_req_futures[request_id] = future return future diff --git a/hippolyzer/lib/voice/connection.py b/hippolyzer/lib/voice/connection.py index 72c9d06..2e1b16f 100644 --- a/hippolyzer/lib/voice/connection.py +++ b/hippolyzer/lib/voice/connection.py @@ -90,15 +90,20 @@ class VivoxConnection: async def read_messages(self): # TODO: handle interrupted read while self._reader and not self._reader.at_eof() and not self._writer.is_closing(): - msg = await self._reader.readuntil(b"\n\n\n") - yield self.parse(msg[:-3]) + yield await self.read_message() + + async def read_message(self): + msg = await self._reader.readuntil(b"\n\n\n") + return self.parse(msg[:-3]) def parse(self, raw_msg) -> Tuple[str, Optional[str], Optional[str], dict]: parsed_msg = defusedxml.lxml.fromstring(raw_msg.decode("utf8")) msg_type = parsed_msg.tag request_id = parsed_msg.attrib.get("requestId", None) - dict_msg = xml_to_dict(parsed_msg)[1] + # There may be no params, just use an empty dict if that's the case + dict_msg = xml_to_dict(parsed_msg)[1] or {} + if msg_type == "Event": msg_action = parsed_msg.attrib.get("type") elif msg_type == "Response": @@ -107,6 +112,9 @@ class VivoxConnection: if 'InputXml' in dict_msg: del dict_msg['InputXml'] dict_msg = _clean_message(msg_action, parsed_msg, dict_msg) + elif msg_type == "Request": + msg_action = parsed_msg.attrib.get("action") + dict_msg = _clean_message(msg_action, parsed_msg, dict_msg) else: raise Exception("Unknown Vivox message type %r?" % msg_type) return msg_type, msg_action, request_id, dict_msg @@ -118,7 +126,7 @@ class VivoxConnection: weakref.finalize(drain_coro, drain_coro.close) return drain_coro - def send(self, request_id: str, msg_type: str, data: Any) -> Coroutine[Any, Any, None]: + def send_request(self, request_id: str, msg_type: str, data: Any) -> Coroutine[Any, Any, None]: elem = lxml.etree.Element("Request") elem.attrib["requestId"] = request_id elem.attrib["action"] = msg_type diff --git a/tests/voice/__init__.py b/tests/voice/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/voice/test_voice.py b/tests/voice/test_voice.py new file mode 100644 index 0000000..f8dd7be --- /dev/null +++ b/tests/voice/test_voice.py @@ -0,0 +1,109 @@ +from typing import * + +import asyncio +import unittest +from unittest import mock + +from hippolyzer.lib.voice.connection import VivoxConnection + + +def _make_transport(buf: Any): + transport = mock.Mock() + transport.write.side_effect = buf.extend + transport.is_closing.return_value = False + return transport + + +def _make_protocol(transport: Any): + protocol = mock.Mock(transport=transport) + protocol._drain_helper = mock.AsyncMock() + return protocol + + +class TestVivoxConnection(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + self._writer_buf = bytearray() + self._transport = _make_transport(self._writer_buf) + self._protocol = _make_protocol(self._transport) + self.reader = asyncio.StreamReader() + self.writer = asyncio.StreamWriter(self._transport, self._protocol, self.reader, asyncio.get_event_loop()) + self.vivox_connection = VivoxConnection(self.reader, self.writer) + + async def test_read_request(self): + self.reader.feed_data( + b'1\n\n\n' + ) + self.reader.feed_eof() + msg_type, msg_action, request_id, body = await self.vivox_connection.read_message() + self.assertEqual("Request", msg_type) + self.assertEqual("Aux.GetRenderDevices.1", msg_action) + self.assertEqual("foobar", request_id) + self.assertDictEqual({"Foo": "1"}, body) + + async def test_read_response(self): + self.reader.feed_data( + b'0' + b'0' + b'\n\n\n' + ) + self.reader.feed_eof() + msg_type, msg_action, request_id, body = await self.vivox_connection.read_message() + self.assertEqual("Response", msg_type) + self.assertEqual("Connector.SetLocalMicVolume.1", msg_action) + self.assertEqual("foobar", request_id) + self.assertDictEqual( + {'ReturnCode': 0, 'Results': {'StatusCode': '0', 'StatusString': None}}, + body, + ) + + async def test_read_event(self): + self.reader.feed_data( + b'47' + b'0' + b'6Connectingfalse' + b'\n\n\n' + ) + self.reader.feed_eof() + msg_type, msg_action, request_id, body = await self.vivox_connection.read_message() + self.assertEqual("Event", msg_type) + self.assertEqual("MediaStreamUpdatedEvent", msg_action) + self.assertEqual(None, request_id) + self.assertDictEqual( + { + 'DurableMediaId': None, + 'Incoming': 'false', + 'SessionGroupHandle': '4', + 'SessionHandle': '7', + 'State': '6', + 'StateDescription': 'Connecting', + 'StatusCode': '0', + 'StatusString': None, + }, + body, + ) + + async def test_read_messages(self): + self.reader.feed_data( + b'1\n\n\n' + b'1\n\n\n' + ) + self.reader.feed_eof() + + i = 0 + async for msg_type, msg_action, request_id, body in self.vivox_connection.read_messages(): + if i == 0: + self.assertEqual("foobar", request_id) + else: + self.assertEqual("quux", request_id) + self.assertEqual("Request", msg_type) + + self.assertEqual("Aux.GetRenderDevices.1", msg_action) + self.assertDictEqual({"Foo": "1"}, body) + i += 1 + + async def test_send_message(self): + await self.vivox_connection.send_request("foo", "bar", {"baz": 1}) + self.assertEqual( + b'1\n\n\n', + self._writer_buf + )