Skip to content

Commit

Permalink
make process type an enum
Browse files Browse the repository at this point in the history
  • Loading branch information
TjarkMiener committed Sep 18, 2024
1 parent a99bd9b commit cf12cd0
Showing 1 changed file with 20 additions and 16 deletions.
36 changes: 20 additions & 16 deletions dl1_data_handler/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from abc import abstractmethod
from collections import OrderedDict
from enum import Enum
import numpy as np
import tables
import threading
Expand Down Expand Up @@ -45,6 +46,9 @@

lock = threading.Lock()

class ProcessType(Enum):
Observation = "Observation"
Simulation = "Simulation"

class TableQualityQuery(QualityQuery):
"""Quality criteria for table-wise dl1b parameters."""
Expand Down Expand Up @@ -78,8 +82,8 @@ class DLDataReader(Component):
The first file in the list of input files, which is used as reference.
_v_attrs : dict
Attributes and useful information retrieved from the first file.
process_type : str
The type of data processing (i.e. ``Observation`` or ``Simulation``).
process_type : enum
The type of data processing (i.e. ``ProcessType.Observation`` or ``ProcessType.Simulation``).
data_format_version : str
The version of the ctapipe data format.
instrument_id : str
Expand Down Expand Up @@ -229,7 +233,7 @@ def __init__(
self.first_file = list(self.files)[0]
# Save the user attributes and useful information retrieved from the first file as a reference
self._v_attrs = self.files[self.first_file].root._v_attrs
self.process_type = self._v_attrs["CTA PROCESS TYPE"]
self.process_type = ProcessType(self._v_attrs["CTA PROCESS TYPE"])
self.data_format_version = self._v_attrs["CTA PRODUCT DATA MODEL VERSION"]
self.instrument_id = self._v_attrs["CTA INSTRUMENT ID"]

Expand All @@ -239,7 +243,7 @@ def __init__(
f"Provided ctapipe data format version is '{self.data_format_version}' (must be >= v.6.0.0)."
)
# Check for real data processing that only a single file is provided.
if self.process_type == "Observation" and len(self.files) != 1:
if self.process_type == ProcessType.Observation and len(self.files) != 1:
raise ValueError(
f"When processing real observational data, please provide a single file (currently: '{len(self.files)}')."
)
Expand Down Expand Up @@ -318,7 +322,7 @@ def __init__(
# Telescope pointings
self.telescope_pointings = {}
self.tel_trigger_table = None
if self.process_type == "Observation":
if self.process_type == ProcessType.Observation:
for tel_id in self.tel_ids:
with lock:
self.telescope_pointings[f"tel_{tel_id:03d}"] = read_table(
Expand Down Expand Up @@ -354,7 +358,7 @@ def __init__(
# Scaling by total/2 helps keep the loss to a similar magnitude.
# The sum of the weights of all examples stays the same.
self.class_weight = None
if self.process_type == "Simulation":
if self.process_type == ProcessType.Simulation:
if self.bkg_input_files is not None:
self.class_weight = {
0: (1.0 / self.n_bkg_events) * (self._get_n_events() / 2.0),
Expand Down Expand Up @@ -385,15 +389,15 @@ def _construct_mono_example_identifiers(self):
# This are the basic columns one need to do a
# conventional IACT analysis with CNNs
self.example_ids_keep_columns = ["table_index", "obs_id", "event_id", "tel_id"]
if self.process_type == "Simulation":
if self.process_type == ProcessType.Simulation:
self.example_ids_keep_columns.extend(
["true_energy", "true_alt", "true_az", "true_shower_primary_id"]
)

simulation_info = []
example_identifiers = []
for file_idx, (filename, f) in enumerate(self.files.items()):
if self.process_type == "Simulation":
if self.process_type == ProcessType.Simulation:
# Read simulation information for each observation
simulation_info.append(read_table(f, "/configuration/simulation/run"))
# Construct the shower simulation table
Expand All @@ -410,7 +414,7 @@ def _construct_mono_example_identifiers(self):
tel_table.add_column(
np.arange(len(tel_table)), name="table_index", index=0
)
if self.process_type == "Simulation":
if self.process_type == ProcessType.Simulation:
tel_table = join(
left=tel_table,
right=simshower_table,
Expand Down Expand Up @@ -455,7 +459,7 @@ def _construct_mono_example_identifiers(self):
# Constrcut the example identifiers for all files
self.example_identifiers = vstack(example_identifiers)
# Construct simulation information for all files
if self.process_type == "Simulation":
if self.process_type == ProcessType.Simulation:
self.simulation_info = vstack(simulation_info)
self.n_signal_events = np.count_nonzero(
self.example_identifiers["true_shower_primary_class"] == 1
Expand Down Expand Up @@ -491,24 +495,24 @@ def _construct_stereo_example_identifiers(self):
"tel_id",
"hillas_intensity",
]
if self.process_type == "Simulation":
if self.process_type == ProcessType.Simulation:
self.example_ids_keep_columns.extend(
["true_energy", "true_alt", "true_az", "true_shower_primary_id"]
)
elif self.process_type == "Observation":
elif self.process_type == ProcessType.Observation:
self.example_ids_keep_columns.extend(["time", "event_type"])

simulation_info = []
example_identifiers = []
for file_idx, (filename, f) in enumerate(self.files.items()):
if self.process_type == "Simulation":
if self.process_type == ProcessType.Simulation:
# Read simulation information for each observation
simulation_info.append(read_table(f, "/configuration/simulation/run"))
# Construct the shower simulation table
simshower_table = read_table(f, "/simulation/event/subarray/shower")
# Read the trigger table.
trigger_table = read_table(f, "/dl1/event/subarray/trigger")
if self.process_type == "Simulation":
if self.process_type == ProcessType.Simulation:
# The shower simulation table is joined with the subarray trigger table.
trigger_table = join(
left=trigger_table,
Expand Down Expand Up @@ -545,7 +549,7 @@ def _construct_stereo_example_identifiers(self):

table_per_type = table_per_type.group_by(["obs_id", "event_id"])
table_per_type.keep_columns(self.example_ids_keep_columns)
if self.process_type == "Simulation":
if self.process_type == ProcessType.Simulation:
tel_pointing = self._get_tel_pointing(f, self.tel_ids)
table_per_type = join(
left=table_per_type,
Expand Down Expand Up @@ -591,7 +595,7 @@ def _multiplicity_cut_subarray(table, key_colnames):
self.example_identifiers, keys=["obs_id", "event_id"]
)
# Construct simulation information for all files
if self.process_type == "Simulation":
if self.process_type == ProcessType.Simulation:
self.simulation_info = vstack(simulation_info)
self.n_signal_events = np.count_nonzero(
self.unique_example_identifiers["true_shower_primary_class"] == 1
Expand Down

0 comments on commit cf12cd0

Please sign in to comment.