From 0f2e933be1cf53a3edb4f60735589fbd2068a8d9 Mon Sep 17 00:00:00 2001 From: Salad Dais Date: Wed, 20 Dec 2023 22:26:03 +0000 Subject: [PATCH] Make legacy input schema round-trip correctly --- hippolyzer/lib/base/inventory.py | 49 ++++++++++++++++++++-------- hippolyzer/lib/base/legacy_schema.py | 5 +-- hippolyzer/lib/base/serialization.py | 9 ++++- tests/base/test_legacy_schema.py | 3 ++ 4 files changed, 49 insertions(+), 17 deletions(-) diff --git a/hippolyzer/lib/base/inventory.py b/hippolyzer/lib/base/inventory.py index 9d26e2f..8dc6008 100644 --- a/hippolyzer/lib/base/inventory.py +++ b/hippolyzer/lib/base/inventory.py @@ -5,6 +5,7 @@ It's typically only used for object contents now. """ from __future__ import annotations +import abc import dataclasses import datetime as dt import logging @@ -111,9 +112,21 @@ class InventoryBase(SchemaBase): return cls._obj_from_dict(obj_dict) def to_writer(self, writer: StringIO): - writer.write(f"\t{self.SCHEMA_NAME}\t0\n") + writer.write(f"\t{self.SCHEMA_NAME}") + if self.SCHEMA_NAME == "permissions": + writer.write(" 0\n") + else: + writer.write("\t0\n") writer.write("\t{\n") - for field_name, field in self._get_fields_dict().items(): + + # Make sure the ID field always comes first, if there is one. + fields_dict = {} + if hasattr(self, "ID_ATTR"): + fields_dict = {getattr(self, "ID_ATTR"): None} + # update()ing will put all fields that aren't yet in the dict after the ID attr. + fields_dict.update(self._get_fields_dict()) + + for field_name, field in fields_dict.items(): spec = field.metadata.get("spec") # Not meant to be serialized if not spec: @@ -291,13 +304,20 @@ class InventorySaleInfo(InventoryBase): sale_price: int = schema_field(SchemaInt) -@dataclasses.dataclass -class InventoryNodeBase(InventoryBase): - ID_ATTR: ClassVar[str] - +class _HasName(abc.ABC): + """ + Only exists so that we can assert that all subclasses should have this without forcing + a particular serialization order, as would happen if this was present on InventoryNodeBase. + """ name: str + +@dataclasses.dataclass +class InventoryNodeBase(InventoryBase, _HasName): + ID_ATTR: ClassVar[str] + parent_id: Optional[UUID] = schema_field(SchemaUUID) + model: Optional[InventoryModel] = dataclasses.field( default=None, init=False, hash=False, compare=False, repr=False ) @@ -339,7 +359,6 @@ class InventoryNodeBase(InventoryBase): @dataclasses.dataclass class InventoryContainerBase(InventoryNodeBase): type: str = schema_field(SchemaStr) - name: str = schema_field(SchemaMultilineStr) @property def children(self) -> Sequence[InventoryNodeBase]: @@ -386,6 +405,7 @@ class InventoryObject(InventoryContainerBase): ID_ATTR: ClassVar[str] = "obj_id" obj_id: UUID = schema_field(SchemaUUID) + name: str = schema_field(SchemaMultilineStr) metadata: Optional[Dict[str, Any]] = schema_field(SchemaLLSD, default=None, include_none=True) __hash__ = InventoryNodeBase.__hash__ @@ -399,6 +419,7 @@ class InventoryCategory(InventoryContainerBase): cat_id: UUID = schema_field(SchemaUUID) pref_type: str = schema_field(SchemaStr, llsd_name="preferred_type") + name: str = schema_field(SchemaMultilineStr) owner_id: UUID = schema_field(SchemaUUID) version: int = schema_field(SchemaInt) metadata: Optional[Dict[str, Any]] = schema_field(SchemaLLSD, default=None, include_none=True) @@ -412,17 +433,17 @@ class InventoryItem(InventoryNodeBase): ID_ATTR: ClassVar[str] = "item_id" item_id: UUID = schema_field(SchemaUUID) - type: str = schema_field(SchemaStr) - inv_type: str = schema_field(SchemaStr) - flags: int = schema_field(SchemaFlagField) - name: str = schema_field(SchemaMultilineStr) - desc: str = schema_field(SchemaMultilineStr) - creation_date: dt.datetime = schema_field(SchemaDate, llsd_name="created_at") permissions: InventoryPermissions = schema_field(InventoryPermissions) - sale_info: InventorySaleInfo = schema_field(InventorySaleInfo) asset_id: Optional[UUID] = schema_field(SchemaUUID, default=None) shadow_id: Optional[UUID] = schema_field(SchemaUUID, default=None) + type: Optional[str] = schema_field(SchemaStr, default=None) + inv_type: Optional[str] = schema_field(SchemaStr, default=None) + flags: Optional[int] = schema_field(SchemaFlagField, default=None) + sale_info: Optional[InventorySaleInfo] = schema_field(InventorySaleInfo, default=None) + name: Optional[str] = schema_field(SchemaMultilineStr, default=None) + desc: Optional[str] = schema_field(SchemaMultilineStr, default=None) metadata: Optional[Dict[str, Any]] = schema_field(SchemaLLSD, default=None, include_none=True) + creation_date: Optional[dt.datetime] = schema_field(SchemaDate, llsd_name="created_at", default=None) __hash__ = InventoryNodeBase.__hash__ diff --git a/hippolyzer/lib/base/legacy_schema.py b/hippolyzer/lib/base/legacy_schema.py index 4249e6f..53c6392 100644 --- a/hippolyzer/lib/base/legacy_schema.py +++ b/hippolyzer/lib/base/legacy_schema.py @@ -116,11 +116,12 @@ class SchemaLLSD(SchemaFieldSerializer[_T]): """Arbitrary LLSD embedded in a field""" @classmethod def deserialize(cls, val: str) -> _T: - return llsd.parse_xml(val.encode("utf8")) + return llsd.parse_xml(val.partition("|")[0].encode("utf8")) @classmethod def serialize(cls, val: _T) -> str: - return llsd.format_xml(val).decode("utf8") + # Don't include the XML header + return llsd.format_xml(val).split(b">", 1)[1].decode("utf8") + "\n|" def schema_field(spec: Type[Union[SchemaBase, SchemaFieldSerializer]], *, default=dataclasses.MISSING, init=True, diff --git a/hippolyzer/lib/base/serialization.py b/hippolyzer/lib/base/serialization.py index 3ff3465..94c153d 100644 --- a/hippolyzer/lib/base/serialization.py +++ b/hippolyzer/lib/base/serialization.py @@ -1580,8 +1580,15 @@ def bitfield_field(bits: int, *, adapter: Optional[Adapter] = None, default=0, i class BitfieldDataclass(DataclassAdapter): - def __init__(self, data_cls: Type, + PRIM_SPEC: ClassVar[Optional[SerializablePrimitive]] = None + + def __init__(self, data_cls: Optional[Type] = None, prim_spec: Optional[SerializablePrimitive] = None, shift: bool = True): + if not dataclasses.is_dataclass(data_cls): + raise ValueError(f"{data_cls!r} is not a dataclass") + if prim_spec is None: + prim_spec = getattr(data_cls, 'PRIM_SPEC', None) + super().__init__(data_cls, prim_spec) self._shift = shift self._bitfield_spec = self._build_bitfield(data_cls) diff --git a/tests/base/test_legacy_schema.py b/tests/base/test_legacy_schema.py index ab643ec..6f01282 100644 --- a/tests/base/test_legacy_schema.py +++ b/tests/base/test_legacy_schema.py @@ -146,6 +146,9 @@ class TestLegacyInv(unittest.TestCase): new_model.root.name = "foo" self.assertNotEqual(self.model, new_model) + def test_legacy_serialization(self): + self.assertEqual(SIMPLE_INV, self.model.to_str()) + def test_difference_added(self): new_model = InventoryModel.from_llsd(self.model.to_llsd()) diff = self.model.get_differences(new_model)