Start writing voice client tests

This commit is contained in:
Salad Dais
2023-12-18 05:27:32 +00:00
parent 613dd32a40
commit 1165769aca
3 changed files with 113 additions and 29 deletions

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import asyncio
import base64
import json
@@ -324,7 +326,7 @@ class VoiceClient:
})
def send_message(self, msg_type: str, data: Any) -> asyncio.Future[dict]:
request_id = str(uuid.uuid4())
request_id = self._make_request_id()
# This is apparently what the viewer does, not clear if
# request_id has any semantic significance
if msg_type == "Session.Create.1":
@@ -352,20 +354,19 @@ class VoiceClient:
while not self.vivox_conn:
await asyncio.sleep(0.001)
async for msg_type, msg_action, request_id, dict_msg in self.vivox_conn.read_messages():
async for msg in self.vivox_conn.read_messages():
try:
if msg_type == "Event":
RESP_LOG.debug("%s %s %r" % ("Event", msg_action, dict_msg))
self.message_handler.handle(VoiceEvent(msg_action, dict_msg))
RESP_LOG.debug(repr(msg))
if msg.type == "Event":
self.message_handler.handle(VoiceEvent(msg.action, msg.data))
# Spin off handler tasks for each event so that we don't block polling
_ = asyncio.get_event_loop().create_task(self._dispatch_received_event(msg_action, dict_msg))
elif msg_type == "Response":
RESP_LOG.debug("%s %s %s %r" % ("Response", request_id, msg_action, dict_msg))
_ = asyncio.get_event_loop().create_task(self._dispatch_received_event(msg.action, msg.data))
elif msg.type == "Response":
# Might not have this request ID if it was sent directly via the socket
if request_id in self._pending_req_futures:
self._pending_req_futures[request_id].set_result(dict_msg)
del self._pending_req_futures[request_id]
if msg.request_id in self._pending_req_futures:
self._pending_req_futures[msg.request_id].set_result(msg.data)
del self._pending_req_futures[msg.request_id]
except Exception:
LOG.exception("Error in response handler?")
@@ -450,3 +451,6 @@ class VoiceClient:
if self._region_global_x is not None:
pos = self.region_pos
self.channel_info_updated.notify(pos)
def _make_request_id(self):
return str(uuid.uuid4())

View File

@@ -1,14 +1,22 @@
# TODO: some fancy parser that parses everything into
# dicts or objects using schemas.
from __future__ import annotations
import asyncio
import weakref
from typing import Any, Tuple, Optional, Coroutine
from typing import Any, Optional, Coroutine, NamedTuple
import defusedxml.lxml
import lxml.etree
class VivoxMessage(NamedTuple):
type: str
action: str
request_id: Optional[str]
data: dict
def xml_to_dict(element):
return element.tag, dict(map(xml_to_dict, element)) or element.text
@@ -96,7 +104,7 @@ class VivoxConnection:
msg = await self._reader.readuntil(b"\n\n\n")
return self.parse(msg[:-3])
def parse(self, raw_msg) -> Tuple[str, Optional[str], Optional[str], dict]:
def parse(self, raw_msg) -> VivoxMessage:
parsed_msg = defusedxml.lxml.fromstring(raw_msg.decode("utf8"))
msg_type = parsed_msg.tag
request_id = parsed_msg.attrib.get("requestId", None)
@@ -114,10 +122,9 @@ class VivoxConnection:
dict_msg = _clean_message(msg_action, parsed_msg, dict_msg)
elif msg_type == "Request":
msg_action = parsed_msg.attrib.get("action")
dict_msg = _clean_message(msg_action, parsed_msg, dict_msg)
else:
raise Exception("Unknown Vivox message type %r?" % msg_type)
return msg_type, msg_action, request_id, dict_msg
return VivoxMessage(msg_type, msg_action, request_id, dict_msg)
def send_raw(self, buf: bytes) -> Coroutine[Any, Any, None]:
self._writer.write(buf + b"\n\n\n")
@@ -126,12 +133,24 @@ class VivoxConnection:
weakref.finalize(drain_coro, drain_coro.close)
return drain_coro
def send_request(self, request_id: str, msg_type: str, data: Any) -> Coroutine[Any, Any, None]:
elem = lxml.etree.Element("Request")
elem.attrib["requestId"] = request_id
elem.attrib["action"] = msg_type
if msg_type == "Account.WebCall.1":
def send_request(self, request_id: str, action: str, data: Any) -> Coroutine[Any, Any, None]:
if action == "Account.WebCall.1":
data = dict(data)
data["Parameters"] = _build_webcall_params(data["Parameters"])
return self._send_request_response("Request", request_id, action, data)
def send_response(self, request_id: str, action: str, data: Any) -> Coroutine[Any, Any, None]:
return self._send_request_response("Response", request_id, action, data)
def _send_request_response(self, msg_type: str, request_id: str, action: str, data: Any):
elem = lxml.etree.Element(msg_type)
elem.attrib["requestId"] = request_id
elem.attrib["action"] = action
serialized = lxml.etree.tostring(buildxml(elem, data))
return self.send_raw(serialized)
def send_event(self, event_type: str, data: Any) -> Coroutine[Any, Any, None]:
elem = lxml.etree.Element("Event")
elem.attrib["type"] = event_type
serialized = lxml.etree.tostring(buildxml(elem, data))
return self.send_raw(serialized)

View File

@@ -4,12 +4,13 @@ import asyncio
import unittest
from unittest import mock
from hippolyzer.lib.voice.client import VoiceClient
from hippolyzer.lib.voice.connection import VivoxConnection
def _make_transport(buf: Any):
def _make_transport(write_func):
transport = mock.Mock()
transport.write.side_effect = buf.extend
transport.write.side_effect = write_func
transport.is_closing.return_value = False
return transport
@@ -23,7 +24,7 @@ def _make_protocol(transport: Any):
class TestVivoxConnection(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
self._writer_buf = bytearray()
self._transport = _make_transport(self._writer_buf)
self._transport = _make_transport(self._writer_buf.extend)
self._protocol = _make_protocol(self._transport)
self.reader = asyncio.StreamReader()
self.writer = asyncio.StreamWriter(self._transport, self._protocol, self.reader, asyncio.get_event_loop())
@@ -90,15 +91,15 @@ class TestVivoxConnection(unittest.IsolatedAsyncioTestCase):
self.reader.feed_eof()
i = 0
async for msg_type, msg_action, request_id, body in self.vivox_connection.read_messages():
async for msg in self.vivox_connection.read_messages():
if i == 0:
self.assertEqual("foobar", request_id)
self.assertEqual("foobar", msg.request_id)
else:
self.assertEqual("quux", request_id)
self.assertEqual("Request", msg_type)
self.assertEqual("quux", msg.request_id)
self.assertEqual("Aux.GetRenderDevices.1", msg_action)
self.assertDictEqual({"Foo": "1"}, body)
self.assertEqual("Request", msg.type)
self.assertEqual("Aux.GetRenderDevices.1", msg.action)
self.assertDictEqual({"Foo": "1"}, msg.data)
i += 1
async def test_send_message(self):
@@ -107,3 +108,63 @@ class TestVivoxConnection(unittest.IsolatedAsyncioTestCase):
b'<Request requestId="foo" action="bar"><baz>1</baz></Request>\n\n\n',
self._writer_buf
)
class TestVoiceClient(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
self._client_transport = _make_transport(
lambda *args: asyncio.get_event_loop().call_soon(self.server_reader.feed_data, *args)
)
self._client_protocol = _make_protocol(self._client_transport)
self.client_reader = asyncio.StreamReader()
self.client_writer = asyncio.StreamWriter(
self._client_transport,
self._client_protocol,
self.client_reader,
asyncio.get_event_loop()
)
self._server_transport = _make_transport(
lambda *args: asyncio.get_event_loop().call_soon(self.client_reader.feed_data, *args)
)
self._server_protocol = _make_protocol(self._server_transport)
self.server_reader = asyncio.StreamReader()
self.server_writer = asyncio.StreamWriter(
self._server_transport,
self._server_protocol,
self.server_reader,
asyncio.get_event_loop()
)
self.client_connection = VivoxConnection(self.client_reader, self.client_writer)
self.server_connection = VivoxConnection(self.server_reader, self.server_writer)
self.client = VoiceClient("127.0.0.1", 0)
self.client.vivox_conn = self.client_connection
def _make_request_id():
_make_request_id.i += 1
return str(_make_request_id.i)
_make_request_id.i = 0
self.client._make_request_id = _make_request_id
async def test_connection(self):
async def _serve_login():
await self.server_connection.send_event(
"VoiceServiceConnectionStateChangedEvent",
{
"Connected": 1,
"Platform": "Linux",
"Version": 1,
"DataDirectory": "/tmp/whatever",
}
)
self.assertEqual(
('Request', 'Aux.GetCaptureDevices.1', '1', {}),
await self.server_connection.read_message()
)
serve_coro = asyncio.get_event_loop().create_task(_serve_login())
# Await this here so we can see any exceptions
await serve_coro