From 9df0224fbf46ff2fb41efc7a0872d78e10aaad8a Mon Sep 17 00:00:00 2001 From: Salad Dais Date: Thu, 3 Jun 2021 08:01:28 +0000 Subject: [PATCH] Split CapsClient into proxy and non-proxy version --- hippolyzer/apps/proxy_gui.py | 4 +- hippolyzer/lib/base/network/caps_client.py | 171 +++++++++++++++++++++ hippolyzer/lib/proxy/caps_client.py | 165 +------------------- hippolyzer/lib/proxy/region.py | 6 +- tests/proxy/test_capsclient.py | 3 +- 5 files changed, 184 insertions(+), 165 deletions(-) create mode 100644 hippolyzer/lib/base/network/caps_client.py diff --git a/hippolyzer/apps/proxy_gui.py b/hippolyzer/apps/proxy_gui.py index 750f236..8ea40fc 100644 --- a/hippolyzer/apps/proxy_gui.py +++ b/hippolyzer/apps/proxy_gui.py @@ -39,7 +39,7 @@ import hippolyzer.lib.base.serialization as se from hippolyzer.lib.base.network.transport import Direction, WrappingUDPTransport from hippolyzer.lib.proxy.addons import BaseInteractionManager, AddonManager from hippolyzer.lib.proxy.ca_utils import setup_ca_everywhere -from hippolyzer.lib.proxy.caps_client import CapsClient +from hippolyzer.lib.proxy.caps_client import ProxyCapsClient from hippolyzer.lib.proxy.http_proxy import create_proxy_master, HTTPFlowContext from hippolyzer.lib.proxy.message_logger import LLUDPMessageLogEntry, AbstractMessageLogEntry from hippolyzer.lib.proxy.region import ProxiedRegion @@ -719,7 +719,7 @@ class MessageBuilderWindow(QtWidgets.QMainWindow): return val def _sendHTTPRequest(self, method, uri, headers, body): - caps_client = CapsClient() + caps_client = ProxyCapsClient() async def _send_request(): req = caps_client.request(method, uri, headers=headers, data=body) diff --git a/hippolyzer/lib/base/network/caps_client.py b/hippolyzer/lib/base/network/caps_client.py new file mode 100644 index 0000000..93c7c1e --- /dev/null +++ b/hippolyzer/lib/base/network/caps_client.py @@ -0,0 +1,171 @@ +from __future__ import annotations + +import asyncio +import copy +import dataclasses +from types import TracebackType +from typing import * + +import aiohttp +import multidict + +from hippolyzer.lib.base import llsd as llsd_lib + + +class CapsClientResponse(aiohttp.ClientResponse): + """ + Not actually instantiated, used for lying to the type system + since we'll dynamically put this onto a ClientResponse instance + Will fail isinstance(). + """ + async def read_llsd(self) -> Any: + raise NotImplementedError() + + +class _HippoSessionRequestContextManager: + """ + _SessionRequestContextManager but with a symmetrical API + + aiohttp.request() and aiohttp.ClientSession.request() have different APIs. + One is sync returning a context manager, one is async returning a coro. + aiohttp.request() also doesn't accept the arguments that we need for custom + SSL contexts. To deal with requests that have existing sessions and those without, + just give them both the same wrapper and don't close the session on context manager + exit if it wasn't our session. + """ + __slots__ = ("_coro", "_resp", "_session", "_session_owned") + + def __init__( + self, + coro: Coroutine[asyncio.Future[Any], None, aiohttp.ClientResponse], + session: aiohttp.ClientSession, + session_owned: bool = True, + ) -> None: + self._coro = coro + self._resp: Optional[aiohttp.ClientResponse] = None + self._session = session + self._session_owned = session_owned + + async def __aenter__(self) -> CapsClientResponse: + try: + self._resp = await self._coro + + # We don't control creation of the ClientResponse, so tack on + # a convenience method for reading LLSD. + async def _read_llsd(): + return llsd_lib.parse_xml(await self._resp.read()) + self._resp.read_llsd = _read_llsd + except BaseException: + if self._session_owned: + await self._session.close() + raise + else: + # intentionally fooling the type system + return self._resp # type: ignore + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> None: + assert self._resp is not None + self._resp.close() + if self._session_owned: + await self._session.close() + + +CAPS_DICT = Union[ + Mapping[str, str], + multidict.MultiDict[Tuple[Any, str]] +] + + +class CapsClient: + def __init__(self, caps: Optional[CAPS_DICT] = None): + self._caps = caps + + def _request_fixups(self, cap_or_url: str, headers: Dict, proxy: Optional[bool], ssl: Any): + return cap_or_url, headers, proxy, ssl + + def request(self, method: str, cap_or_url: str, *, path: str = "", data: Any = None, + headers: Optional[Dict] = None, session: Optional[aiohttp.ClientSession] = None, + llsd: Any = dataclasses.MISSING, params: Optional[Dict[str, Any]] = None, + proxy: Optional[str] = None, skip_auto_headers: Optional[Sequence[str]] = None, + **kwargs) -> _HippoSessionRequestContextManager: + if cap_or_url.startswith("http"): + if path: + raise ValueError("Specifying both path and a full URL not supported") + else: + if self._caps is None: + raise RuntimeError(f"Need a caps dict to request a Cap like {cap_or_url}") + if cap_or_url not in self._caps: + raise KeyError(f"{cap_or_url} is not a full URL and not a Cap") + cap_or_url = self._caps[cap_or_url] + # Stupid hack for proxy multidicts that contain a tuple of `(cap_type, cap_url)` + if isinstance(cap_or_url, tuple): + cap_or_url = cap_or_url[-1] + if path: + cap_or_url += path + + if params is not None: + for pname, pval in params.items(): + if not isinstance(pval, str): + params[pname] = str(pval) + + session_owned = False + # Use an existing session if we have one to take advantage of connection pooling + # otherwise create one + if session is None: + session_owned = True + session = aiohttp.ClientSession( + connector=aiohttp.TCPConnector(force_close=True), + connector_owner=True + ) + + if headers is None: + headers = {} + else: + headers = copy.copy(headers) + + # Use sentinel val so explicit `None` can be passed + if llsd is not dataclasses.MISSING: + data = llsd_lib.format_xml(llsd) + # Sometimes needed even on GETs. + if "Content-Type" not in headers: + headers["Content-Type"] = "application/llsd+xml" + # Always present, usually ignored by the server. + if "Accept" not in headers: + headers["Accept"] = "application/llsd+xml" + # Ask to keep the connection open if we're sharing a session + if not session_owned: + headers["Connection"] = "keep-alive" + headers["Keep-alive"] = "300" + + ssl = kwargs.pop('ssl', None) + cap_or_url, headers, proxy, ssl = self._request_fixups(cap_or_url, headers, proxy, ssl) + + resp = session._request(method, cap_or_url, data=data, headers=headers, # noqa: need internal call + params=params, ssl=ssl, proxy=proxy, + skip_auto_headers=skip_auto_headers or ("User-Agent",), **kwargs) + return _HippoSessionRequestContextManager(resp, session, session_owned=session_owned) + + def get(self, cap_or_url: str, *, path: str = "", headers: Optional[dict] = None, + session: Optional[aiohttp.ClientSession] = None, params: Optional[Dict[str, Any]] = None, + proxy: Optional[str] = None, **kwargs) -> _HippoSessionRequestContextManager: + return self.request("GET", cap_or_url=cap_or_url, path=path, headers=headers, + session=session, params=params, proxy=proxy, **kwargs) + + def post(self, cap_or_url: str, *, path: str = "", data: Any = None, + headers: Optional[dict] = None, session: Optional[aiohttp.ClientSession] = None, + llsd: Any = dataclasses.MISSING, params: Optional[Dict[str, Any]] = None, + proxy: Optional[str] = None, **kwargs) -> _HippoSessionRequestContextManager: + return self.request("POST", cap_or_url=cap_or_url, path=path, headers=headers, data=data, + llsd=llsd, session=session, params=params, proxy=proxy, **kwargs) + + def put(self, cap_or_url: str, *, path: str = "", data: Any = None, + headers: Optional[dict] = None, session: Optional[aiohttp.ClientSession] = None, + llsd: Any = dataclasses.MISSING, params: Optional[Dict[str, Any]] = None, + proxy: Optional[str] = None, **kwargs) -> _HippoSessionRequestContextManager: + return self.request("PUT", cap_or_url=cap_or_url, path=path, headers=headers, data=data, + llsd=llsd, session=session, params=params, proxy=proxy, **kwargs) diff --git a/hippolyzer/lib/proxy/caps_client.py b/hippolyzer/lib/proxy/caps_client.py index 5423f50..91fda36 100644 --- a/hippolyzer/lib/proxy/caps_client.py +++ b/hippolyzer/lib/proxy/caps_client.py @@ -1,147 +1,20 @@ from __future__ import annotations -import asyncio -import copy -import dataclasses import os import re import sys -from types import TracebackType from typing import * -import aiohttp - -from hippolyzer.lib.base import llsd as llsd_lib -from hippolyzer.lib.base.helpers import proxify - -if TYPE_CHECKING: - from hippolyzer.lib.proxy.region import ProxiedRegion +from hippolyzer.lib.base.network.caps_client import CapsClient -class CapsClientResponse(aiohttp.ClientResponse): - """ - Not actually instantiated, used for lying to the type system - since we'll dynamically put this onto a ClientResponse instance - Will fail isinstance(). - """ - async def read_llsd(self) -> Any: - raise NotImplementedError() - - -class _HippoSessionRequestContextManager: - """ - _SessionRequestContextManager but with a symmetrical API - - aiohttp.request() and aiohttp.ClientSession.request() have different APIs. - One is sync returning a context manager, one is async returning a coro. - aiohttp.request() also doesn't accept the arguments that we need for custom - SSL contexts. To deal with requests that have existing sessions and those without, - just give them both the same wrapper and don't close the session on context manager - exit if it wasn't our session. - """ - __slots__ = ("_coro", "_resp", "_session", "_session_owned") - - def __init__( - self, - coro: Coroutine[asyncio.Future[Any], None, aiohttp.ClientResponse], - session: aiohttp.ClientSession, - session_owned: bool = True, - ) -> None: - self._coro = coro - self._resp: Optional[aiohttp.ClientResponse] = None - self._session = session - self._session_owned = session_owned - - async def __aenter__(self) -> CapsClientResponse: - try: - self._resp = await self._coro - - # We don't control creation of the ClientResponse, so tack on - # a convenience method for reading LLSD. - async def _read_llsd(): - return llsd_lib.parse_xml(await self._resp.read()) - self._resp.read_llsd = _read_llsd - except BaseException: - if self._session_owned: - await self._session.close() - raise - else: - # intentionally fooling the type system - return self._resp # type: ignore - - async def __aexit__( - self, - exc_type: Optional[Type[BaseException]], - exc: Optional[BaseException], - tb: Optional[TracebackType], - ) -> None: - assert self._resp is not None - self._resp.close() - if self._session_owned: - await self._session.close() - - -class CapsClient: - def __init__(self, region: Optional[ProxiedRegion] = None): - self._region: Optional[ProxiedRegion] = proxify(region) - - def request(self, method: str, cap_or_url: str, *, path: str = "", data: Any = None, - headers: Optional[dict] = None, session: Optional[aiohttp.ClientSession] = None, - llsd: Any = dataclasses.MISSING, params: Optional[Dict[str, Any]] = None, - proxy: Optional[str] = None, skip_auto_headers: Optional[Sequence[str]] = None, - **kwargs) -> _HippoSessionRequestContextManager: - if cap_or_url.startswith("http"): - if path: - raise ValueError("Specifying both path and a full URL not supported") - else: - if self._region is None: - raise RuntimeError(f"Need a region to request a Cap like {cap_or_url}") - if cap_or_url not in self._region.caps: - raise KeyError(f"{cap_or_url} is not a full URL and not a Cap") - cap_or_url = self._region.caps[cap_or_url] - if path: - cap_or_url += path - - if params is not None: - for pname, pval in params.items(): - if not isinstance(pval, str): - params[pname] = str(pval) - - session_owned = False - # Use an existing session if we have one to take advantage of connection pooling - # otherwise create one - if session is None: - session_owned = True - session = aiohttp.ClientSession( - connector=aiohttp.TCPConnector(force_close=True), - connector_owner=True - ) - - if headers is None: - headers = {} - else: - headers = copy.copy(headers) - - # Use sentinel val so explicit `None` can be passed - if llsd is not dataclasses.MISSING: - data = llsd_lib.format_xml(llsd) - # Sometimes needed even on GETs. - if "Content-Type" not in headers: - headers["Content-Type"] = "application/llsd+xml" - # Always present, usually ignored by the server. - if "Accept" not in headers: - headers["Accept"] = "application/llsd+xml" - # Ask to keep the connection open if we're sharing a session - if not session_owned: - headers["Connection"] = "keep-alive" - headers["Keep-alive"] = "300" - # We go through the proxy by default, tack on a header letting mitmproxy know the - # request came from us so we can tag the request as injected. The header will be popped - # off before passing through to the server. - ssl = kwargs.pop('ssl', None) +class ProxyCapsClient(CapsClient): + def _request_fixups(self, cap_or_url: str, headers: Dict, proxy: Optional[bool], ssl: Any): # We want to proxy this through Hippolyzer if proxy is None: - # Always set this so we know this request was from the proxy + # We go through the proxy by default, tack on a header letting mitmproxy know the + # request came from us so we can tag the request as injected. The header will be popped + # off before passing through to the server. headers["X-Hippo-Injected"] = "1" # TODO: Have a setting for this proxy_port = int(os.environ.get("HIPPO_HTTP_PORT", 9062)) @@ -160,28 +33,4 @@ class CapsClient: if sys.platform == "win32" and cap_or_url.startswith("https:"): headers["X-Hippo-Windows-SSL-Hack"] = "1" cap_or_url = re.sub(r"^https:", "http:", cap_or_url) - - resp = session._request(method, cap_or_url, data=data, headers=headers, # noqa: need internal call - params=params, ssl=ssl, proxy=proxy, - skip_auto_headers=skip_auto_headers or ("User-Agent",), **kwargs) - return _HippoSessionRequestContextManager(resp, session, session_owned=session_owned) - - def get(self, cap_or_url: str, *, path: str = "", headers: Optional[dict] = None, - session: Optional[aiohttp.ClientSession] = None, params: Optional[Dict[str, Any]] = None, - proxy: Optional[str] = None, **kwargs) -> _HippoSessionRequestContextManager: - return self.request("GET", cap_or_url=cap_or_url, path=path, headers=headers, - session=session, params=params, proxy=proxy, **kwargs) - - def post(self, cap_or_url: str, *, path: str = "", data: Any = None, - headers: Optional[dict] = None, session: Optional[aiohttp.ClientSession] = None, - llsd: Any = dataclasses.MISSING, params: Optional[Dict[str, Any]] = None, - proxy: Optional[str] = None, **kwargs) -> _HippoSessionRequestContextManager: - return self.request("POST", cap_or_url=cap_or_url, path=path, headers=headers, data=data, - llsd=llsd, session=session, params=params, proxy=proxy, **kwargs) - - def put(self, cap_or_url: str, *, path: str = "", data: Any = None, - headers: Optional[dict] = None, session: Optional[aiohttp.ClientSession] = None, - llsd: Any = dataclasses.MISSING, params: Optional[Dict[str, Any]] = None, - proxy: Optional[str] = None, **kwargs) -> _HippoSessionRequestContextManager: - return self.request("PUT", cap_or_url=cap_or_url, path=path, headers=headers, data=data, - llsd=llsd, session=session, params=params, proxy=proxy, **kwargs) + return cap_or_url, headers, proxy, ssl diff --git a/hippolyzer/lib/proxy/region.py b/hippolyzer/lib/proxy/region.py index 818a442..c589754 100644 --- a/hippolyzer/lib/proxy/region.py +++ b/hippolyzer/lib/proxy/region.py @@ -11,9 +11,10 @@ import urllib.parse import multidict from hippolyzer.lib.base.datatypes import Vector3, UUID +from hippolyzer.lib.base.message.message import Message from hippolyzer.lib.base.message.message_handler import MessageHandler from hippolyzer.lib.base.objects import handle_to_global_pos -from hippolyzer.lib.proxy.caps_client import CapsClient +from hippolyzer.lib.proxy.caps_client import ProxyCapsClient from hippolyzer.lib.proxy.circuit import ProxiedCircuit from hippolyzer.lib.proxy.objects import ObjectManager from hippolyzer.lib.base.transfer_manager import TransferManager @@ -22,7 +23,6 @@ from hippolyzer.lib.base.xfer_manager import XferManager if TYPE_CHECKING: from hippolyzer.lib.proxy.sessions import Session from hippolyzer.lib.proxy.http_flow import HippoHTTPFlow - from hippolyzer.lib.base.message.message import Message class CapType(enum.Enum): @@ -60,7 +60,7 @@ class ProxiedRegion: self.message_handler: MessageHandler[Message] = MessageHandler() self.http_message_handler: MessageHandler[HippoHTTPFlow] = MessageHandler() self.eq_manager = EventQueueManager(self) - self.caps_client = CapsClient(self) + self.caps_client = ProxyCapsClient(self._caps) self.objects = ObjectManager(self, use_vo_cache=True) self._recalc_caps() diff --git a/tests/proxy/test_capsclient.py b/tests/proxy/test_capsclient.py index 5f3dda1..88493d3 100644 --- a/tests/proxy/test_capsclient.py +++ b/tests/proxy/test_capsclient.py @@ -2,7 +2,6 @@ import aiohttp import aioresponses from yarl import URL -from hippolyzer.lib.proxy.caps_client import CapsClient from hippolyzer.lib.proxy.region import ProxiedRegion from . import BaseProxyTest @@ -12,7 +11,7 @@ class TestCapsClient(BaseProxyTest): def setUp(self) -> None: super().setUp() self.region = ProxiedRegion(("127.0.0.1", 1), "", self.session) - self.caps_client = CapsClient(self.region) + self.caps_client = self.region.caps_client async def test_bare_url_works(self): with aioresponses.aioresponses() as m: