From f317ad88f12a67dd920aae6146cf01f9e388c8ef Mon Sep 17 00:00:00 2001 From: Imene Kerboua Date: Sun, 23 Jun 2024 14:27:59 +0200 Subject: [PATCH] Format files using ruff --- .gitignore | 2 +- CHANGELOG.md | 2 +- LICENSE | 2 +- Makefile | 5 + README.md | 4 +- plume_python/__init__.py | 7 - plume_python/cli.py | 86 ++++++--- plume_python/export/xdf_exporter.py | 82 ++++++-- plume_python/export/xdf_writer.py | 124 +++++++++--- plume_python/file_reader.py | 2 +- plume_python/parser.py | 1 - plume_python/record.py | 2 - plume_python/utils/dataframe.py | 46 +++-- plume_python/utils/game_object.py | 13 +- plume_python/utils/transform.py | 113 +++++++---- poetry.lock | 179 +++++++++++++++++- pyproject.toml | 1 + tests/.gitignore | 2 +- tests/test_compute_world_transform.py | 13 +- tests/test_export_xdf.py | 4 +- ...st_find_game_object_identifiers_by_name.py | 5 +- tests/test_parser.py | 14 +- 22 files changed, 552 insertions(+), 157 deletions(-) diff --git a/.gitignore b/.gitignore index 424791f..28cefba 100644 --- a/.gitignore +++ b/.gitignore @@ -165,4 +165,4 @@ cython_debug/ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ \ No newline at end of file +#.idea/ diff --git a/CHANGELOG.md b/CHANGELOG.md index 186b9de..ccd4c70 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,4 +23,4 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- Fixed a bug where extracting samples by time range would throw an exception if the record contained timeless samples. \ No newline at end of file +- Fixed a bug where extracting samples by time range would throw an exception if the record contained timeless samples. diff --git a/LICENSE b/LICENSE index e72bfdd..f288702 100644 --- a/LICENSE +++ b/LICENSE @@ -671,4 +671,4 @@ into proprietary programs. If your program is a subroutine library, you may consider it more useful to permit linking proprietary applications with the library. If this is what you want to do, use the GNU Lesser General Public License instead of this License. But first, please read -. \ No newline at end of file +. diff --git a/Makefile b/Makefile index 0f57f21..d3e4018 100644 --- a/Makefile +++ b/Makefile @@ -7,3 +7,8 @@ install: tests: @echo "--- 🧪 Running tests ---" poetry run pytest + +.PHONY: lint +lint: + @echo "--- 🧹 Linting code ---" + poetry run pre-commit run --all-files diff --git a/README.md b/README.md index ae35544..86494a0 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,7 @@ For more advanced usage, the package can be imported in a Python script: import plume_python as plm from plume_python.utils.dataframe import samples_to_dataframe, record_to_dataframes from plume_python.samples.unity import transform_pb2 -from plume_python.export import xdf_exporter +from plume_python.export import xdf_exporter from plume_python.utils.game_object import find_names_by_guid, find_first_identifier_by_name # Load a record file @@ -136,4 +136,4 @@ Sophie VILLENAVE - sophie.villenave@ec-lyon.fr ``` [Button Docs]: https://img.shields.io/badge/Explore%20the%20docs-%E2%86%92-brightgreen -[Explore the docs]: https://liris-xr.github.io/PLUME/ \ No newline at end of file +[Explore the docs]: https://liris-xr.github.io/PLUME/ diff --git a/plume_python/__init__.py b/plume_python/__init__.py index d59f629..e69de29 100644 --- a/plume_python/__init__.py +++ b/plume_python/__init__.py @@ -1,7 +0,0 @@ -from . import file_reader -from . import parser -from . import record -from . import utils -from . import samples -from . import export -from . import cli diff --git a/plume_python/cli.py b/plume_python/cli.py index 869a7c5..e883d44 100644 --- a/plume_python/cli.py +++ b/plume_python/cli.py @@ -5,9 +5,16 @@ from plume_python import parser from plume_python.export.xdf_exporter import export_xdf_from_record from plume_python.samples import sample_types_from_names -from plume_python.utils.dataframe import record_to_dataframes, samples_to_dataframe, world_transforms_to_dataframe -from plume_python.utils.game_object import find_names_by_guid, find_identifiers_by_name, \ - find_identifier_by_game_object_id +from plume_python.utils.dataframe import ( + record_to_dataframes, + samples_to_dataframe, + world_transforms_to_dataframe, +) +from plume_python.utils.game_object import ( + find_names_by_guid, + find_identifiers_by_name, + find_identifier_by_game_object_id, +) from plume_python.utils.transform import compute_transform_time_series @@ -17,33 +24,37 @@ def cli(): @click.command() -@click.argument('record_path', type=click.Path(exists=True, readable=True)) -@click.option('--xdf_output_path', type=click.Path(writable=True)) +@click.argument("record_path", type=click.Path(exists=True, readable=True)) +@click.option("--xdf_output_path", type=click.Path(writable=True)) def export_xdf(record_path: str, xdf_output_path: str | None): """Export a XDF file including LSL samples and markers.""" - if not record_path.endswith('.plm'): + if not record_path.endswith(".plm"): click.echo(err=True, message="Input file must be a .plm file") return if xdf_output_path is None: - xdf_output_path = record_path.replace('.plm', '.xdf') + xdf_output_path = record_path.replace(".plm", ".xdf") if os.path.exists(xdf_output_path): - if not click.confirm(f"File '{xdf_output_path}' already exists, do you want to overwrite it?"): + if not click.confirm( + f"File '{xdf_output_path}' already exists, do you want to overwrite it?" + ): return with open(xdf_output_path, "wb") as xdf_output_file: record = parser.parse_record_from_file(record_path) export_xdf_from_record(xdf_output_file, record) - click.echo('Exported xdf from record: ' + record_path + ' to ' + xdf_output_path) + click.echo( + "Exported xdf from record: " + record_path + " to " + xdf_output_path + ) @click.command() -@click.argument('record_path', type=click.Path(exists=True, readable=True)) -@click.argument('guid', type=click.STRING) +@click.argument("record_path", type=click.Path(exists=True, readable=True)) +@click.argument("guid", type=click.STRING) def find_name(record_path: str, guid: str): """Find the name(s) of a GameObject with the given GUID in the record.""" - if not record_path.endswith('.plm'): + if not record_path.endswith(".plm"): click.echo(err=True, message="Input file must be a .plm file") return @@ -58,11 +69,11 @@ def find_name(record_path: str, guid: str): @click.command() -@click.argument('record_path', type=click.Path(exists=True, readable=True)) -@click.argument('name', type=click.STRING) +@click.argument("record_path", type=click.Path(exists=True, readable=True)) +@click.argument("name", type=click.STRING) def find_guid(record_path: str, name: str): """Find the GUID(s) of a GameObject by the given name.""" - if not record_path.endswith('.plm'): + if not record_path.endswith(".plm"): click.echo(err=True, message="Input file must be a .plm file") return @@ -77,11 +88,11 @@ def find_guid(record_path: str, name: str): @click.command() -@click.argument('record_path', type=click.Path(exists=True, readable=True)) -@click.argument('guid', type=click.STRING) +@click.argument("record_path", type=click.Path(exists=True, readable=True)) +@click.argument("guid", type=click.STRING) def export_world_transforms(record_path: str, guid: str): """Export world transforms of a GameObject with the given GUID to a CSV file.""" - if not record_path.endswith('.plm'): + if not record_path.endswith(".plm"): click.echo(err=True, message="Input file must be a .plm file") return @@ -94,38 +105,53 @@ def export_world_transforms(record_path: str, guid: str): time_series = compute_transform_time_series(record, identifier.transform_id) df = world_transforms_to_dataframe(time_series) - file_path = record_path.replace('.plm', f'_{guid}_world_transform.csv') + file_path = record_path.replace(".plm", f"_{guid}_world_transform.csv") df.to_csv(file_path) @click.command() -@click.argument('record_path', type=click.Path(exists=True, readable=True)) -@click.argument('output_dir', type=click.Path(exists=True, writable=True)) -@click.option('--filter', default="all", show_default=True, type=click.STRING, - help="Comma separated list of sample types to export (eg. 'TransformUpdate,GameObjectUpdate')") +@click.argument("record_path", type=click.Path(exists=True, readable=True)) +@click.argument("output_dir", type=click.Path(exists=True, writable=True)) +@click.option( + "--filter", + default="all", + show_default=True, + type=click.STRING, + help="Comma separated list of sample types to export (eg. 'TransformUpdate,GameObjectUpdate')", +) def export_csv(record_path: str, output_dir: str | None, filter: str): """Export samples from the record to CSV files.""" - if not record_path.endswith('.plm'): + if not record_path.endswith(".plm"): click.echo(err=True, message="Input file must be a .plm file") return record = parser.parse_record_from_file(record_path) - filters = [d.strip() for d in filter.split(',')] + filters = [d.strip() for d in filter.split(",")] - if filters == ['all'] or filters == ['*']: + if filters == ["all"] or filters == ["*"]: dataframes = record_to_dataframes(record) for sample_type, df in dataframes.items(): - file_path = os.path.join(output_dir, sample_type.__name__ + '.csv') + file_path = os.path.join(output_dir, sample_type.__name__ + ".csv") df.to_csv(file_path) - click.echo('Exported CSV for sample type: ' + sample_type.__name__ + ' to ' + file_path) + click.echo( + "Exported CSV for sample type: " + + sample_type.__name__ + + " to " + + file_path + ) else: sample_types = sample_types_from_names(filters) for sample_type in sample_types: df = samples_to_dataframe(record.get_samples_by_type(sample_type)) - file_path = os.path.join(output_dir, sample_type.__name__ + '.csv') + file_path = os.path.join(output_dir, sample_type.__name__ + ".csv") df.to_csv(file_path) - click.echo('Exported CSV for sample type: ' + sample_type.__name__ + ' to ' + file_path) + click.echo( + "Exported CSV for sample type: " + + sample_type.__name__ + + " to " + + file_path + ) cli.add_command(export_csv) diff --git a/plume_python/export/xdf_exporter.py b/plume_python/export/xdf_exporter.py index 7d36173..80d1ad9 100644 --- a/plume_python/export/xdf_exporter.py +++ b/plume_python/export/xdf_exporter.py @@ -1,12 +1,26 @@ -from plume_python.export.xdf_writer import * +from plume_python.export.xdf_writer import ( + STR_ENCODING, + write_file_header, + write_stream_header, + write_stream_sample, + write_stream_footer, +) from plume_python.record import Record from plume_python.samples.common import marker_pb2 from plume_python.samples.lsl import lsl_stream_pb2 from typing import BinaryIO +import xml.etree.ElementTree as ET +import numpy as np + def export_xdf_from_record(output_file: BinaryIO, record: Record): - datetime_str = record.get_metadata().start_time.ToDatetime().astimezone().strftime('%Y-%m-%dT%H:%M:%S%z') + datetime_str = ( + record.get_metadata() + .start_time.ToDatetime() + .astimezone() + .strftime("%Y-%m-%dT%H:%M:%S%z") + ) # Add a colon separator to the offset segment datetime_str = "{0}:{1}".format(datetime_str[:-2], datetime_str[-2:]) @@ -26,7 +40,9 @@ def export_xdf_from_record(output_file: BinaryIO, record: Record): for lsl_open_stream in record[lsl_stream_pb2.StreamOpen]: xml_header = ET.fromstring(lsl_open_stream.payload.xml_header) - stream_id = np.uint64(lsl_open_stream.payload.stream_id) + 1 # reserve id = 1 for the marker stream + stream_id = ( + np.uint64(lsl_open_stream.payload.stream_id) + 1 + ) # reserve id = 1 for the marker stream channel_format = xml_header.find("channel_format").text stream_channel_format[stream_id] = channel_format stream_min_time[stream_id] = None @@ -50,19 +66,33 @@ def export_xdf_from_record(output_file: BinaryIO, record: Record): stream_sample_count[stream_id] += 1 if channel_format == "string": - val = np.array([x for x in lsl_sample.payload.string_values.value], dtype=np.str_) + val = np.array( + [x for x in lsl_sample.payload.string_values.value], dtype=np.str_ + ) elif channel_format == "int8": - val = np.array([x for x in lsl_sample.payload.int8_values.value], dtype=np.int8) + val = np.array( + [x for x in lsl_sample.payload.int8_values.value], dtype=np.int8 + ) elif channel_format == "int16": - val = np.array([x for x in lsl_sample.payload.int16_values.value], dtype=np.int16) + val = np.array( + [x for x in lsl_sample.payload.int16_values.value], dtype=np.int16 + ) elif channel_format == "int32": - val = np.array([x for x in lsl_sample.payload.int32_values.value], dtype=np.int32) + val = np.array( + [x for x in lsl_sample.payload.int32_values.value], dtype=np.int32 + ) elif channel_format == "int64": - val = np.array([x for x in lsl_sample.payload.int64_values.value], dtype=np.int64) + val = np.array( + [x for x in lsl_sample.payload.int64_values.value], dtype=np.int64 + ) elif channel_format == "float32": - val = np.array([x for x in lsl_sample.payload.float_values.value], dtype=np.float32) + val = np.array( + [x for x in lsl_sample.payload.float_values.value], dtype=np.float32 + ) elif channel_format == "double64": - val = np.array([x for x in lsl_sample.payload.double_values.value], dtype=np.float64) + val = np.array( + [x for x in lsl_sample.payload.double_values.value], dtype=np.float64 + ) else: raise ValueError(f"Unsupported channel format: {channel_format}") @@ -71,9 +101,15 @@ def export_xdf_from_record(output_file: BinaryIO, record: Record): for marker_sample in record[marker_pb2.Marker]: t = marker_sample.timestamp / 1_000_000_000.0 # convert time to seconds - if stream_min_time[marker_stream_id] is None or t < stream_min_time[marker_stream_id]: + if ( + stream_min_time[marker_stream_id] is None + or t < stream_min_time[marker_stream_id] + ): stream_min_time[marker_stream_id] = t - if stream_max_time[marker_stream_id] is None or t > stream_max_time[marker_stream_id]: + if ( + stream_max_time[marker_stream_id] is None + or t > stream_max_time[marker_stream_id] + ): stream_max_time[marker_stream_id] = t if marker_stream_id not in stream_sample_count: @@ -87,13 +123,23 @@ def export_xdf_from_record(output_file: BinaryIO, record: Record): for lsl_close_stream in record[lsl_stream_pb2.StreamClose]: stream_id = np.uint64(lsl_close_stream.payload.stream_id) + 1 sample_count = stream_sample_count[stream_id] - write_stream_footer(output_file, stream_min_time[stream_id], stream_max_time[stream_id], sample_count, - stream_id) + write_stream_footer( + output_file, + stream_min_time[stream_id], + stream_max_time[stream_id], + sample_count, + stream_id, + ) # Write marker stream footer # stream_id = 1 is reserved for the marker stream - write_stream_footer(output_file, stream_min_time[marker_stream_id], stream_max_time[marker_stream_id], - stream_sample_count[marker_stream_id], marker_stream_id) + write_stream_footer( + output_file, + stream_min_time[marker_stream_id], + stream_max_time[marker_stream_id], + stream_sample_count[marker_stream_id], + marker_stream_id, + ) def write_marker_stream_header(output_buf, marker_stream_id): @@ -109,4 +155,6 @@ def write_marker_stream_header(output_buf, marker_stream_id): channel_count_el.text = "1" nominal_srate_el.text = "0.0" xml = ET.tostring(info_el, encoding=STR_ENCODING, xml_declaration=True) - write_stream_header(output_buf, xml, marker_stream_id) # stream_id = 1 is reserved for the marker stream + write_stream_header( + output_buf, xml, marker_stream_id + ) # stream_id = 1 is reserved for the marker stream diff --git a/plume_python/export/xdf_writer.py b/plume_python/export/xdf_writer.py index 6a1cabb..db5b855 100644 --- a/plume_python/export/xdf_writer.py +++ b/plume_python/export/xdf_writer.py @@ -18,7 +18,19 @@ int64=np.int64, ) -DataType = np.int8 | np.int16 | np.int32 | np.int64 | np.uint8 | np.uint16 | np.uint32 | np.uint64 | np.float32 | np.float64 | str +DataType = ( + np.int8 + | np.int16 + | np.int32 + | np.int64 + | np.uint8 + | np.uint16 + | np.uint32 + | np.uint64 + | np.float32 + | np.float64 + | str +) class ChunkTag(Enum): @@ -32,18 +44,19 @@ class ChunkTag(Enum): def write_file_header(output: BinaryIO, version: str, datetime: str): - output.write(b'XDF:') + output.write(b"XDF:") info_element = ET.Element("info") version_element = ET.SubElement(info_element, "version") datetime_element = ET.SubElement(info_element, "datetime") version_element.text = version datetime_element.text = datetime - xml_str = ET.tostring( - info_element, xml_declaration=True, encoding=STR_ENCODING) + xml_str = ET.tostring(info_element, xml_declaration=True, encoding=STR_ENCODING) write_chunk(output, ChunkTag.FILE_HEADER, xml_str) -def write_chunk(output: BinaryIO, chunk_tag: ChunkTag, content: bytes, stream_id: np.uint32 = None): +def write_chunk( + output: BinaryIO, chunk_tag: ChunkTag, content: bytes, stream_id: np.uint32 = None +): if not isinstance(content, bytes): raise Exception("Content should be bytes.") @@ -61,15 +74,27 @@ def write_chunk(output: BinaryIO, chunk_tag: ChunkTag, content: bytes, stream_id write(output, content) -def write_stream_header(output: BinaryIO, xml_header: str | bytes, stream_id: np.uint32 = None): +def write_stream_header( + output: BinaryIO, xml_header: str | bytes, stream_id: np.uint32 = None +): if isinstance(xml_header, str): xml_header = bytes(xml_header, encoding=STR_ENCODING) - write_chunk(output, ChunkTag.STREAM_HEADER, xml_header, None if stream_id is None else np.uint32(stream_id)) - - -def write_stream_footer(output: BinaryIO, first_timestamp: float, last_timestamp: float, - sample_count: int, stream_id: np.uint32 = None): + write_chunk( + output, + ChunkTag.STREAM_HEADER, + xml_header, + None if stream_id is None else np.uint32(stream_id), + ) + + +def write_stream_footer( + output: BinaryIO, + first_timestamp: float, + last_timestamp: float, + sample_count: int, + stream_id: np.uint32 = None, +): first_timestamp = np.float64(first_timestamp) last_timestamp = np.float64(last_timestamp) sample_count = np.uint64(sample_count) @@ -81,23 +106,42 @@ def write_stream_footer(output: BinaryIO, first_timestamp: float, last_timestamp last_timestamp_element.text = str(last_timestamp) sample_count_element.text = str(sample_count) - xml_str = ET.tostring( - info_element, xml_declaration=True, encoding=STR_ENCODING) - write_chunk(output, ChunkTag.STREAM_FOOTER, xml_str, None if stream_id is None else np.uint32(stream_id)) - - -def write_stream_sample(output: BinaryIO, sample: np.ndarray, timestamp: float, channel_format: str, - stream_id: np.uint32 = None): + xml_str = ET.tostring(info_element, xml_declaration=True, encoding=STR_ENCODING) + write_chunk( + output, + ChunkTag.STREAM_FOOTER, + xml_str, + None if stream_id is None else np.uint32(stream_id), + ) + + +def write_stream_sample( + output: BinaryIO, + sample: np.ndarray, + timestamp: float, + channel_format: str, + stream_id: np.uint32 = None, +): if channel_format not in formats: raise Exception("Unsupported channel format '{}'".format(channel_format)) fmt = formats[channel_format] - write_stream_sample_chunk(output, np.array([sample], dtype=fmt), [timestamp], - channel_format, None if stream_id is None else np.uint32(stream_id)) - - -def write_stream_sample_chunk(output: BinaryIO, chunk: np.ndarray, timestamps: list[float], channel_format: str, - stream_id: np.uint32 = None): + write_stream_sample_chunk( + output, + np.array([sample], dtype=fmt), + [timestamp], + channel_format, + None if stream_id is None else np.uint32(stream_id), + ) + + +def write_stream_sample_chunk( + output: BinaryIO, + chunk: np.ndarray, + timestamps: list[float], + channel_format: str, + stream_id: np.uint32 = None, +): if channel_format not in formats: raise Exception("Unsupported channel format '{}'".format(channel_format)) @@ -133,7 +177,12 @@ def write_stream_sample_chunk(output: BinaryIO, chunk: np.ndarray, timestamps: l else: raise Exception("Unsupported data type " + str(type(channel))) - write_chunk(output, ChunkTag.SAMPLES, tmp_output.getvalue(), None if stream_id is None else np.uint32(stream_id)) + write_chunk( + output, + ChunkTag.SAMPLES, + tmp_output.getvalue(), + None if stream_id is None else np.uint32(stream_id), + ) def write_timestamp(output: BinaryIO, timestamp: Optional[float] = None): @@ -159,9 +208,28 @@ def write_variable_length_integer(output: BinaryIO, val: np.uint64): write(output, np.uint64(val)) -def write_fixed_length_integer(output: BinaryIO, - val: np.int8 | np.int16 | np.int32 | np.int64 | np.uint8 | np.uint16 | np.uint32 | np.uint64): - if not isinstance(val, np.int8 | np.int16 | np.int32 | np.int64 | np.uint8 | np.uint16 | np.uint32 | np.uint64): +def write_fixed_length_integer( + output: BinaryIO, + val: np.int8 + | np.int16 + | np.int32 + | np.int64 + | np.uint8 + | np.uint16 + | np.uint32 + | np.uint64, +): + if not isinstance( + val, + np.int8 + | np.int16 + | np.int32 + | np.int64 + | np.uint8 + | np.uint16 + | np.uint32 + | np.uint64, + ): raise Exception("Unsupported data type " + str(type(val))) write(output, np.uint8(np.dtype(val).itemsize)) diff --git a/plume_python/file_reader.py b/plume_python/file_reader.py index c896755..70423f1 100644 --- a/plume_python/file_reader.py +++ b/plume_python/file_reader.py @@ -10,7 +10,7 @@ def _is_lz4_compressed(raw_bytes: bytes) -> bool: def read_file(filepath: str) -> BinaryIO: - if not filepath.endswith('.plm'): + if not filepath.endswith(".plm"): raise ValueError("File must be a .plm file") with open(filepath, "rb") as file: diff --git a/plume_python/parser.py b/plume_python/parser.py index acc414e..c4a50f3 100644 --- a/plume_python/parser.py +++ b/plume_python/parser.py @@ -41,7 +41,6 @@ def parse_record_from_stream(data_stream: BinaryIO) -> Record: last_timestamp: Optional[int] = None for packed_sample in tqdm(packed_samples, desc="Unpacking samples", unit="samples"): - unpacked_payload = unpack_any(packed_sample.payload, default_descriptor_pool) if unpacked_payload is None: continue diff --git a/plume_python/record.py b/plume_python/record.py index 7d12064..a1f49c0 100644 --- a/plume_python/record.py +++ b/plume_python/record.py @@ -4,7 +4,6 @@ from google.protobuf.message import Message -from .samples import record_pb2 T = TypeVar("T", bound=Message) @@ -65,7 +64,6 @@ def get_samples_in_time_range( samples_in_time_range = {} for payload_type, samples in self.samples_by_type.items(): - samples = [ sample for sample in samples diff --git a/plume_python/utils/dataframe.py b/plume_python/utils/dataframe.py index e08c4f5..e7736f1 100644 --- a/plume_python/utils/dataframe.py +++ b/plume_python/utils/dataframe.py @@ -8,10 +8,12 @@ from plume_python.record import Sample, FrameDataSample, Record -T = TypeVar('T', bound=Message) +T = TypeVar("T", bound=Message) -def world_transforms_to_dataframe(world_transforms: list[TimestampedTransform]) -> pd.DataFrame: +def world_transforms_to_dataframe( + world_transforms: list[TimestampedTransform], +) -> pd.DataFrame: if len(world_transforms) == 0: return pd.DataFrame() @@ -21,18 +23,21 @@ def world_transforms_to_dataframe(world_transforms: list[TimestampedTransform]) world_position = world_transform.get_world_position() world_rotation = world_transform.get_world_rotation() world_scale = world_transform.get_world_scale() - world_transform_data.append({"timestamp": world_transform.timestamp, - "position_x": world_position[0], - "position_y": world_position[1], - "position_z": world_position[2], - "rotation_x": world_rotation.x, - "rotation_y": world_rotation.y, - "rotation_z": world_rotation.z, - "rotation_w": world_rotation.w, - "scale_x": world_scale[0], - "scale_y": world_scale[1], - "scale_z": world_scale[2] - }) + world_transform_data.append( + { + "timestamp": world_transform.timestamp, + "position_x": world_position[0], + "position_y": world_position[1], + "position_z": world_position[2], + "rotation_x": world_rotation.x, + "rotation_y": world_rotation.y, + "rotation_z": world_rotation.z, + "rotation_w": world_rotation.w, + "scale_x": world_scale[0], + "scale_y": world_scale[1], + "scale_z": world_scale[2], + } + ) return pd.json_normalize(world_transform_data) @@ -47,13 +52,20 @@ def samples_to_dataframe(samples: list[Sample[T]]) -> pd.DataFrame: frame_samples = cast(list[FrameDataSample[T]], samples) for frame_sample in frame_samples: sample_payload_fields_value = MessageToDict(frame_sample.payload, True) - sample_data.append({"timestamp": frame_sample.timestamp, - "frame_number": frame_sample.frame_number} | sample_payload_fields_value) + sample_data.append( + { + "timestamp": frame_sample.timestamp, + "frame_number": frame_sample.frame_number, + } + | sample_payload_fields_value + ) else: for sample in samples: sample_payload_fields_value = MessageToDict(sample.payload, True) if sample.is_timestamped(): - sample_data.append({"timestamp": sample.timestamp} | sample_payload_fields_value) + sample_data.append( + {"timestamp": sample.timestamp} | sample_payload_fields_value + ) else: sample_data.append(sample_payload_fields_value) diff --git a/plume_python/utils/game_object.py b/plume_python/utils/game_object.py index a0fa438..015a1a0 100644 --- a/plume_python/utils/game_object.py +++ b/plume_python/utils/game_object.py @@ -24,7 +24,9 @@ def find_first_name_by_guid(record: Record, guid: str) -> Optional[str]: return None -def find_identifier_by_game_object_id(record: Record, game_object_id: str) -> Optional[GameObjectIdentifier]: +def find_identifier_by_game_object_id( + record: Record, game_object_id: str +) -> Optional[GameObjectIdentifier]: for go_update_sample in record[GameObjectUpdate]: go_update = go_update_sample.payload if go_update.id.game_object_id == game_object_id: @@ -39,13 +41,18 @@ def find_identifiers_by_name(record: Record, name: str) -> list[GameObjectIdenti for go_update_sample in record[GameObjectUpdate]: go_update = go_update_sample.payload if go_update.HasField("name"): - if name == go_update.name and go_update.id.game_object_id not in known_guids: + if ( + name == go_update.name + and go_update.id.game_object_id not in known_guids + ): identifiers.append(go_update.id) known_guids.add(go_update.id.game_object_id) return identifiers -def find_first_identifier_by_name(record: Record, name: str) -> Optional[GameObjectIdentifier]: +def find_first_identifier_by_name( + record: Record, name: str +) -> Optional[GameObjectIdentifier]: for go_update_sample in record[GameObjectUpdate]: go_update = go_update_sample.payload if go_update.HasField("name"): diff --git a/plume_python/utils/transform.py b/plume_python/utils/transform.py index 6429654..8a8481e 100644 --- a/plume_python/utils/transform.py +++ b/plume_python/utils/transform.py @@ -15,12 +15,24 @@ @dataclass(slots=True) class Transform: _guid: str - _local_position: np.ndarray = field(default_factory=lambda: np.array(3, dtype=np.float32)) - _local_rotation: quaternion.quaternion = field(default_factory=lambda: quaternion.quaternion(1, 0, 0, 0)) - _local_scale: np.ndarray = field(default_factory=lambda: np.array(4, dtype=np.float32)) - _local_T_mtx: np.ndarray = field(default_factory=lambda: np.eye(4, dtype=np.float32)) - _local_R_mtx: np.ndarray = field(default_factory=lambda: np.eye(4, dtype=np.float32)) - _local_S_mtx: np.ndarray = field(default_factory=lambda: np.eye(4, dtype=np.float32)) + _local_position: np.ndarray = field( + default_factory=lambda: np.array(3, dtype=np.float32) + ) + _local_rotation: quaternion.quaternion = field( + default_factory=lambda: quaternion.quaternion(1, 0, 0, 0) + ) + _local_scale: np.ndarray = field( + default_factory=lambda: np.array(4, dtype=np.float32) + ) + _local_T_mtx: np.ndarray = field( + default_factory=lambda: np.eye(4, dtype=np.float32) + ) + _local_R_mtx: np.ndarray = field( + default_factory=lambda: np.eye(4, dtype=np.float32) + ) + _local_S_mtx: np.ndarray = field( + default_factory=lambda: np.eye(4, dtype=np.float32) + ) _local_to_world_mtx: np.ndarray = None _parent: Optional[Transform] = None _dirty: bool = True @@ -66,13 +78,14 @@ def get_parent(self) -> Optional[Transform]: def get_local_to_world_matrix(self) -> np.ndarray: if self._is_dirty() or self._local_to_world_mtx is None: - trs_mtx = self._local_T_mtx @ self._local_R_mtx @ self._local_S_mtx if self._parent is None: self._local_to_world_mtx = trs_mtx else: - self._local_to_world_mtx = self._parent.get_local_to_world_matrix() @ trs_mtx + self._local_to_world_mtx = ( + self._parent.get_local_to_world_matrix() @ trs_mtx + ) self._dirty = False @@ -82,7 +95,9 @@ def get_world_position(self) -> np.ndarray: return self.get_local_to_world_matrix()[0:3, 3].transpose() def get_world_rotation(self) -> quaternion: - return quaternion.from_rotation_matrix(self.get_local_to_world_matrix()[0:3, 0:3]) + return quaternion.from_rotation_matrix( + self.get_local_to_world_matrix()[0:3, 0:3] + ) def get_world_scale(self) -> np.ndarray: local_to_world_mtx = self.get_local_to_world_matrix() @@ -118,27 +133,40 @@ def get_world_scale(self) -> np.ndarray: return scale -def compute_transform_time_series(record: Record, guid: str) -> list[TimestampedTransform]: +def compute_transform_time_series( + record: Record, guid: str +) -> list[TimestampedTransform]: transform_time_series = compute_transforms_time_series(record, {guid}) return transform_time_series.get(guid, {}) -def compute_transforms_time_series(record: Record, included_guids: set[str] = None) \ - -> dict[str, list[TimestampedTransform]]: +def compute_transforms_time_series( + record: Record, included_guids: set[str] = None +) -> dict[str, list[TimestampedTransform]]: result: dict[str, list[TimestampedTransform]] = {} current_transforms: dict[str, Transform] = {} - creation_samples: dict[int, list[FrameDataSample[transform_pb2.TransformCreate]]] = {} - destruction_samples: dict[int, list[FrameDataSample[transform_pb2.TransformDestroy]]] = {} + creation_samples: dict[ + int, list[FrameDataSample[transform_pb2.TransformCreate]] + ] = {} + destruction_samples: dict[ + int, list[FrameDataSample[transform_pb2.TransformDestroy]] + ] = {} update_samples: dict[int, list[FrameDataSample[transform_pb2.TransformUpdate]]] = {} - for frame_number, s in groupby(record[transform_pb2.TransformCreate], lambda x: x.frame_number): + for frame_number, s in groupby( + record[transform_pb2.TransformCreate], lambda x: x.frame_number + ): creation_samples[frame_number] = list(s) - for frame_number, s in groupby(record[transform_pb2.TransformDestroy], lambda x: x.frame_number): + for frame_number, s in groupby( + record[transform_pb2.TransformDestroy], lambda x: x.frame_number + ): destruction_samples[frame_number] = list(s) - for frame_number, s in groupby(record[transform_pb2.TransformUpdate], lambda x: x.frame_number): + for frame_number, s in groupby( + record[transform_pb2.TransformUpdate], lambda x: x.frame_number + ): update_samples[frame_number] = list(s) for frame in tqdm(record.frames_info, desc="Computing world positions"): @@ -161,17 +189,26 @@ def compute_transforms_time_series(record: Record, included_guids: set[str] = No for update_sample in update_samples[frame.frame_number]: guid = update_sample.payload.id.component_id local_transform = current_transforms[guid] - if update_sample.payload.HasField('local_position'): + if update_sample.payload.HasField("local_position"): local_position = update_sample.payload.local_position - local_transform.set_local_position(np.array([local_position.x, local_position.y, local_position.z])) - if update_sample.payload.HasField('local_rotation'): + local_transform.set_local_position( + np.array([local_position.x, local_position.y, local_position.z]) + ) + if update_sample.payload.HasField("local_rotation"): local_rotation = update_sample.payload.local_rotation - q = quaternion.quaternion(local_rotation.w, local_rotation.x, local_rotation.y, local_rotation.z) + q = quaternion.quaternion( + local_rotation.w, + local_rotation.x, + local_rotation.y, + local_rotation.z, + ) local_transform.set_local_rotation(q) - if update_sample.payload.HasField('local_scale'): + if update_sample.payload.HasField("local_scale"): local_scale = update_sample.payload.local_scale - local_transform.set_local_scale(np.array([local_scale.x, local_scale.y, local_scale.z])) - if update_sample.payload.HasField('parent_transform_id'): + local_transform.set_local_scale( + np.array([local_scale.x, local_scale.y, local_scale.z]) + ) + if update_sample.payload.HasField("parent_transform_id"): parent_guid = update_sample.payload.parent_transform_id.component_id if parent_guid == "00000000000000000000000000000000": # null guid local_transform.set_parent(None) @@ -185,17 +222,25 @@ def compute_transforms_time_series(record: Record, included_guids: set[str] = No if included_guids is None: included_transforms = current_transforms.values() else: - included_transforms = [current_transforms[guid] for guid in included_guids if guid in current_transforms] + included_transforms = [ + current_transforms[guid] + for guid in included_guids + if guid in current_transforms + ] for t in included_transforms: - timestamped_transform = TimestampedTransform(timestamp=frame.timestamp, - frame_number=frame.frame_number, - guid=t.get_guid(), - parent_guid=t.get_parent().get_guid(), - local_scale=t.get_local_scale(), - local_position=t.get_local_position(), - local_rotation=t.get_local_rotation(), - local_to_world_mtx=t.get_local_to_world_matrix()) - result.setdefault(t.get_guid(), list[TimestampedTransform]()).append(timestamped_transform) + timestamped_transform = TimestampedTransform( + timestamp=frame.timestamp, + frame_number=frame.frame_number, + guid=t.get_guid(), + parent_guid=t.get_parent().get_guid(), + local_scale=t.get_local_scale(), + local_position=t.get_local_position(), + local_rotation=t.get_local_rotation(), + local_to_world_mtx=t.get_local_to_world_matrix(), + ) + result.setdefault(t.get_guid(), list[TimestampedTransform]()).append( + timestamped_transform + ) return result diff --git a/poetry.lock b/poetry.lock index 10a2ebd..3a67d48 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,5 +1,16 @@ # This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +[[package]] +name = "cfgv" +version = "3.4.0" +description = "Validate configuration and produce human readable error messages." +optional = false +python-versions = ">=3.8" +files = [ + {file = "cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9"}, + {file = "cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560"}, +] + [[package]] name = "click" version = "8.1.7" @@ -39,6 +50,17 @@ files = [ [package.dependencies] protobuf = "*" +[[package]] +name = "distlib" +version = "0.3.8" +description = "Distribution utilities" +optional = false +python-versions = "*" +files = [ + {file = "distlib-0.3.8-py2.py3-none-any.whl", hash = "sha256:034db59a0b96f8ca18035f36290806a9a6e6bd9d1ff91e45a7f172eb17e51784"}, + {file = "distlib-0.3.8.tar.gz", hash = "sha256:1530ea13e350031b6312d8580ddb6b27a104275a31106523b8f123787f494f64"}, +] + [[package]] name = "exceptiongroup" version = "1.2.1" @@ -53,6 +75,36 @@ files = [ [package.extras] test = ["pytest (>=6)"] +[[package]] +name = "filelock" +version = "3.15.4" +description = "A platform independent file lock." +optional = false +python-versions = ">=3.8" +files = [ + {file = "filelock-3.15.4-py3-none-any.whl", hash = "sha256:6ca1fffae96225dab4c6eaf1c4f4f28cd2568d3ec2a44e15a08520504de468e7"}, + {file = "filelock-3.15.4.tar.gz", hash = "sha256:2207938cbc1844345cb01a5a95524dae30f0ce089eba5b00378295a17e3e90cb"}, +] + +[package.extras] +docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-asyncio (>=0.21)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)", "virtualenv (>=20.26.2)"] +typing = ["typing-extensions (>=4.8)"] + +[[package]] +name = "identify" +version = "2.5.36" +description = "File identification library for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "identify-2.5.36-py2.py3-none-any.whl", hash = "sha256:37d93f380f4de590500d9dba7db359d0d3da95ffe7f9de1753faa159e71e7dfa"}, + {file = "identify-2.5.36.tar.gz", hash = "sha256:e5e00f54165f9047fbebeb4a560f9acfb8af4c88232be60a488e9b68d122745d"}, +] + +[package.extras] +license = ["ukkonen"] + [[package]] name = "iniconfig" version = "2.0.0" @@ -114,6 +166,17 @@ docs = ["sphinx (>=1.6.0)", "sphinx-bootstrap-theme"] flake8 = ["flake8"] tests = ["psutil", "pytest (!=3.3.0)", "pytest-cov"] +[[package]] +name = "nodeenv" +version = "1.9.1" +description = "Node.js virtual environment builder" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +files = [ + {file = "nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9"}, + {file = "nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f"}, +] + [[package]] name = "numpy" version = "1.26.4" @@ -310,6 +373,22 @@ sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-d test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"] xml = ["lxml (>=4.9.2)"] +[[package]] +name = "platformdirs" +version = "4.2.2" +description = "A small Python package for determining appropriate platform-specific dirs, e.g. a `user data dir`." +optional = false +python-versions = ">=3.8" +files = [ + {file = "platformdirs-4.2.2-py3-none-any.whl", hash = "sha256:2d7a1657e36a80ea911db832a8a6ece5ee53d8de21edd5cc5879af6530b1bfee"}, + {file = "platformdirs-4.2.2.tar.gz", hash = "sha256:38b7b51f512eed9e84a22788b4bce1de17c0adb134d6becb09836e37d8654cd3"}, +] + +[package.extras] +docs = ["furo (>=2023.9.10)", "proselint (>=0.13)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] +test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)"] +type = ["mypy (>=1.8)"] + [[package]] name = "pluggy" version = "1.5.0" @@ -325,6 +404,24 @@ files = [ dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] +[[package]] +name = "pre-commit" +version = "3.7.1" +description = "A framework for managing and maintaining multi-language pre-commit hooks." +optional = false +python-versions = ">=3.9" +files = [ + {file = "pre_commit-3.7.1-py2.py3-none-any.whl", hash = "sha256:fae36fd1d7ad7d6a5a1c0b0d5adb2ed1a3bda5a21bf6c3e5372073d7a11cd4c5"}, + {file = "pre_commit-3.7.1.tar.gz", hash = "sha256:8ca3ad567bc78a4972a3f1a477e94a79d4597e8140a6e0b651c5e33899c3654a"}, +] + +[package.dependencies] +cfgv = ">=2.0.0" +identify = ">=1.0.0" +nodeenv = ">=0.11.1" +pyyaml = ">=5.1" +virtualenv = ">=20.10.0" + [[package]] name = "protobuf" version = "5.27.1" @@ -392,6 +489,66 @@ files = [ {file = "pytz-2024.1.tar.gz", hash = "sha256:2a29735ea9c18baf14b448846bde5a48030ed267578472d8955cd0e7443a9812"}, ] +[[package]] +name = "pyyaml" +version = "6.0.1" +description = "YAML parser and emitter for Python" +optional = false +python-versions = ">=3.6" +files = [ + {file = "PyYAML-6.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a"}, + {file = "PyYAML-6.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, + {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, + {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, + {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, + {file = "PyYAML-6.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, + {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, + {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, + {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd"}, + {file = "PyYAML-6.0.1-cp36-cp36m-win32.whl", hash = "sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585"}, + {file = "PyYAML-6.0.1-cp36-cp36m-win_amd64.whl", hash = "sha256:f22ac1c3cac4dbc50079e965eba2c1058622631e526bd9afd45fedd49ba781fa"}, + {file = "PyYAML-6.0.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:18aeb1bf9a78867dc38b259769503436b7c72f7a1f1f4c93ff9a17de54319b27"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:596106435fa6ad000c2991a98fa58eeb8656ef2325d7e158344fb33864ed87e3"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baa90d3f661d43131ca170712d903e6295d1f7a0f595074f151c0aed377c9b9c"}, + {file = "PyYAML-6.0.1-cp37-cp37m-win32.whl", hash = "sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba"}, + {file = "PyYAML-6.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:4fb147e7a67ef577a588a0e2c17b6db51dda102c71de36f8549b6816a96e1867"}, + {file = "PyYAML-6.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1d4c7e777c441b20e32f52bd377e0c409713e8bb1386e1099c2415f26e479595"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, + {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, + {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, + {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, + {file = "PyYAML-6.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, + {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, + {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, + {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, +] + [[package]] name = "six" version = "1.16.0" @@ -445,7 +602,27 @@ files = [ {file = "tzdata-2024.1.tar.gz", hash = "sha256:2674120f8d891909751c38abcdfd386ac0a5a1127954fbc332af6b5ceae07efd"}, ] +[[package]] +name = "virtualenv" +version = "20.26.3" +description = "Virtual Python Environment builder" +optional = false +python-versions = ">=3.7" +files = [ + {file = "virtualenv-20.26.3-py3-none-any.whl", hash = "sha256:8cc4a31139e796e9a7de2cd5cf2489de1217193116a8fd42328f1bd65f434589"}, + {file = "virtualenv-20.26.3.tar.gz", hash = "sha256:4c43a2a236279d9ea36a0d76f98d84bd6ca94ac4e0f4a3b9d46d05e10fea542a"}, +] + +[package.dependencies] +distlib = ">=0.3.7,<1" +filelock = ">=3.12.2,<4" +platformdirs = ">=3.9.1,<5" + +[package.extras] +docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2,!=7.3)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"] +test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8)", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10)"] + [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "62afe993ca87613b19b6253ce713518c998f3ad9b363f1bd29592e4d102dcdbd" +content-hash = "0389fcf30f061f817d308ced1e3155a12d1f6676196c6335bfff3b7ee5bcb70a" diff --git a/pyproject.toml b/pyproject.toml index 900e80a..d7f3dc6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ optional = true [tool.poetry.group.dev.dependencies] pytest = "^7.4.4" +pre-commit = "^3.7.1" [tool.setuptools.packages.find] exclude = ["tests"] diff --git a/tests/.gitignore b/tests/.gitignore index 9cf4042..69341fd 100644 --- a/tests/.gitignore +++ b/tests/.gitignore @@ -1 +1 @@ -test_outputs/ \ No newline at end of file +test_outputs/ diff --git a/tests/test_compute_world_transform.py b/tests/test_compute_world_transform.py index 95ac7f0..2ae32fb 100644 --- a/tests/test_compute_world_transform.py +++ b/tests/test_compute_world_transform.py @@ -1,11 +1,18 @@ import pytest -from plume_python.utils.transform import compute_transforms_time_series, compute_transform_time_series +from plume_python.utils.transform import ( + compute_transforms_time_series, + compute_transform_time_series, +) def test_compute_single_transform_time_series(): - transform_time_series = compute_transform_time_series(pytest.record, "4a3f40e37eaf4c0a9d5d88ac993c0ebc") + transform_time_series = compute_transform_time_series( + pytest.record, "4a3f40e37eaf4c0a9d5d88ac993c0ebc" + ) def test_compute_all_transforms_time_series(): - transform_time_series = compute_transforms_time_series(pytest.record, {"4a3f40e37eaf4c0a9d5d88ac993c0ebc"}) + transform_time_series = compute_transforms_time_series( + pytest.record, {"4a3f40e37eaf4c0a9d5d88ac993c0ebc"} + ) diff --git a/tests/test_export_xdf.py b/tests/test_export_xdf.py index 9b18686..4cd4993 100644 --- a/tests/test_export_xdf.py +++ b/tests/test_export_xdf.py @@ -7,6 +7,6 @@ def test_export_xdf(): record = pytest.record # create directory tests/test_outputs if it does not exist - Path('tests/test_outputs').mkdir(parents=True, exist_ok=True) - with open('tests/test_outputs/test.xdf', 'wb') as f: + Path("tests/test_outputs").mkdir(parents=True, exist_ok=True) + with open("tests/test_outputs/test.xdf", "wb") as f: export_xdf_from_record(f, record) diff --git a/tests/test_find_game_object_identifiers_by_name.py b/tests/test_find_game_object_identifiers_by_name.py index 9dccf8b..467be5b 100644 --- a/tests/test_find_game_object_identifiers_by_name.py +++ b/tests/test_find_game_object_identifiers_by_name.py @@ -1,6 +1,9 @@ import pytest -from plume_python.utils.game_object import find_identifiers_by_name, find_first_identifier_by_name +from plume_python.utils.game_object import ( + find_identifiers_by_name, + find_first_identifier_by_name, +) def test_find_first_game_object_identifier(): diff --git a/tests/test_parser.py b/tests/test_parser.py index eda071f..1ec67f4 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -22,7 +22,13 @@ def test_parse_record(): def test_simple_filtering(): record = cast(parser.Record, pytest.record) - filtered_lsl = [lsl_sample for lsl_sample in record[lsl_stream_pb2.StreamSample] if - 0 <= lsl_sample.timestamp <= 5_000_000_000] - filtered_transform_updates = [frame for frame in record[transform_pb2.TransformUpdate] if - 0 <= frame.timestamp <= 5_000_000_000] + filtered_lsl = [ + lsl_sample + for lsl_sample in record[lsl_stream_pb2.StreamSample] + if 0 <= lsl_sample.timestamp <= 5_000_000_000 + ] + filtered_transform_updates = [ + frame + for frame in record[transform_pb2.TransformUpdate] + if 0 <= frame.timestamp <= 5_000_000_000 + ]