diff --git a/hippolyzer/lib/base/serialization.py b/hippolyzer/lib/base/serialization.py index 659b864..83a1276 100644 --- a/hippolyzer/lib/base/serialization.py +++ b/hippolyzer/lib/base/serialization.py @@ -890,7 +890,23 @@ class TupleCoord(SerializableBase): return cls.COORD_CLS -class QuantizedTupleCoord(TupleCoord): +class EncodedTupleCoord(TupleCoord, abc.ABC): + _elem_specs: Sequence[SERIALIZABLE_TYPE] + + def serialize(self, vals, writer: BufferWriter, ctx): + vals = self._vals_to_tuple(vals) + for spec, val in zip(self._elem_specs, vals): + writer.write(spec, val, ctx=ctx) + + def deserialize(self, reader: Reader, ctx): + vals = (reader.read(spec, ctx=ctx) for spec in self._elem_specs) + val = self.COORD_CLS(*vals) + if self.need_pod(reader): + return tuple(val) + return val + + +class QuantizedTupleCoord(EncodedTupleCoord): def __init__(self, lower=None, upper=None, component_scales=None): super().__init__() if component_scales: @@ -906,17 +922,14 @@ class QuantizedTupleCoord(TupleCoord): ) assert len(self._elem_specs) == self.NUM_ELEMS - def serialize(self, vals, writer: BufferWriter, ctx): - vals = self._vals_to_tuple(vals) - for spec, val in zip(self._elem_specs, vals): - writer.write(spec, val, ctx=ctx) - def deserialize(self, reader: Reader, ctx): - vals = (reader.read(spec, ctx=ctx) for spec in self._elem_specs) - val = self.COORD_CLS(*vals) - if self.need_pod(reader): - return tuple(val) - return val +class FixedPointTupleCoord(EncodedTupleCoord): + def __init__(self, int_bits: int, frac_bits: int, signed: bool): + super().__init__() + self._elem_specs = tuple( + FixedPoint(self.ELEM_SPEC, int_bits, frac_bits, signed) + for _ in range(self.NUM_ELEMS) + ) class Vector3(TupleCoord): @@ -992,6 +1005,12 @@ class Vector4U8(QuantizedTupleCoord): COORD_CLS = dtypes.Vector4 +class FixedPointVector3U16(FixedPointTupleCoord): + ELEM_SPEC = U16 + NUM_ELEMS = 3 + COORD_CLS = dtypes.Vector3 + + class OptionalPrefixed(SerializableBase): """Field prefixed by a U8 indicating whether or not it's present""" OPTIONAL = True diff --git a/hippolyzer/lib/base/templates.py b/hippolyzer/lib/base/templates.py index 1fc08af..4fe1afe 100644 --- a/hippolyzer/lib/base/templates.py +++ b/hippolyzer/lib/base/templates.py @@ -1176,12 +1176,8 @@ PSYS_BLOCK_TEMPLATE = se.Template({ "BurstSpeedMin": se.FixedPoint(se.U16, 8, 8), "BurstSpeedMax": se.FixedPoint(se.U16, 8, 8), "BurstPartCount": se.U8, - "VelX": se.FixedPoint(se.U16, 8, 7, signed=True), - "VelY": se.FixedPoint(se.U16, 8, 7, signed=True), - "VelZ": se.FixedPoint(se.U16, 8, 7, signed=True), - "AccelX": se.FixedPoint(se.U16, 8, 7, signed=True), - "AccelY": se.FixedPoint(se.U16, 8, 7, signed=True), - "AccelZ": se.FixedPoint(se.U16, 8, 7, signed=True), + "Vel": se.FixedPointVector3U16(8, 7, signed=True), + "Accel": se.FixedPointVector3U16(8, 7, signed=True), "Texture": se.UUID, "Target": se.UUID, }) diff --git a/tests/base/test_serialization.py b/tests/base/test_serialization.py index a115e67..7697ef5 100644 --- a/tests/base/test_serialization.py +++ b/tests/base/test_serialization.py @@ -615,6 +615,16 @@ class QuantizedFloatSerializationTests(BaseSerializationTest): self.assertEqual(-2.0, reader.read(spec)) self.assertEqual(1.0, reader.read(spec)) + def test_fixed_point_tuplecoord(self): + expected_bytes = b"\xff\x80\x00\x00\x7f\x7f" + spec = se.FixedPointVector3U16(8, 7, signed=True) + self.writer.write_bytes(expected_bytes) + vec: Vector3 = self._get_reader().read(spec) + self._assert_coords_fuzzy_equals(tuple(vec), (255.0, -256.0, -1.0078)) + self.writer.clear() + self.writer.write(spec, vec) + self.assertEqual(expected_bytes, self.writer.copy_buffer()) + class NameValueSerializationTests(BaseSerializationTest): EXAMPLE_NAMEVALUES = b'DisplayName STRING RW DS unicodename\n' \