Skip to content

Commit

Permalink
Format files using ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
imenelydiaker authored and cjaverliat committed Jun 23, 2024
1 parent 2700003 commit f317ad8
Show file tree
Hide file tree
Showing 22 changed files with 552 additions and 157 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -165,4 +165,4 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
#.idea/
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fixed a bug where extracting samples by time range would throw an exception if the record contained timeless samples.
- Fixed a bug where extracting samples by time range would throw an exception if the record contained timeless samples.
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -671,4 +671,4 @@ into proprietary programs. If your program is a subroutine library, you
may consider it more useful to permit linking proprietary applications with
the library. If this is what you want to do, use the GNU Lesser General
Public License instead of this License. But first, please read
<https://www.gnu.org/licenses/why-not-lgpl.html>.
<https://www.gnu.org/licenses/why-not-lgpl.html>.
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,8 @@ install:
tests:
@echo "--- 🧪 Running tests ---"
poetry run pytest

.PHONY: lint
lint:
@echo "--- 🧹 Linting code ---"
poetry run pre-commit run --all-files
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ For more advanced usage, the package can be imported in a Python script:
import plume_python as plm
from plume_python.utils.dataframe import samples_to_dataframe, record_to_dataframes
from plume_python.samples.unity import transform_pb2
from plume_python.export import xdf_exporter
from plume_python.export import xdf_exporter
from plume_python.utils.game_object import find_names_by_guid, find_first_identifier_by_name

# Load a record file
Expand Down Expand Up @@ -136,4 +136,4 @@ Sophie VILLENAVE - sophie.villenave@ec-lyon.fr
```

[Button Docs]: https://img.shields.io/badge/Explore%20the%20docs-%E2%86%92-brightgreen
[Explore the docs]: https://liris-xr.github.io/PLUME/
[Explore the docs]: https://liris-xr.github.io/PLUME/
7 changes: 0 additions & 7 deletions plume_python/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +0,0 @@
from . import file_reader
from . import parser
from . import record
from . import utils
from . import samples
from . import export
from . import cli
86 changes: 56 additions & 30 deletions plume_python/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,16 @@
from plume_python import parser
from plume_python.export.xdf_exporter import export_xdf_from_record
from plume_python.samples import sample_types_from_names
from plume_python.utils.dataframe import record_to_dataframes, samples_to_dataframe, world_transforms_to_dataframe
from plume_python.utils.game_object import find_names_by_guid, find_identifiers_by_name, \
find_identifier_by_game_object_id
from plume_python.utils.dataframe import (
record_to_dataframes,
samples_to_dataframe,
world_transforms_to_dataframe,
)
from plume_python.utils.game_object import (
find_names_by_guid,
find_identifiers_by_name,
find_identifier_by_game_object_id,
)
from plume_python.utils.transform import compute_transform_time_series


Expand All @@ -17,33 +24,37 @@ def cli():


@click.command()
@click.argument('record_path', type=click.Path(exists=True, readable=True))
@click.option('--xdf_output_path', type=click.Path(writable=True))
@click.argument("record_path", type=click.Path(exists=True, readable=True))
@click.option("--xdf_output_path", type=click.Path(writable=True))
def export_xdf(record_path: str, xdf_output_path: str | None):
"""Export a XDF file including LSL samples and markers."""
if not record_path.endswith('.plm'):
if not record_path.endswith(".plm"):
click.echo(err=True, message="Input file must be a .plm file")
return

if xdf_output_path is None:
xdf_output_path = record_path.replace('.plm', '.xdf')
xdf_output_path = record_path.replace(".plm", ".xdf")

if os.path.exists(xdf_output_path):
if not click.confirm(f"File '{xdf_output_path}' already exists, do you want to overwrite it?"):
if not click.confirm(
f"File '{xdf_output_path}' already exists, do you want to overwrite it?"
):
return

with open(xdf_output_path, "wb") as xdf_output_file:
record = parser.parse_record_from_file(record_path)
export_xdf_from_record(xdf_output_file, record)
click.echo('Exported xdf from record: ' + record_path + ' to ' + xdf_output_path)
click.echo(
"Exported xdf from record: " + record_path + " to " + xdf_output_path
)


@click.command()
@click.argument('record_path', type=click.Path(exists=True, readable=True))
@click.argument('guid', type=click.STRING)
@click.argument("record_path", type=click.Path(exists=True, readable=True))
@click.argument("guid", type=click.STRING)
def find_name(record_path: str, guid: str):
"""Find the name(s) of a GameObject with the given GUID in the record."""
if not record_path.endswith('.plm'):
if not record_path.endswith(".plm"):
click.echo(err=True, message="Input file must be a .plm file")
return

Expand All @@ -58,11 +69,11 @@ def find_name(record_path: str, guid: str):


@click.command()
@click.argument('record_path', type=click.Path(exists=True, readable=True))
@click.argument('name', type=click.STRING)
@click.argument("record_path", type=click.Path(exists=True, readable=True))
@click.argument("name", type=click.STRING)
def find_guid(record_path: str, name: str):
"""Find the GUID(s) of a GameObject by the given name."""
if not record_path.endswith('.plm'):
if not record_path.endswith(".plm"):
click.echo(err=True, message="Input file must be a .plm file")
return

Expand All @@ -77,11 +88,11 @@ def find_guid(record_path: str, name: str):


@click.command()
@click.argument('record_path', type=click.Path(exists=True, readable=True))
@click.argument('guid', type=click.STRING)
@click.argument("record_path", type=click.Path(exists=True, readable=True))
@click.argument("guid", type=click.STRING)
def export_world_transforms(record_path: str, guid: str):
"""Export world transforms of a GameObject with the given GUID to a CSV file."""
if not record_path.endswith('.plm'):
if not record_path.endswith(".plm"):
click.echo(err=True, message="Input file must be a .plm file")
return

Expand All @@ -94,38 +105,53 @@ def export_world_transforms(record_path: str, guid: str):

time_series = compute_transform_time_series(record, identifier.transform_id)
df = world_transforms_to_dataframe(time_series)
file_path = record_path.replace('.plm', f'_{guid}_world_transform.csv')
file_path = record_path.replace(".plm", f"_{guid}_world_transform.csv")
df.to_csv(file_path)


@click.command()
@click.argument('record_path', type=click.Path(exists=True, readable=True))
@click.argument('output_dir', type=click.Path(exists=True, writable=True))
@click.option('--filter', default="all", show_default=True, type=click.STRING,
help="Comma separated list of sample types to export (eg. 'TransformUpdate,GameObjectUpdate')")
@click.argument("record_path", type=click.Path(exists=True, readable=True))
@click.argument("output_dir", type=click.Path(exists=True, writable=True))
@click.option(
"--filter",
default="all",
show_default=True,
type=click.STRING,
help="Comma separated list of sample types to export (eg. 'TransformUpdate,GameObjectUpdate')",
)
def export_csv(record_path: str, output_dir: str | None, filter: str):
"""Export samples from the record to CSV files."""
if not record_path.endswith('.plm'):
if not record_path.endswith(".plm"):
click.echo(err=True, message="Input file must be a .plm file")
return

record = parser.parse_record_from_file(record_path)

filters = [d.strip() for d in filter.split(',')]
filters = [d.strip() for d in filter.split(",")]

if filters == ['all'] or filters == ['*']:
if filters == ["all"] or filters == ["*"]:
dataframes = record_to_dataframes(record)
for sample_type, df in dataframes.items():
file_path = os.path.join(output_dir, sample_type.__name__ + '.csv')
file_path = os.path.join(output_dir, sample_type.__name__ + ".csv")
df.to_csv(file_path)
click.echo('Exported CSV for sample type: ' + sample_type.__name__ + ' to ' + file_path)
click.echo(
"Exported CSV for sample type: "
+ sample_type.__name__
+ " to "
+ file_path
)
else:
sample_types = sample_types_from_names(filters)
for sample_type in sample_types:
df = samples_to_dataframe(record.get_samples_by_type(sample_type))
file_path = os.path.join(output_dir, sample_type.__name__ + '.csv')
file_path = os.path.join(output_dir, sample_type.__name__ + ".csv")
df.to_csv(file_path)
click.echo('Exported CSV for sample type: ' + sample_type.__name__ + ' to ' + file_path)
click.echo(
"Exported CSV for sample type: "
+ sample_type.__name__
+ " to "
+ file_path
)


cli.add_command(export_csv)
Expand Down
82 changes: 65 additions & 17 deletions plume_python/export/xdf_exporter.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,26 @@
from plume_python.export.xdf_writer import *
from plume_python.export.xdf_writer import (
STR_ENCODING,
write_file_header,
write_stream_header,
write_stream_sample,
write_stream_footer,
)
from plume_python.record import Record
from plume_python.samples.common import marker_pb2
from plume_python.samples.lsl import lsl_stream_pb2
from typing import BinaryIO

import xml.etree.ElementTree as ET
import numpy as np


def export_xdf_from_record(output_file: BinaryIO, record: Record):
datetime_str = record.get_metadata().start_time.ToDatetime().astimezone().strftime('%Y-%m-%dT%H:%M:%S%z')
datetime_str = (
record.get_metadata()
.start_time.ToDatetime()
.astimezone()
.strftime("%Y-%m-%dT%H:%M:%S%z")
)
# Add a colon separator to the offset segment
datetime_str = "{0}:{1}".format(datetime_str[:-2], datetime_str[-2:])

Expand All @@ -26,7 +40,9 @@ def export_xdf_from_record(output_file: BinaryIO, record: Record):

for lsl_open_stream in record[lsl_stream_pb2.StreamOpen]:
xml_header = ET.fromstring(lsl_open_stream.payload.xml_header)
stream_id = np.uint64(lsl_open_stream.payload.stream_id) + 1 # reserve id = 1 for the marker stream
stream_id = (
np.uint64(lsl_open_stream.payload.stream_id) + 1
) # reserve id = 1 for the marker stream
channel_format = xml_header.find("channel_format").text
stream_channel_format[stream_id] = channel_format
stream_min_time[stream_id] = None
Expand All @@ -50,19 +66,33 @@ def export_xdf_from_record(output_file: BinaryIO, record: Record):
stream_sample_count[stream_id] += 1

if channel_format == "string":
val = np.array([x for x in lsl_sample.payload.string_values.value], dtype=np.str_)
val = np.array(
[x for x in lsl_sample.payload.string_values.value], dtype=np.str_
)
elif channel_format == "int8":
val = np.array([x for x in lsl_sample.payload.int8_values.value], dtype=np.int8)
val = np.array(
[x for x in lsl_sample.payload.int8_values.value], dtype=np.int8
)
elif channel_format == "int16":
val = np.array([x for x in lsl_sample.payload.int16_values.value], dtype=np.int16)
val = np.array(
[x for x in lsl_sample.payload.int16_values.value], dtype=np.int16
)
elif channel_format == "int32":
val = np.array([x for x in lsl_sample.payload.int32_values.value], dtype=np.int32)
val = np.array(
[x for x in lsl_sample.payload.int32_values.value], dtype=np.int32
)
elif channel_format == "int64":
val = np.array([x for x in lsl_sample.payload.int64_values.value], dtype=np.int64)
val = np.array(
[x for x in lsl_sample.payload.int64_values.value], dtype=np.int64
)
elif channel_format == "float32":
val = np.array([x for x in lsl_sample.payload.float_values.value], dtype=np.float32)
val = np.array(
[x for x in lsl_sample.payload.float_values.value], dtype=np.float32
)
elif channel_format == "double64":
val = np.array([x for x in lsl_sample.payload.double_values.value], dtype=np.float64)
val = np.array(
[x for x in lsl_sample.payload.double_values.value], dtype=np.float64
)
else:
raise ValueError(f"Unsupported channel format: {channel_format}")

Expand All @@ -71,9 +101,15 @@ def export_xdf_from_record(output_file: BinaryIO, record: Record):
for marker_sample in record[marker_pb2.Marker]:
t = marker_sample.timestamp / 1_000_000_000.0 # convert time to seconds

if stream_min_time[marker_stream_id] is None or t < stream_min_time[marker_stream_id]:
if (
stream_min_time[marker_stream_id] is None
or t < stream_min_time[marker_stream_id]
):
stream_min_time[marker_stream_id] = t
if stream_max_time[marker_stream_id] is None or t > stream_max_time[marker_stream_id]:
if (
stream_max_time[marker_stream_id] is None
or t > stream_max_time[marker_stream_id]
):
stream_max_time[marker_stream_id] = t

if marker_stream_id not in stream_sample_count:
Expand All @@ -87,13 +123,23 @@ def export_xdf_from_record(output_file: BinaryIO, record: Record):
for lsl_close_stream in record[lsl_stream_pb2.StreamClose]:
stream_id = np.uint64(lsl_close_stream.payload.stream_id) + 1
sample_count = stream_sample_count[stream_id]
write_stream_footer(output_file, stream_min_time[stream_id], stream_max_time[stream_id], sample_count,
stream_id)
write_stream_footer(
output_file,
stream_min_time[stream_id],
stream_max_time[stream_id],
sample_count,
stream_id,
)

# Write marker stream footer
# stream_id = 1 is reserved for the marker stream
write_stream_footer(output_file, stream_min_time[marker_stream_id], stream_max_time[marker_stream_id],
stream_sample_count[marker_stream_id], marker_stream_id)
write_stream_footer(
output_file,
stream_min_time[marker_stream_id],
stream_max_time[marker_stream_id],
stream_sample_count[marker_stream_id],
marker_stream_id,
)


def write_marker_stream_header(output_buf, marker_stream_id):
Expand All @@ -109,4 +155,6 @@ def write_marker_stream_header(output_buf, marker_stream_id):
channel_count_el.text = "1"
nominal_srate_el.text = "0.0"
xml = ET.tostring(info_el, encoding=STR_ENCODING, xml_declaration=True)
write_stream_header(output_buf, xml, marker_stream_id) # stream_id = 1 is reserved for the marker stream
write_stream_header(
output_buf, xml, marker_stream_id
) # stream_id = 1 is reserved for the marker stream
Loading

0 comments on commit f317ad8

Please sign in to comment.