Allow filtering message log on HTTP headers
This commit is contained in:
@@ -3,7 +3,7 @@ import ast
|
||||
import typing
|
||||
|
||||
from arpeggio import Optional, ZeroOrMore, EOF, \
|
||||
ParserPython, PTNodeVisitor, visit_parse_tree, RegExMatch
|
||||
ParserPython, PTNodeVisitor, visit_parse_tree, RegExMatch, OneOrMore
|
||||
|
||||
|
||||
def literal():
|
||||
@@ -26,7 +26,9 @@ def literal():
|
||||
|
||||
|
||||
def identifier():
|
||||
return RegExMatch(r'[a-zA-Z*]([a-zA-Z0-9_*]+)?')
|
||||
# Identifiers are allowed to have "-". It's not a special character
|
||||
# in our grammar, and we expect them to show up some places, like header names.
|
||||
return RegExMatch(r'[a-zA-Z*]([a-zA-Z0-9_*-]+)?')
|
||||
|
||||
|
||||
def field_specifier():
|
||||
@@ -42,7 +44,7 @@ def unary_expression():
|
||||
|
||||
|
||||
def meta_field_specifier():
|
||||
return "Meta", ".", identifier
|
||||
return "Meta", OneOrMore(".", identifier)
|
||||
|
||||
|
||||
def enum_field_specifier():
|
||||
@@ -155,7 +157,7 @@ class MessageFilterNode(BaseFilterNode):
|
||||
return self.selector, self.operator, self.value
|
||||
|
||||
|
||||
class MetaFieldSpecifier(str):
|
||||
class MetaFieldSpecifier(tuple):
|
||||
pass
|
||||
|
||||
|
||||
@@ -181,7 +183,7 @@ class MessageFilterVisitor(PTNodeVisitor):
|
||||
return LiteralValue(ast.literal_eval(node.value))
|
||||
|
||||
def visit_meta_field_specifier(self, _node, children):
|
||||
return MetaFieldSpecifier(children[0])
|
||||
return MetaFieldSpecifier(children)
|
||||
|
||||
def visit_enum_field_specifier(self, _node, children):
|
||||
return EnumFieldSpecifier(*children)
|
||||
|
||||
@@ -235,7 +235,7 @@ class AbstractMessageLogEntry(abc.ABC):
|
||||
obj = self.region.objects.lookup_localid(selected_local)
|
||||
return obj and obj.FullID
|
||||
|
||||
def _get_meta(self, name: str):
|
||||
def _get_meta(self, name: str) -> typing.Any:
|
||||
# Slight difference in semantics. Filters are meant to return the same
|
||||
# thing no matter when they're run, so SelectedLocal and friends resolve
|
||||
# to the selected items _at the time the message was logged_. To handle
|
||||
@@ -308,7 +308,9 @@ class AbstractMessageLogEntry(abc.ABC):
|
||||
|
||||
def _val_matches(self, operator, val, expected):
|
||||
if isinstance(expected, MetaFieldSpecifier):
|
||||
expected = self._get_meta(str(expected))
|
||||
if len(expected) != 1:
|
||||
raise ValueError(f"Can only support single-level Meta specifiers, not {expected!r}")
|
||||
expected = self._get_meta(str(expected[0]))
|
||||
if not isinstance(expected, (int, float, bytes, str, type(None), tuple)):
|
||||
if callable(expected):
|
||||
expected = expected()
|
||||
@@ -362,8 +364,14 @@ class AbstractMessageLogEntry(abc.ABC):
|
||||
if matcher.value or matcher.operator:
|
||||
return False
|
||||
return self._packet_root_matches(matcher.selector[0])
|
||||
if len(matcher.selector) == 2 and matcher.selector[0] == "Meta":
|
||||
return self._val_matches(matcher.operator, self._get_meta(matcher.selector[1]), matcher.value)
|
||||
if matcher.selector[0] == "Meta":
|
||||
if len(matcher.selector) == 2:
|
||||
return self._val_matches(matcher.operator, self._get_meta(matcher.selector[1]), matcher.value)
|
||||
elif len(matcher.selector) == 3:
|
||||
meta_dict = self._get_meta(matcher.selector[1])
|
||||
if not meta_dict or not hasattr(meta_dict, 'get'):
|
||||
return False
|
||||
return self._val_matches(matcher.operator, meta_dict.get(matcher.selector[2]), matcher.value)
|
||||
return None
|
||||
|
||||
def matches(self, matcher: "MessageFilterNode", short_circuit=True) -> "MatchResult":
|
||||
@@ -541,6 +549,18 @@ class HTTPMessageLogEntry(AbstractMessageLogEntry):
|
||||
return "application/xml"
|
||||
return content_type
|
||||
|
||||
def _get_meta(self, name: str) -> typing.Any:
|
||||
lower_name = name.lower()
|
||||
if lower_name == "url":
|
||||
return self.flow.request.url
|
||||
elif lower_name == "reqheaders":
|
||||
return self.flow.request.headers
|
||||
elif lower_name == "respheaders":
|
||||
return self.flow.response.headers
|
||||
elif lower_name == "host":
|
||||
return self.flow.request.host.lower()
|
||||
return super()._get_meta(name)
|
||||
|
||||
def to_dict(self):
|
||||
val = super().to_dict()
|
||||
val['flow'] = self.flow.get_state()
|
||||
|
||||
@@ -141,6 +141,16 @@ class MessageFilterTests(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertTrue(self._filter_matches("FakeCap", entry))
|
||||
self.assertFalse(self._filter_matches("NotFakeCap", entry))
|
||||
|
||||
def test_http_header_filter(self):
|
||||
session_manager = SessionManager(ProxySettings())
|
||||
fake_flow = tflow.tflow(req=tutils.treq(), resp=tutils.tresp())
|
||||
fake_flow.request.headers["Cookie"] = 'foo="bar"'
|
||||
flow = HippoHTTPFlow.from_state(fake_flow.get_state(), session_manager)
|
||||
entry = HTTPMessageLogEntry(flow)
|
||||
# The header map is case-insensitive!
|
||||
self.assertTrue(self._filter_matches('Meta.ReqHeaders.cookie ~= "foo"', entry))
|
||||
self.assertFalse(self._filter_matches('Meta.ReqHeaders.foobar ~= "foo"', entry))
|
||||
|
||||
def test_export_import_http_flow(self):
|
||||
fake_flow = tflow.tflow(req=tutils.treq(), resp=tutils.tresp())
|
||||
fake_flow.metadata["cap_data_ser"] = SerializedCapData(
|
||||
|
||||
Reference in New Issue
Block a user