diff --git a/yt/frontends/rockstar/data_structures.py b/yt/frontends/rockstar/data_structures.py index e93ec9c0b56..3d613d3b5fd 100644 --- a/yt/frontends/rockstar/data_structures.py +++ b/yt/frontends/rockstar/data_structures.py @@ -1,5 +1,7 @@ import glob import os +from functools import cached_property +from typing import Any, Optional import numpy as np @@ -9,22 +11,58 @@ from yt.geometry.particle_geometry_handler import ParticleIndex from yt.utilities import fortran_utils as fpu from yt.utilities.cosmology import Cosmology +from yt.utilities.exceptions import YTFieldNotFound from .definitions import header_dt from .fields import RockstarFieldInfo class RockstarBinaryFile(HaloCatalogFile): + header: dict + _position_offset: int + _member_offset: int + _Npart: "np.ndarray[Any, np.dtype[np.int64]]" + _ids_halos: list[int] + _file_size: int + def __init__(self, ds, io, filename, file_id, range): with open(filename, "rb") as f: self.header = fpu.read_cattrs(f, header_dt, "=") self._position_offset = f.tell() + pcount = self.header["num_halos"] + + halos = np.fromfile(f, dtype=io._halo_dt, count=pcount) + self._member_offset = f.tell() + self._ids_halos = list(halos["particle_identifier"]) + self._Npart = halos["num_p"] + f.seek(0, os.SEEK_END) self._file_size = f.tell() + expected_end = self._member_offset + 8 * self._Npart.sum() + if expected_end != self._file_size: + raise RuntimeError( + f"File size {self._file_size} does not match expected size {expected_end}." + ) + super().__init__(ds, io, filename, file_id, range) - def _read_particle_positions(self, ptype, f=None): + def _read_member( + self, ihalo: int + ) -> Optional["np.ndarray[Any, np.dtype[np.int64]]"]: + if ihalo not in self._ids_halos: + return None + + ind_halo = self._ids_halos.index(ihalo) + + ipos = self._member_offset + 8 * self._Npart[:ind_halo].sum() + + with open(self.filename, "rb") as f: + f.seek(ipos, os.SEEK_SET) + ids = np.fromfile(f, dtype=np.int64, count=self._Npart[ind_halo]) + return ids + + def _read_particle_positions(self, ptype: str, f=None): """ Read all particle positions in this file. """ @@ -48,8 +86,18 @@ def _read_particle_positions(self, ptype, f=None): return pos +class RockstarIndex(ParticleIndex): + def get_member(self, ihalo: int): + for df in self.data_files: + members = df._read_member(ihalo) + if members is not None: + return members + + raise RuntimeError(f"Could not find halo {ihalo} in any data file.") + + class RockstarDataset(ParticleDataset): - _index_class = ParticleIndex + _index_class = RockstarIndex _file_class = RockstarBinaryFile _field_info_class = RockstarFieldInfo _suffix = ".bin" @@ -122,3 +170,71 @@ def _is_valid(cls, filename: str, *args, **kwargs) -> bool: return False else: return header["magic"] == 18077126535843729616 + + def halo(self, ptype, particle_identifier): + return RockstarHaloContainer( + ptype, + particle_identifier, + parent_ds=None, + halo_ds=self, + ) + + +class RockstarHaloContainer: + def __init__(self, ptype, particle_identifier, *, parent_ds, halo_ds): + if ptype not in halo_ds.particle_types_raw: + raise RuntimeError( + f'Possible halo types are {halo_ds.particle_types_raw}, supplied "{ptype}".' + ) + + self.ds = parent_ds + self.halo_ds = halo_ds + self.ptype = ptype + self.particle_identifier = particle_identifier + + def __repr__(self): + return f"{self.halo_ds}_{self.ptype}_{self.particle_identifier:09d}" + + def __getitem__(self, key): + if isinstance(key, tuple): + ptype, field = key + else: + ptype = self.ptype + field = key + + data = { + "mass": self.mass, + "position": self.position, + "velocity": self.velocity, + "member_ids": self.member_ids, + } + if ptype == "halos" and field in data: + return data[field] + + raise YTFieldNotFound((ptype, field), dataset=self.ds) + + @cached_property + def ihalo(self): + halo_id = self.particle_identifier + halo_ids = list(self.halo_ds.r["halos", "particle_identifier"].astype("i8")) + ihalo = halo_ids.index(halo_id) + + assert halo_ids[ihalo] == halo_id + + return ihalo + + @property + def mass(self): + return self.halo_ds.r["halos", "particle_mass"][self.ihalo] + + @property + def position(self): + return self.halo_ds.r["halos", "particle_position"][self.ihalo] + + @property + def velocity(self): + return self.halo_ds.r["halos", "particle_velocity"][self.ihalo] + + @property + def member_ids(self): + return self.halo_ds.index.get_member(self.particle_identifier) diff --git a/yt/frontends/rockstar/tests/test_outputs.py b/yt/frontends/rockstar/tests/test_outputs.py index 920881ec1cc..0acf4dc1d72 100644 --- a/yt/frontends/rockstar/tests/test_outputs.py +++ b/yt/frontends/rockstar/tests/test_outputs.py @@ -38,3 +38,23 @@ def test_particle_selection(): ds = data_dir_load(r1) psc = ParticleSelectionComparison(ds) psc.run_defaults() + + +@requires_file(r1) +def test_halo_loading(): + ds = data_dir_load(r1) + + for halo_id, Npart in zip( + ds.r["halos", "particle_identifier"], + ds.r["halos", "num_p"], + ): + halo = ds.halo("halos", halo_id) + assert halo is not None + + # Try accessing properties + halo.position + halo.velocity + halo.mass + + # Make sure we can access the member particles + assert_equal(len(halo.member_ids), Npart)