Files
Hippolyzer/hippolyzer/lib/base/network/caps_client.py
Salad Dais 49f7ba960f Move tons more things to lib.base and lib.client
Put an abstract session and region implementation in client so things
that could be logically shared between client/proxy can be.

ObjectManager moved to client with proxy-specific details in
ProxyObjectManager.
2021-06-04 09:31:54 +00:00

173 lines
7.0 KiB
Python

from __future__ import annotations
import asyncio
import copy
import dataclasses
from types import TracebackType
from typing import *
import aiohttp
import multidict
from hippolyzer.lib.base import llsd as llsd_lib
class CapsClientResponse(aiohttp.ClientResponse):
"""
Not actually instantiated, used for lying to the type system
since we'll dynamically put this onto a ClientResponse instance
Will fail isinstance().
"""
async def read_llsd(self) -> Any:
raise NotImplementedError()
class _HippoSessionRequestContextManager:
"""
_SessionRequestContextManager but with a symmetrical API
aiohttp.request() and aiohttp.ClientSession.request() have different APIs.
One is sync returning a context manager, one is async returning a coro.
aiohttp.request() also doesn't accept the arguments that we need for custom
SSL contexts. To deal with requests that have existing sessions and those without,
just give them both the same wrapper and don't close the session on context manager
exit if it wasn't our session.
"""
__slots__ = ("_coro", "_resp", "_session", "_session_owned")
def __init__(
self,
coro: Coroutine[asyncio.Future[Any], None, aiohttp.ClientResponse],
session: aiohttp.ClientSession,
session_owned: bool = True,
) -> None:
self._coro = coro
self._resp: Optional[aiohttp.ClientResponse] = None
self._session = session
self._session_owned = session_owned
async def __aenter__(self) -> CapsClientResponse:
try:
self._resp = await self._coro
# We don't control creation of the ClientResponse, so tack on
# a convenience method for reading LLSD.
async def _read_llsd():
return llsd_lib.parse_xml(await self._resp.read())
self._resp.read_llsd = _read_llsd
except BaseException:
if self._session_owned:
await self._session.close()
raise
else:
# intentionally fooling the type system
return self._resp # type: ignore
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> None:
assert self._resp is not None
self._resp.close()
if self._session_owned:
await self._session.close()
CAPS_DICT = Union[
Mapping[str, str],
multidict.MultiDict[str],
]
class CapsClient:
def __init__(self, caps: Optional[CAPS_DICT] = None):
self._caps = caps
def _request_fixups(self, cap_or_url: str, headers: Dict, proxy: Optional[bool], ssl: Any):
return cap_or_url, headers, proxy, ssl
def _get_caps(self) -> Optional[CAPS_DICT]:
return self._caps
def request(self, method: str, cap_or_url: str, *, path: str = "", data: Any = None,
headers: Optional[Dict] = None, session: Optional[aiohttp.ClientSession] = None,
llsd: Any = dataclasses.MISSING, params: Optional[Dict[str, Any]] = None,
proxy: Optional[str] = None, skip_auto_headers: Optional[Sequence[str]] = None,
**kwargs) -> _HippoSessionRequestContextManager:
if cap_or_url.startswith("http"):
if path:
raise ValueError("Specifying both path and a full URL not supported")
else:
caps = self._get_caps()
if caps is None:
raise RuntimeError(f"Need a caps dict to request a Cap like {cap_or_url}")
if cap_or_url not in caps:
raise KeyError(f"{cap_or_url} is not a full URL and not a Cap")
cap_or_url = caps[cap_or_url]
if path:
cap_or_url += path
if params is not None:
for pname, pval in params.items():
if not isinstance(pval, str):
params[pname] = str(pval)
session_owned = False
# Use an existing session if we have one to take advantage of connection pooling
# otherwise create one
if session is None:
session_owned = True
session = aiohttp.ClientSession(
connector=aiohttp.TCPConnector(force_close=True),
connector_owner=True
)
if headers is None:
headers = {}
else:
headers = copy.copy(headers)
# Use sentinel val so explicit `None` can be passed
if llsd is not dataclasses.MISSING:
data = llsd_lib.format_xml(llsd)
# Sometimes needed even on GETs.
if "Content-Type" not in headers:
headers["Content-Type"] = "application/llsd+xml"
# Always present, usually ignored by the server.
if "Accept" not in headers:
headers["Accept"] = "application/llsd+xml"
# Ask to keep the connection open if we're sharing a session
if not session_owned:
headers["Connection"] = "keep-alive"
headers["Keep-alive"] = "300"
ssl = kwargs.pop('ssl', None)
cap_or_url, headers, proxy, ssl = self._request_fixups(cap_or_url, headers, proxy, ssl)
resp = session._request(method, cap_or_url, data=data, headers=headers, # noqa: need internal call
params=params, ssl=ssl, proxy=proxy,
skip_auto_headers=skip_auto_headers or ("User-Agent",), **kwargs)
return _HippoSessionRequestContextManager(resp, session, session_owned=session_owned)
def get(self, cap_or_url: str, *, path: str = "", headers: Optional[dict] = None,
session: Optional[aiohttp.ClientSession] = None, params: Optional[Dict[str, Any]] = None,
proxy: Optional[str] = None, **kwargs) -> _HippoSessionRequestContextManager:
return self.request("GET", cap_or_url=cap_or_url, path=path, headers=headers,
session=session, params=params, proxy=proxy, **kwargs)
def post(self, cap_or_url: str, *, path: str = "", data: Any = None,
headers: Optional[dict] = None, session: Optional[aiohttp.ClientSession] = None,
llsd: Any = dataclasses.MISSING, params: Optional[Dict[str, Any]] = None,
proxy: Optional[str] = None, **kwargs) -> _HippoSessionRequestContextManager:
return self.request("POST", cap_or_url=cap_or_url, path=path, headers=headers, data=data,
llsd=llsd, session=session, params=params, proxy=proxy, **kwargs)
def put(self, cap_or_url: str, *, path: str = "", data: Any = None,
headers: Optional[dict] = None, session: Optional[aiohttp.ClientSession] = None,
llsd: Any = dataclasses.MISSING, params: Optional[Dict[str, Any]] = None,
proxy: Optional[str] = None, **kwargs) -> _HippoSessionRequestContextManager:
return self.request("PUT", cap_or_url=cap_or_url, path=path, headers=headers, data=data,
llsd=llsd, session=session, params=params, proxy=proxy, **kwargs)