Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
JoschD committed Aug 20, 2024
1 parent b4a5f64 commit b3848b2
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 8 deletions.
98 changes: 98 additions & 0 deletions tests/test_doros.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@

from datetime import datetime
from pathlib import Path

import numpy as np
import pandas as pd
import pytest
import h5py

from turn_by_turn.constants import PRINT_PRECISION
from turn_by_turn.errors import DataTypeError
from turn_by_turn.structures import TbtData, TransverseData
from tests.test_lhc_and_general import create_data, compare_tbt

from turn_by_turn.doros import N_ORBIT_SAMPLES, read_tbt, write_tbt, DEFAULT_BUNCH_ID, POSITIONS

INPUTS_DIR = Path(__file__).parent / "inputs"


def test_write_read(tmp_path):
tbt = _tbt_data()
file_path = tmp_path / "test_file.h5"
write_tbt(tbt, file_path)
new = read_tbt(file_path)
compare_tbt(tbt, new, no_binary=False)


def test_read_raises_different_bpm_lengths(tmp_path):
tbt = _tbt_data()
file_path = tmp_path / "test_file.h5"
write_tbt(tbt, file_path)

bpm = tbt.matrices[0].X.index[0]

# modify the BPM lengths in the file
with h5py.File(file_path, "r+") as h5f:
delta = 10
del h5f[bpm][N_ORBIT_SAMPLES]
h5f[bpm][N_ORBIT_SAMPLES] = [tbt.matrices[0].X.shape[1] - delta]
for key in POSITIONS.values():
data = h5f[bpm][key][:-delta]
del h5f[bpm][key]
h5f[bpm][key] = data

with pytest.raises(ValueError) as e:
read_tbt(file_path)
assert "Not all BPMs have the same number of turns!" in str(e)


def test_read_raises_on_different_bpm_lengths_in_data(tmp_path):
tbt = _tbt_data()
file_path = tmp_path / "test_file.h5"
write_tbt(tbt, file_path)

bpms = [tbt.matrices[0].X.index[i] for i in (0, 2)]

# modify the BPM lengths in the file
with h5py.File(file_path, "r+") as h5f:
for bpm in bpms:
del h5f[bpm][N_ORBIT_SAMPLES]
h5f[bpm][N_ORBIT_SAMPLES] = [tbt.matrices[0].X.shape[1] + 10]

with pytest.raises(ValueError) as e:
read_tbt(file_path)
assert "Found BPMs with different data lengths" in str(e)
assert all(bpm in str(e) for bpm in bpms)


def _tbt_data() -> TbtData:
"""TbT data for testing. Adding random noise, so that the data is different per BPM."""
nturns = 2000
bpms = ["TBPM1", "TBPM2", "TBPM3", "TBPM4"]

return TbtData(
matrices=[
TransverseData(
X=pd.DataFrame(
index=bpms,
data=create_data(
np.linspace(-np.pi, np.pi, nturns, endpoint=False),
nbpm=len(bpms), function=np.sin, noise=0.02
),
dtype=float,
),
Y=pd.DataFrame(
index=bpms,
data=create_data(
np.linspace(-np.pi, np.pi, nturns, endpoint=False),
nbpm=len(bpms), function=np.cos, noise=0.015
),
dtype=float,
),
)
],
date=datetime.now(),
bunch_ids=[DEFAULT_BUNCH_ID],
nturns=nturns,
)
4 changes: 2 additions & 2 deletions tests/test_lhc_and_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def compare_tbt(origin: TbtData, new: TbtData, no_binary: bool, max_deviation =
assert np.all(origin_mat == new_mat)


def create_data(phases, nbpm, function) -> np.ndarray:
return np.ones((nbpm, len(phases))) * function(phases)
def create_data(phases, nbpm, function, noise: float = 0) -> np.ndarray:
return np.ones((nbpm, len(phases))) * function(phases) + noise * np.random.randn(nbpm, len(phases))


@pytest.fixture()
Expand Down
13 changes: 7 additions & 6 deletions turn_by_turn/doros.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ def read_tbt(file_path: str|Path, bunch_id: int = DEFAULT_BUNCH_ID) -> TbtData:
file_path = Path(file_path)
LOGGER.debug(f"Reading DOROS file at path: '{file_path.absolute()}'")
with h5py.File(file_path, "r") as hdf_file:
bpm_names = [name for name in hdf_file.keys() if name != METADATA]
# use "/" to keep track of bpm order, see https://github.com/h5py/h5py/issues/1471
bpm_names = [name for name in hdf_file["/"].keys() if name != METADATA]

_check_data_lengths(hdf_file, bpm_names)

Expand Down Expand Up @@ -117,16 +118,16 @@ def write_tbt(tbt_data: TbtData, file_path: str|Path) -> None:
LOGGER.debug(f"Writing DOROS file at path: '{file_path.absolute()}'")

data = tbt_data.matrices[0]
with h5py.File(file_path, "w") as hdf_file:
with h5py.File(file_path, "w", track_order=True) as hdf_file:
hdf_file.create_group(METADATA)
for bpm in tbt_data.matrices[0].X.index:
for bpm in data.X.index:
hdf_file.create_group(bpm)
hdf_file[bpm].create_dataset(ACQ_STAMP, data=[tbt_data.date.timestamp() * 1e6])
hdf_file[bpm].create_dataset(BST_TIMESTAMP, data=[tbt_data.date.timestamp() * 1e6])

hdf_file[bpm].create_dataset(N_ORBIT_SAMPLES, data=[tbt_data.nturns])
hdf_file[bpm].create_dataset(POSITIONS["X"], data=data.X[bpm].values)
hdf_file[bpm].create_dataset(POSITIONS["Y"], data=data.Y[bpm].values)
hdf_file[bpm].create_dataset(POSITIONS["X"], data=data.X.loc[bpm, :].values)
hdf_file[bpm].create_dataset(POSITIONS["Y"], data=data.Y.loc[bpm, :].values)

hdf_file[bpm].create_dataset(N_OSCILLATION_SAMPLES, data=0)
hdf_file[bpm].create_dataset(OSCILLATIONS["X"], data=[DEFAULT_OSCILLATION_DATA])
Expand All @@ -150,6 +151,6 @@ def _check_data_lengths(hdf_file: h5py.File, bpm_names: str) -> None:
msg = f"Found BPMs with different data lengths than defined in '{N_ORBIT_SAMPLES}': {suspicious_bpms}"
raise ValueError(msg)

if not all_elements_equal(hdf_file[bpm][N_ORBIT_SAMPLES] for bpm in bpm_names):
if not all_elements_equal(hdf_file[bpm][N_ORBIT_SAMPLES][0] for bpm in bpm_names):
msg = "Not all BPMs have the same number of turns!"
raise ValueError(msg)

0 comments on commit b3848b2

Please sign in to comment.