From ecb14197cf94bb91962e856ccb00d3ced256cdb1 Mon Sep 17 00:00:00 2001 From: Salad Dais Date: Thu, 9 Dec 2021 01:14:09 +0000 Subject: [PATCH] Make message log filter highlight every matched field Previously only the first match was being highlighted. --- hippolyzer/apps/proxy_gui.py | 2 +- hippolyzer/lib/proxy/message_filter.py | 31 +++++++++++++++----------- hippolyzer/lib/proxy/message_logger.py | 29 ++++++++++++++---------- 3 files changed, 36 insertions(+), 26 deletions(-) diff --git a/hippolyzer/apps/proxy_gui.py b/hippolyzer/apps/proxy_gui.py index 79a6fbe..819ba3d 100644 --- a/hippolyzer/apps/proxy_gui.py +++ b/hippolyzer/apps/proxy_gui.py @@ -364,7 +364,7 @@ class MessageLogWindow(QtWidgets.QMainWindow): # The string has a map of fields and their associated positions within the string, # use that to highlight any individual fields the filter matched on. if isinstance(req, SpannedString): - for field in self.model.filter.match(entry).fields: + for field in self.model.filter.match(entry, short_circuit=False).fields: field_span = req.spans.get(field) if not field_span: continue diff --git a/hippolyzer/lib/proxy/message_filter.py b/hippolyzer/lib/proxy/message_filter.py index 28e05b3..b1e8067 100644 --- a/hippolyzer/lib/proxy/message_filter.py +++ b/hippolyzer/lib/proxy/message_filter.py @@ -79,7 +79,7 @@ class MatchResult(typing.NamedTuple): class BaseFilterNode(abc.ABC): @abc.abstractmethod - def match(self, msg) -> MatchResult: + def match(self, msg, short_circuit=True) -> MatchResult: raise NotImplementedError() @property @@ -109,28 +109,33 @@ class BinaryFilterNode(BaseFilterNode, abc.ABC): class UnaryNotFilterNode(UnaryFilterNode): - def match(self, msg) -> MatchResult: + def match(self, msg, short_circuit=True) -> MatchResult: # Should we pass fields up here? Maybe not. - return MatchResult(not self.node.match(msg), []) + return MatchResult(not self.node.match(msg, short_circuit), []) class OrFilterNode(BinaryFilterNode): - def match(self, msg) -> MatchResult: - left_match = self.left_node.match(msg) - if left_match: + def match(self, msg, short_circuit=True) -> MatchResult: + left_match = self.left_node.match(msg, short_circuit) + if left_match and short_circuit: return MatchResult(True, left_match.fields) - right_match = self.right_node.match(msg) - if right_match: + + right_match = self.right_node.match(msg, short_circuit) + if right_match and short_circuit: return MatchResult(True, right_match.fields) + + if left_match or right_match: + # Fine since fields should be empty when result=False + return MatchResult(True, left_match.fields + right_match.fields) return MatchResult(False, []) class AndFilterNode(BinaryFilterNode): - def match(self, msg) -> MatchResult: - left_match = self.left_node.match(msg) + def match(self, msg, short_circuit=True) -> MatchResult: + left_match = self.left_node.match(msg, short_circuit) if not left_match: return MatchResult(False, []) - right_match = self.right_node.match(msg) + right_match = self.right_node.match(msg, short_circuit) if not right_match: return MatchResult(False, []) return MatchResult(True, left_match.fields + right_match.fields) @@ -142,8 +147,8 @@ class MessageFilterNode(BaseFilterNode): self.operator = operator self.value = value - def match(self, msg) -> MatchResult: - return msg.matches(self) + def match(self, msg, short_circuit=True) -> MatchResult: + return msg.matches(self, short_circuit) @property def children(self): diff --git a/hippolyzer/lib/proxy/message_logger.py b/hippolyzer/lib/proxy/message_logger.py index 2e89798..babffc1 100644 --- a/hippolyzer/lib/proxy/message_logger.py +++ b/hippolyzer/lib/proxy/message_logger.py @@ -366,7 +366,7 @@ class AbstractMessageLogEntry(abc.ABC): return self._val_matches(matcher.operator, self._get_meta(matcher.selector[1]), matcher.value) return None - def matches(self, matcher: "MessageFilterNode") -> "MatchResult": + def matches(self, matcher: "MessageFilterNode", short_circuit=True) -> "MatchResult": return MatchResult(self._base_matches(matcher) or False, []) @property @@ -671,7 +671,7 @@ class LLUDPMessageLogEntry(AbstractMessageLogEntry): def request(self, beautify=False, replacements=None): return HumanMessageSerializer.to_human_string(self.message, replacements, beautify) - def matches(self, matcher) -> "MatchResult": + def matches(self, matcher, short_circuit=True) -> "MatchResult": base_matched = self._base_matches(matcher) if base_matched is not None: return MatchResult(base_matched, []) @@ -685,6 +685,7 @@ class LLUDPMessageLogEntry(AbstractMessageLogEntry): # name, block_name, var_name(, subfield_name)? if selector_len not in (3, 4): return MatchResult(False, []) + found_field_keys = [] for block_name in message.blocks: if not fnmatch.fnmatchcase(block_name, matcher.selector[1]): continue @@ -697,11 +698,9 @@ class LLUDPMessageLogEntry(AbstractMessageLogEntry): if selector_len == 3: # We're just matching on the var existing, not having any particular value if matcher.value is None: - # TODO: Ability to disable short-circuiting when matching for display - # purposes, it's helpful to see every match in the message. - return MatchResult(True, [field_key]) - if self._val_matches(matcher.operator, block[var_name], matcher.value): - return MatchResult(True, [field_key]) + found_field_keys.append(field_key) + elif self._val_matches(matcher.operator, block[var_name], matcher.value): + found_field_keys.append(field_key) # Need to invoke a special unpacker elif selector_len == 4: try: @@ -712,15 +711,21 @@ class LLUDPMessageLogEntry(AbstractMessageLogEntry): if isinstance(deserialized, TaggedUnion): deserialized = deserialized.value if not isinstance(deserialized, dict): - return MatchResult(False, []) + continue for key in deserialized.keys(): if fnmatch.fnmatchcase(str(key), matcher.selector[3]): if matcher.value is None: - return MatchResult(True, [field_key]) - if self._val_matches(matcher.operator, deserialized[key], matcher.value): - return MatchResult(True, [field_key]) + # Short-circuiting checking individual subfields is fine since + # we only highlight fields anyway. + found_field_keys.append(field_key) + break + elif self._val_matches(matcher.operator, deserialized[key], matcher.value): + found_field_keys.append(field_key) + break - return MatchResult(False, []) + if short_circuit and found_field_keys: + return MatchResult(True, found_field_keys) + return MatchResult(bool(found_field_keys), found_field_keys) @property def summary(self):