diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d898dfa..d5b41c4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,7 @@ repos: - id: pyupgrade - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.8.0 + rev: v0.8.1 hooks: - id: ruff args: [--fix] diff --git a/config/_templates/dataset/carla.yaml b/config/_templates/dataset/carla.yaml index 05f940b..4245c12 100644 --- a/config/_templates/dataset/carla.yaml +++ b/config/_templates/dataset/carla.yaml @@ -127,13 +127,4 @@ inputs: _target_: rbyte.io.DataFrameFilter predicate: | `control.throttle` > 0.5 - - - _target_: pipefunc.PipeFunc - renames: - input: data_filtered - output_name: samples - func: - _target_: rbyte.RollingWindowSampleBuilder - index_column: _idx_ - period: 1i #@ end diff --git a/config/_templates/dataset/mimicgen.yaml b/config/_templates/dataset/mimicgen.yaml index 2c321e9..b897072 100644 --- a/config/_templates/dataset/mimicgen.yaml +++ b/config/_templates/dataset/mimicgen.yaml @@ -58,15 +58,5 @@ inputs: func: _target_: rbyte.io.DataFrameConcater method: vertical - - - _target_: pipefunc.PipeFunc - renames: - input: data_concated - output_name: samples - func: - _target_: rbyte.RollingWindowSampleBuilder - index_column: _idx_ - period: 1i - #@ end #@ end diff --git a/config/_templates/dataset/nuscenes/mcap.yaml b/config/_templates/dataset/nuscenes/mcap.yaml index beb0380..2988694 100644 --- a/config/_templates/dataset/nuscenes/mcap.yaml +++ b/config/_templates/dataset/nuscenes/mcap.yaml @@ -105,13 +105,4 @@ inputs: _target_: rbyte.io.DataFrameFilter predicate: | `/odom/vel.x` >= 8 - - - _target_: pipefunc.PipeFunc - renames: - input: data_filtered - output_name: samples - func: - _target_: rbyte.RollingWindowSampleBuilder - index_column: (@=camera_topics.values()[0]@)/_idx_ - period: 1i #@ end diff --git a/config/_templates/dataset/nuscenes/rrd.yaml b/config/_templates/dataset/nuscenes/rrd.yaml index 308412a..a0af322 100644 --- a/config/_templates/dataset/nuscenes/rrd.yaml +++ b/config/_templates/dataset/nuscenes/rrd.yaml @@ -99,13 +99,4 @@ inputs: _target_: rbyte.io.DataFrameFilter predicate: | `/world/ego_vehicle/CAM_FRONT/timestamp` between '2018-07-24 03:28:48' and '2018-07-24 03:28:50' - - - _target_: pipefunc.PipeFunc - renames: - input: data_filtered - output_name: samples - func: - _target_: rbyte.RollingWindowSampleBuilder - index_column: (@=camera_entities.values()[0]@)/_idx_ - period: 1i #@ end diff --git a/pyproject.toml b/pyproject.toml index 5b5498d..c4b2fa2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "rbyte" -version = "0.9.1" +version = "0.10.0" description = "Multimodal PyTorch dataset library" authors = [{ name = "Evgenii Gorchakov", email = "evgenii@yaak.ai" }] maintainers = [{ name = "Evgenii Gorchakov", email = "evgenii@yaak.ai" }] @@ -8,14 +8,13 @@ dependencies = [ "tensordict>=0.6.2", "torch", "numpy", - "polars>=1.15.0", + "polars>=1.16.0", "pydantic>=2.10.2", "more-itertools>=10.5.0", "hydra-core>=1.3.2", "optree>=0.13.1", "cachetools>=5.5.0", "diskcache>=5.6.3", - "jaxtyping>=0.2.34", "parse>=1.20.2", "structlog>=24.4.0", "xxhash>=3.5.0", @@ -40,7 +39,7 @@ repo = "https://github.com/yaak-ai/rbyte" [project.optional-dependencies] build = ["hatchling>=1.25.0", "grpcio-tools>=1.62.0", "protoletariat==3.2.19"] -visualize = ["rerun-sdk[notebook]>=0.20.0"] +visualize = ["rerun-sdk[notebook]>=0.20.2"] mcap = [ "mcap>=1.2.1", "mcap-ros2-support>=0.5.5", @@ -54,7 +53,7 @@ video = [ "video-reader-rs>=0.2.1", ] hdf5 = ["h5py>=3.12.1"] -rrd = ["rerun-sdk>=0.20.0", "pyarrow-stubs"] +rrd = ["rerun-sdk>=0.20.2", "pyarrow-stubs"] [project.scripts] rbyte-visualize = 'rbyte.scripts.visualize:main' @@ -72,7 +71,7 @@ dev-dependencies = [ "wat-inspector>=0.4.3", "lovely-tensors>=0.1.18", "pudb>=2024.1.2", - "ipython>=8.29.0", + "ipython>=8.30.0", "ipython-autoimport>=0.5", "pytest>=8.3.3", "testbook>=0.4.2", @@ -129,7 +128,18 @@ skip-magic-trailing-comma = true preview = true select = ["ALL"] fixable = ["ALL"] -ignore = ["A001", "A002", "D", "CPY", "COM812", "F722", "PD901", "ISC001", "TD"] +ignore = [ + "A001", + "A002", + "D", + "CPY", + "COM812", + "F722", + "PD901", + "ISC001", + "TD", + "TC006", +] [tool.ruff.lint.isort] split-on-trailing-comma = false diff --git a/src/rbyte/io/_mcap/tensor_source.py b/src/rbyte/io/_mcap/tensor_source.py index d0c75a2..394e60a 100644 --- a/src/rbyte/io/_mcap/tensor_source.py +++ b/src/rbyte/io/_mcap/tensor_source.py @@ -2,12 +2,12 @@ from dataclasses import dataclass from functools import cached_property from mmap import ACCESS_READ, mmap -from typing import IO, override +from operator import itemgetter +from typing import IO, final, override import more_itertools as mit import numpy.typing as npt import torch -from jaxtyping import Shaped from mcap.data_stream import ReadDataStream from mcap.decoder import DecoderFactory from mcap.opcode import Opcode @@ -32,6 +32,7 @@ class MessageIndex: message_length: int +@final class McapTensorSource(TensorSource): @validate_call(config=BaseModel.model_config) def __init__( @@ -47,8 +48,8 @@ def __init__( with bound_contextvars( path=path.as_posix(), topic=topic, message_decoder_factory=decoder_factory ): - self._path: FilePath = path - self._validate_crcs: bool = validate_crcs + self._path = path + self._validate_crcs = validate_crcs summary = SeekingReader( stream=self._file, validate_crcs=self._validate_crcs @@ -73,13 +74,14 @@ def __init__( logger.error(msg := "missing message decoder") raise RuntimeError(msg) - self._message_decoder: Callable[[bytes], object] = message_decoder - self._chunk_indexes: tuple[ChunkIndex, ...] = tuple( + self._message_decoder = message_decoder + self._chunk_indexes = tuple( chunk_index for chunk_index in summary.chunk_indexes if self._channel.id in chunk_index.message_index_offsets ) - self._decoder: Callable[[bytes], npt.ArrayLike] = decoder + self._decoder = decoder + self._mmap = None @property def _file(self) -> IO[bytes]: @@ -89,9 +91,7 @@ def _file(self) -> IO[bytes]: case None | mmap(closed=True): with self._path.open("rb") as f: - self._mmap: mmap = mmap( - fileno=f.fileno(), length=0, access=ACCESS_READ - ) + self._mmap = mmap(fileno=f.fileno(), length=0, access=ACCESS_READ) case _: raise RuntimeError @@ -99,32 +99,47 @@ def _file(self) -> IO[bytes]: return self._mmap # pyright: ignore[reportReturnType] @override - def __getitem__(self, indexes: Iterable[int]) -> Shaped[Tensor, "b h w c"]: - frames: Mapping[int, npt.ArrayLike] = {} - - message_indexes_by_chunk_start_offset: Mapping[ - int, Iterable[tuple[int, MessageIndex]] - ] = mit.map_reduce( - zip(indexes, (self._message_indexes[idx] for idx in indexes), strict=True), - keyfunc=lambda x: x[1].chunk_start_offset, - ) + def __getitem__(self, indexes: int | Iterable[int]) -> Tensor: + match indexes: + case Iterable(): + arrays: Mapping[int, npt.ArrayLike] = {} + message_indexes = (self._message_indexes[idx] for idx in indexes) + indexes_by_chunk_start_offset = mit.map_reduce( + zip(indexes, message_indexes, strict=True), + keyfunc=lambda x: x[1].chunk_start_offset, + ) + + for chunk_start_offset, chunk_indexes in sorted( + indexes_by_chunk_start_offset.items(), key=itemgetter(0) + ): + _ = self._file.seek(chunk_start_offset + 1 + 8) + chunk = Chunk.read(ReadDataStream(self._file)) + stream, _ = get_chunk_data_stream( + chunk, validate_crc=self._validate_crcs + ) + for index, message_index in sorted( + chunk_indexes, key=lambda x: x[1].message_start_offset + ): + stream.read(message_index.message_start_offset - stream.count) # pyright: ignore[reportUnusedCallResult] + message = Message.read(stream, message_index.message_length) + decoded_message = self._message_decoder(message.data) + arrays[index] = self._decoder(decoded_message.data) - for ( - chunk_start_offset, - chunk_message_indexes, - ) in message_indexes_by_chunk_start_offset.items(): - self._file.seek(chunk_start_offset + 1 + 8) # pyright: ignore[reportUnusedCallResult] - chunk = Chunk.read(ReadDataStream(self._file)) - stream, _ = get_chunk_data_stream(chunk, validate_crc=self._validate_crcs) - for frame_index, message_index in sorted( - chunk_message_indexes, key=lambda x: x[1].message_start_offset - ): - stream.read(message_index.message_start_offset - stream.count) # pyright: ignore[reportUnusedCallResult] - message = Message.read(stream, message_index.message_length) + tensors = [torch.from_numpy(arrays[idx]) for idx in indexes] # pyright: ignore[reportUnknownMemberType] + + return torch.stack(tensors) + + case _: + message_index = self._message_indexes[indexes] + _ = self._file.seek(message_index.chunk_start_offset + 1 + 8) + chunk = Chunk.read(ReadDataStream(self._file)) + stream, _ = get_chunk_data_stream(chunk, self._validate_crcs) + _ = stream.read(message_index.message_start_offset - stream.count) + message = Message.read(stream, length=message_index.message_length) decoded_message = self._message_decoder(message.data) - frames[frame_index] = self._decoder(decoded_message.data) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue] + array = self._decoder(decoded_message.data) - return torch.stack([torch.from_numpy(frames[idx]) for idx in indexes]) # pyright: ignore[reportUnknownMemberType] + return torch.from_numpy(array) # pyright: ignore[reportUnknownMemberType] @override def __len__(self) -> int: diff --git a/src/rbyte/io/_numpy/tensor_source.py b/src/rbyte/io/_numpy/tensor_source.py index 2984eaa..550c3a3 100644 --- a/src/rbyte/io/_numpy/tensor_source.py +++ b/src/rbyte/io/_numpy/tensor_source.py @@ -2,7 +2,7 @@ from functools import cached_property from os import PathLike from pathlib import Path -from typing import TYPE_CHECKING, override +from typing import final, override import numpy as np import torch @@ -16,10 +16,8 @@ from rbyte.io.base import TensorSource from rbyte.utils.tensor import pad_sequence -if TYPE_CHECKING: - from types import EllipsisType - +@final class NumpyTensorSource(TensorSource): @validate_call(config=BaseModel.model_config) def __init__( @@ -27,23 +25,28 @@ def __init__( ) -> None: super().__init__() - self._path: Path = Path(path) - self._select: Sequence[str] | EllipsisType = select or ... + self._path = Path(path) + self._select = select or ... @cached_property def _path_posix(self) -> str: return self._path.resolve().as_posix() + def _getitem(self, index: object) -> Tensor: + path = self._path_posix.format(index) + array = structured_to_unstructured(np.load(path)[self._select]) # pyright: ignore[reportUnknownVariableType] + return torch.from_numpy(np.ascontiguousarray(array)) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType] + @override - def __getitem__(self, indexes: Iterable[object]) -> Tensor: - tensors: list[Tensor] = [] - for index in indexes: - path = self._path_posix.format(index) - array = structured_to_unstructured(np.load(path)[self._select]) # pyright: ignore[reportUnknownVariableType] - tensor = torch.from_numpy(np.ascontiguousarray(array)) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType] - tensors.append(tensor) - - return pad_sequence(list(tensors), dim=0, value=torch.nan) + def __getitem__(self, indexes: object | Iterable[object]) -> Tensor: + match indexes: + case Iterable(): + tensors = map(self._getitem, indexes) # pyright: ignore[reportUnknownArgumentType] + + return pad_sequence(list(tensors), dim=0, value=torch.nan) + + case _: + return self._getitem(indexes) @override def __len__(self) -> int: diff --git a/src/rbyte/io/base.py b/src/rbyte/io/base.py index a67a9d6..6ac02a3 100644 --- a/src/rbyte/io/base.py +++ b/src/rbyte/io/base.py @@ -1,10 +1,10 @@ -from collections.abc import Iterable -from typing import Any, Protocol, runtime_checkable +from collections.abc import Sequence +from typing import Protocol, runtime_checkable from torch import Tensor @runtime_checkable class TensorSource(Protocol): - def __getitem__(self, indexes: Iterable[Any]) -> Tensor: ... + def __getitem__[T](self, indexes: T | Sequence[T]) -> Tensor: ... def __len__(self) -> int: ... diff --git a/src/rbyte/io/dataframe/aligner.py b/src/rbyte/io/dataframe/aligner.py index 7c10dcd..ac0e34b 100644 --- a/src/rbyte/io/dataframe/aligner.py +++ b/src/rbyte/io/dataframe/aligner.py @@ -40,7 +40,7 @@ class MergeConfig(BaseModel): columns: OrderedDict[str, ColumnMergeConfig] = Field(default_factory=OrderedDict) -type Fields = MergeConfig | OrderedDict[str, "Fields"] +type Fields = MergeConfig | OrderedDict[str, Fields] @final diff --git a/src/rbyte/io/hdf5/dataframe_builder.py b/src/rbyte/io/hdf5/dataframe_builder.py index 7d9e44f..3f59a7e 100644 --- a/src/rbyte/io/hdf5/dataframe_builder.py +++ b/src/rbyte/io/hdf5/dataframe_builder.py @@ -13,7 +13,7 @@ ) from pydantic import ConfigDict, validate_call -type Fields = Mapping[str, PolarsDataType | None] | Mapping[str, "Fields"] +type Fields = Mapping[str, PolarsDataType | None] | Mapping[str, Fields] @final diff --git a/src/rbyte/io/hdf5/tensor_source.py b/src/rbyte/io/hdf5/tensor_source.py index 592bb1c..15f2edf 100644 --- a/src/rbyte/io/hdf5/tensor_source.py +++ b/src/rbyte/io/hdf5/tensor_source.py @@ -1,23 +1,22 @@ -from collections.abc import Iterable -from typing import cast, override +from collections.abc import Sequence +from typing import cast, final, override import torch from h5py import Dataset, File -from jaxtyping import UInt8 from pydantic import FilePath, validate_call from torch import Tensor from rbyte.io.base import TensorSource +@final class Hdf5TensorSource(TensorSource): @validate_call def __init__(self, path: FilePath, key: str) -> None: - file = File(path) - self._dataset: Dataset = cast(Dataset, file[key]) + self._dataset = cast(Dataset, File(path)[key]) @override - def __getitem__(self, indexes: Iterable[int]) -> UInt8[Tensor, "b h w c"]: + def __getitem__(self, indexes: int | Sequence[int]) -> Tensor: return torch.from_numpy(self._dataset[indexes]) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType] @override diff --git a/src/rbyte/io/path/tensor_source.py b/src/rbyte/io/path/tensor_source.py index 8f74a19..821e31f 100644 --- a/src/rbyte/io/path/tensor_source.py +++ b/src/rbyte/io/path/tensor_source.py @@ -2,17 +2,17 @@ from functools import cached_property from os import PathLike from pathlib import Path -from typing import Any, override +from typing import final, override import numpy.typing as npt import torch -from jaxtyping import UInt8 from pydantic import validate_call from torch import Tensor from rbyte.io.base import TensorSource +@final class PathTensorSource(TensorSource): @validate_call def __init__( @@ -20,8 +20,8 @@ def __init__( ) -> None: super().__init__() - self._path: Path = Path(path) - self._decoder: Callable[[bytes], npt.ArrayLike] = decoder + self._path = Path(path) + self._decoder = decoder @cached_property def _path_posix(self) -> str: @@ -31,13 +31,21 @@ def _decode(self, path: str) -> npt.ArrayLike: with Path(path).open("rb") as f: return self._decoder(f.read()) + def _getitem(self, index: object) -> Tensor: + path = self._path_posix.format(index) + array = self._decode(path) + return torch.from_numpy(array) # pyright: ignore[reportUnknownMemberType] + @override - def __getitem__(self, indexes: Iterable[Any]) -> UInt8[Tensor, "b h w c"]: - paths = map(self._path_posix.format, indexes) - arrays = map(self._decode, paths) - tensors = map(torch.from_numpy, arrays) # pyright: ignore[reportUnknownArgumentType, reportUnknownMemberType] + def __getitem__(self, indexes: object | Iterable[object]) -> Tensor: + match indexes: + case Iterable(): + tensors = map(self._getitem, indexes) # pyright: ignore[reportUnknownArgumentType] + + return torch.stack(list(tensors)) - return torch.stack(list(tensors)) + case _: + return self._getitem(indexes) @override def __len__(self) -> int: diff --git a/src/rbyte/io/rrd/frame_source.py b/src/rbyte/io/rrd/frame_source.py index 1857fb3..c87aa51 100644 --- a/src/rbyte/io/rrd/frame_source.py +++ b/src/rbyte/io/rrd/frame_source.py @@ -1,11 +1,10 @@ from collections.abc import Callable, Iterable -from typing import cast, override +from typing import cast, final, override import numpy.typing as npt import polars as pl import rerun as rr import torch -from jaxtyping import UInt8 from pydantic import FilePath, validate_call from rerun.components import Blob from torch import Tensor @@ -14,6 +13,7 @@ from rbyte.io.base import TensorSource +@final class RrdFrameSource(TensorSource): @validate_call(config=BaseModel.model_config) def __init__( @@ -29,7 +29,7 @@ def __init__( reader = view.select(columns=[index, f"{entity_path}:{Blob.__name__}"]) # WARN: RecordBatchReader does not support random seeking => storing in memory - self._series: pl.Series = ( + self._series = ( cast( pl.DataFrame, pl.from_arrow(reader.read_all(), rechunk=True), # pyright: ignore[reportUnknownMemberType] @@ -40,15 +40,23 @@ def __init__( .to_series(0) ) - self._decoder: Callable[[bytes], npt.ArrayLike] = decoder + self._decoder = decoder + + def _getitem(self, index: int) -> Tensor: + array = self._series[index].to_numpy(allow_copy=False) + array = self._decoder(array) + return torch.from_numpy(array) # pyright: ignore[reportUnknownMemberType] @override - def __getitem__(self, indexes: Iterable[int]) -> UInt8[Tensor, "b h w c"]: - arrays = (self._series[i].to_numpy(allow_copy=False) for i in indexes) - frames_np = map(self._decoder, arrays) - frames_tch = map(torch.from_numpy, frames_np) # pyright: ignore[reportUnknownArgumentType, reportUnknownMemberType] + def __getitem__(self, indexes: int | Iterable[int]) -> Tensor: + match indexes: + case Iterable(): + tensors = map(self._getitem, indexes) + + return torch.stack(list(tensors)) - return torch.stack(list(frames_tch)) + case _: + return self._getitem(indexes) @override def __len__(self) -> int: diff --git a/src/rbyte/io/video/ffmpeg_source.py b/src/rbyte/io/video/ffmpeg_source.py index 754352e..64c9c7a 100644 --- a/src/rbyte/io/video/ffmpeg_source.py +++ b/src/rbyte/io/video/ffmpeg_source.py @@ -1,10 +1,8 @@ from collections.abc import Iterable from pathlib import Path -from typing import cast, override +from typing import final, override -import numpy.typing as npt import torch -from jaxtyping import UInt8 from pydantic import FilePath, NonNegativeInt, validate_call from torch import Tensor from video_reader import ( @@ -14,6 +12,7 @@ from rbyte.io.base import TensorSource +@final class FfmpegFrameSource(TensorSource): @validate_call def __init__( @@ -31,10 +30,15 @@ def __init__( ) @override - def __getitem__(self, indexes: Iterable[int]) -> UInt8[Tensor, "b h w c"]: - batch = cast(npt.ArrayLike, self._reader.get_batch(indexes)) # pyright: ignore[reportUnknownMemberType] + def __getitem__(self, indexes: int | Iterable[int]) -> Tensor: + match indexes: + case Iterable(): + array = self._reader.get_batch(indexes) # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] - return torch.from_numpy(batch) # pyright: ignore[reportUnknownMemberType] + case _: + array = self._reader.get_batch([indexes])[0] # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] + + return torch.from_numpy(array) # pyright: ignore[reportUnknownArgumentType, reportUnknownMemberType] @override def __len__(self) -> int: diff --git a/src/rbyte/io/video/vali_source.py b/src/rbyte/io/video/vali_source.py index 8f95a3a..047b689 100644 --- a/src/rbyte/io/video/vali_source.py +++ b/src/rbyte/io/video/vali_source.py @@ -1,12 +1,11 @@ from collections.abc import Iterable, Mapping from functools import cached_property from itertools import pairwise -from typing import Annotated, override +from typing import Annotated, final, override import more_itertools as mit import python_vali as vali import torch -from jaxtyping import Shaped from pydantic import BeforeValidator, FilePath, NonNegativeInt, validate_call from structlog import get_logger from torch import Tensor @@ -24,6 +23,7 @@ ] +@final class ValiGpuFrameSource(TensorSource): @validate_call(config=BaseModel.model_config) def __init__( @@ -37,13 +37,13 @@ def __init__( ) -> None: super().__init__() - self._gpu_id: int = gpu_id + self._gpu_id = gpu_id - self._decoder: vali.PyDecoder = vali.PyDecoder( + self._decoder = vali.PyDecoder( input=path.resolve().as_posix(), opts={}, gpu_id=self._gpu_id ) - self._pixel_format_chain: tuple[PixelFormat, ...] = ( + self._pixel_format_chain = ( (self._decoder.Format, *pixel_format_chain) if mit.first(pixel_format_chain, default=None) != self._decoder.Format else pixel_format_chain @@ -72,9 +72,7 @@ def _surfaces(self) -> Mapping[vali.PixelFormat, vali.Surface]: for pixel_format in self._pixel_format_chain } - def _read_frame( - self, index: int - ) -> Shaped[Tensor, "c h w"] | Shaped[Tensor, "h w c"]: + def _read_frame(self, index: int) -> Tensor: seek_ctx = vali.SeekContext(seek_frame=index) success, details = self._decoder.DecodeSingleSurface( # pyright: ignore[reportUnknownMemberType] self._surfaces[self._decoder.Format], seek_ctx @@ -103,10 +101,15 @@ def _read_frame( return torch.from_dlpack(surface).clone().detach() # pyright: ignore[reportPrivateImportUsage] @override - def __getitem__( - self, indexes: Iterable[int] - ) -> Shaped[Tensor, "b h w c"] | Shaped[Tensor, "b c h w"]: - return torch.stack([self._read_frame(index) for index in indexes]) + def __getitem__(self, indexes: int | Iterable[int]) -> Tensor: + match indexes: + case Iterable(): + frames = map(self._read_frame, indexes) + + return torch.stack(list(frames)) + + case int(): + return self._read_frame(indexes) @override def __len__(self) -> int: diff --git a/src/rbyte/utils/tensor.py b/src/rbyte/utils/tensor.py index fa58a40..b4bf0ee 100644 --- a/src/rbyte/utils/tensor.py +++ b/src/rbyte/utils/tensor.py @@ -3,18 +3,17 @@ import more_itertools as mit import torch import torch.nn.functional as F # noqa: N812 -from jaxtyping import Float from torch import Tensor def pad_dim( - input: Float[Tensor, "..."], + input: Tensor, *, pad: tuple[int, int], dim: int, mode: str = "constant", value: float | None = None, -) -> Float[Tensor, "..."]: +) -> Tensor: _pad = [(0, 0) for _ in input.shape] _pad[dim] = pad _pad = list(mit.flatten(reversed(_pad))) @@ -22,9 +21,7 @@ def pad_dim( return F.pad(input, _pad, mode=mode, value=value) -def pad_sequence( - sequences: Sequence[Float[Tensor, "..."]], dim: int, value: float = 0.0 -) -> Float[Tensor, "..."]: +def pad_sequence(sequences: Sequence[Tensor], dim: int, value: float = 0.0) -> Tensor: max_length = max(sequence.shape[dim] for sequence in sequences) padded = ( diff --git a/src/rbyte/viz/loggers/rerun_logger.py b/src/rbyte/viz/loggers/rerun_logger.py index 588320e..1d21811 100644 --- a/src/rbyte/viz/loggers/rerun_logger.py +++ b/src/rbyte/viz/loggers/rerun_logger.py @@ -1,5 +1,6 @@ from collections.abc import Iterable, Mapping, Sequence from functools import cache, cached_property +from math import prod from typing import Annotated, Any, Protocol, Self, cast, override, runtime_checkable import more_itertools as mit @@ -126,13 +127,13 @@ def _build_components( case rr.Points3D: match shape := array.shape: - case (n, 3): + case (3,): batch = rr.components.Position3DBatch(array) - case (s, n, 3): + case (*batch_dims, n, 3): batch = rr.components.Position3DBatch( array.reshape(-1, 3) - ).partition([n] * s) + ).partition([n] * prod(batch_dims)) case _: logger.debug("not implemented", shape=shape) @@ -151,11 +152,11 @@ def _build_components( image_format.color_model, array.shape, ): - case None, rr.ColorModel(), (_batch, height, width, _): + case None, rr.ColorModel(), (*batch_dims, height, width, _): pass - case rr.PixelFormat.NV12, None, (_batch, dim, width): - height = int(dim / 1.5) + case rr.PixelFormat.NV12, None, (*batch_dims, batch_dim, width): + height = int(batch_dim / 1.5) case _: logger.error("not implemented") @@ -169,11 +170,14 @@ def _build_components( color_model=image_format.color_model, channel_datatype=rr.ChannelDatatype.from_np_dtype(array.dtype), ) + + batch_dim = prod(batch_dims) + return [ mit.one(schema).indicator(), - rr.components.ImageFormatBatch([image_format] * _batch), + rr.components.ImageFormatBatch([image_format] * batch_dim), rr.components.ImageBufferBatch( - array.reshape(_batch, -1).view(np.uint8) + array.reshape(batch_dim, -1).view(np.uint8) ), ] @@ -187,7 +191,10 @@ def log(self, batch_idx: int, batch: Batch) -> None: for i, sample in enumerate(batch.data): # pyright: ignore[reportUnknownVariableType] with self._get_recording(batch.meta.input_id[i]): # pyright: ignore[reportUnknownArgumentType, reportIndexIssue] times: Sequence[TimeColumn] = [ - column(timeline=timeline, times=sample.get(timeline).numpy()) # pyright: ignore[reportUnknownMemberType, reportCallIssue] + column( + timeline=timeline, + times=np.atleast_1d(sample.get(timeline).numpy()), # pyright: ignore[reportUnknownArgumentType, reportUnknownMemberType, reportCallIssue] + ) for timeline, column in self._schema.times.items() ] diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index fcb3a9a..d5f8193 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -31,9 +31,9 @@ def test_mimicgen() -> None: match batch.to_dict(): case { "data": { - "obs/agentview_image": Tensor(shape=[c.B, _, *_]), - "_idx_": Tensor(shape=[c.B, _]), - "obs/robot0_eef_pos": Tensor(shape=[c.B, _, *_]), + "obs/agentview_image": Tensor(shape=[c.B, *_]), + "_idx_": Tensor(shape=[c.B, *_]), + "obs/robot0_eef_pos": Tensor(shape=[c.B, *_]), **data_rest, }, "meta": { @@ -77,14 +77,14 @@ def test_nuscenes_mcap() -> None: match batch.to_dict(): case { "data": { - "CAM_FRONT": Tensor(shape=[c.B, _, *_]), - "CAM_FRONT_LEFT": Tensor(shape=[c.B, _, *_]), - "CAM_FRONT_RIGHT": Tensor(shape=[c.B, _, *_]), - "/CAM_FRONT/image_rect_compressed/_idx_": Tensor(shape=[c.B, _]), - "/CAM_FRONT/image_rect_compressed/log_time": Tensor(shape=[c.B, _]), - "/CAM_FRONT_LEFT/image_rect_compressed/_idx_": Tensor(shape=[c.B, _]), - "/CAM_FRONT_RIGHT/image_rect_compressed/_idx_": Tensor(shape=[c.B, _]), - "/odom/vel.x": Tensor(shape=[c.B, _]), + "CAM_FRONT": Tensor(shape=[c.B, *_]), + "CAM_FRONT_LEFT": Tensor(shape=[c.B, *_]), + "CAM_FRONT_RIGHT": Tensor(shape=[c.B, *_]), + "/CAM_FRONT/image_rect_compressed/_idx_": Tensor(shape=[c.B, *_]), + "/CAM_FRONT/image_rect_compressed/log_time": Tensor(shape=[c.B, *_]), + "/CAM_FRONT_LEFT/image_rect_compressed/_idx_": Tensor(shape=[c.B, *_]), + "/CAM_FRONT_RIGHT/image_rect_compressed/_idx_": Tensor(shape=[c.B, *_]), + "/odom/vel.x": Tensor(shape=[c.B, *_]), **data_rest, }, "meta": { @@ -128,14 +128,14 @@ def test_nuscenes_rrd() -> None: match batch.to_dict(): case { "data": { - "CAM_FRONT": Tensor(shape=[c.B, _, *_]), - "CAM_FRONT_LEFT": Tensor(shape=[c.B, _, *_]), - "CAM_FRONT_RIGHT": Tensor(shape=[c.B, _, *_]), - "/world/ego_vehicle/CAM_FRONT/timestamp": Tensor(shape=[c.B, _, *_]), - "/world/ego_vehicle/CAM_FRONT/_idx_": Tensor(shape=[c.B, _, *_]), - "/world/ego_vehicle/CAM_FRONT_LEFT/_idx_": Tensor(shape=[c.B, _, *_]), - "/world/ego_vehicle/CAM_FRONT_RIGHT/_idx_": Tensor(shape=[c.B, _, *_]), - "/world/ego_vehicle/LIDAR_TOP/Position3D": Tensor(shape=[c.B, _, *_]), + "CAM_FRONT": Tensor(shape=[c.B, *_]), + "CAM_FRONT_LEFT": Tensor(shape=[c.B, *_]), + "CAM_FRONT_RIGHT": Tensor(shape=[c.B, *_]), + "/world/ego_vehicle/CAM_FRONT/timestamp": Tensor(shape=[c.B, *_]), + "/world/ego_vehicle/CAM_FRONT/_idx_": Tensor(shape=[c.B, *_]), + "/world/ego_vehicle/CAM_FRONT_LEFT/_idx_": Tensor(shape=[c.B, *_]), + "/world/ego_vehicle/CAM_FRONT_RIGHT/_idx_": Tensor(shape=[c.B, *_]), + "/world/ego_vehicle/LIDAR_TOP/Position3D": Tensor(shape=[c.B, *_]), **data_rest, }, "meta": { @@ -179,21 +179,21 @@ def test_yaak() -> None: match batch.to_dict(): case { "data": { - "cam_front_left": Tensor(shape=[c.B, _, *_]), - "cam_left_backward": Tensor(shape=[c.B, _, *_]), - "cam_right_backward": Tensor(shape=[c.B, _, *_]), - "meta/ImageMetadata.cam_front_left/frame_idx": Tensor(shape=[c.B, _]), - "meta/ImageMetadata.cam_front_left/time_stamp": Tensor(shape=[c.B, _]), + "cam_front_left": Tensor(shape=[c.B, *_]), + "cam_left_backward": Tensor(shape=[c.B, *_]), + "cam_right_backward": Tensor(shape=[c.B, *_]), + "meta/ImageMetadata.cam_front_left/frame_idx": Tensor(shape=[c.B, *_]), + "meta/ImageMetadata.cam_front_left/time_stamp": Tensor(shape=[c.B, *_]), "meta/ImageMetadata.cam_left_backward/frame_idx": Tensor( - shape=[c.B, _] + shape=[c.B, *_] ), "meta/ImageMetadata.cam_right_backward/frame_idx": Tensor( - shape=[c.B, _] + shape=[c.B, *_] ), - "meta/VehicleMotion/gear": Tensor(shape=[c.B, _]), - "meta/VehicleMotion/speed": Tensor(shape=[c.B, _]), - "mcap//ai/safety_score/clip.end_timestamp": Tensor(shape=[c.B, _]), - "mcap//ai/safety_score/score": Tensor(shape=[c.B, _]), + "meta/VehicleMotion/gear": Tensor(shape=[c.B, *_]), + "meta/VehicleMotion/speed": Tensor(shape=[c.B, *_]), + "mcap//ai/safety_score/clip.end_timestamp": Tensor(shape=[c.B, *_]), + "mcap//ai/safety_score/score": Tensor(shape=[c.B, *_]), **data_rest, }, "meta": { @@ -233,21 +233,21 @@ def test_zod() -> None: match batch.to_dict(): case { "data": { - "camera_front_blur": Tensor(shape=[c.B, _, *_]), - "camera_front_blur/timestamp": Tensor(shape=[c.B, _, *_]), - "lidar_velodyne": Tensor(shape=[c.B, _, *_]), - "lidar_velodyne/timestamp": Tensor(shape=[c.B, _, *_]), + "camera_front_blur": Tensor(shape=[c.B, *_]), + "camera_front_blur/timestamp": Tensor(shape=[c.B, *_]), + "lidar_velodyne": Tensor(shape=[c.B, *_]), + "lidar_velodyne/timestamp": Tensor(shape=[c.B, *_]), "vehicle_data/ego_vehicle_controls/acceleration_pedal/ratio/unitless/value": Tensor( # noqa: E501 - shape=[c.B, _, *_] + shape=[c.B, *_] ), "vehicle_data/ego_vehicle_controls/steering_wheel_angle/angle/radians/value": Tensor( # noqa: E501 - shape=[c.B, _, *_] + shape=[c.B, *_] ), "vehicle_data/ego_vehicle_controls/timestamp/nanoseconds/value": Tensor( - shape=[c.B, _, *_] + shape=[c.B, *_] ), "vehicle_data/satellite/speed/meters_per_second/value": Tensor( - shape=[c.B, _, *_] + shape=[c.B, *_] ), **data_rest, },