Skip to content

Commit

Permalink
Expanded RingCounting functionality (#258)
Browse files Browse the repository at this point in the history
- Add RC data flow from boost stores, not requiring a pre-computed CNNImage file anymore
- Add "load_from_file" config parameter. Config setting: 0/1
- Add Execute() method in Tool class
- Add get_next_event() method in Tool class
- Change functionality of other methods to support both "load_from_file" settings
- The Tool expects a CNNImage formatted input (std::vector<double>) called "CNNImageCharge" within the "RecoEvent" boost store when not using the input from file
- The predictions are stored in the "RecoEvent" boost store as "RingCountingSRPrediction" and "RingCountingMRPrediction"

Ringcounting/README.md:
- Specify where the predictions are stored (RecoEvent)
- Complete list of config parameters
  • Loading branch information
s4294967296 authored Apr 5, 2024
1 parent 036d1ad commit d348aa0
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 33 deletions.
21 changes: 10 additions & 11 deletions UserTools/RingCounting/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# RingCounting

The RingCounting Tool is used to classify events into single- and multi-Cherenkov-ring-events.
For this a machine learning approach is used.
For this a CNN-based machine learning approach is used.
---
This Tool uses PMT data in the CNNImage format (see UserTools/CNNImage). For further details on the tool and
the models used (including performance etc.) see the documentation found on the anniegpvm-machines at (**Todo**)
Expand All @@ -14,13 +14,11 @@ All models can be found at
```
/pnfs/annie/persistent/users/dschmid/RingCountingStore/models/
```
and (some) at
```
/annie/app/users/dschmid/RingCountingStore/models/
```

## Data

- Currently does not add predictions to any BoostStore. This is planned in a future update (**Todo**)
- In the "load_from_file" mode, this tool adds single- and multi-ring (SR/MR) predictions to the RecoEvent BoostStore. When theh "load_from_file" config parameter is set to 0, the tool instead outputs the predictions to a csv file.
- The predictions are stored in the "RingCountingSRPrediction" and "RingCountingMRPrediction" variables

---
## Configuration
Expand All @@ -30,9 +28,10 @@ are found at the top of the RingCounting.py file.
---
Exemplary configuration:
```
files_to_load configfiles/RingCounting/files_to_load.txt
version 1_0_0
model_path /annie/app/users/dschmid/RingCountingStore/models/
pmt_mask november_22
save_to RC_output.csv
load_from_file 0 # If set to 1, load CNNImage formatted csv files
files_to_load configfiles/RingCounting/files_to_load.txt # txt file containing files to load in case load_from_file == 1
version 1_0_0 # Model version
model_path /annie/app/users/dschmid/RingCountingStore/models/ # Model path
pmt_mask november_22 # PMT mask (zeroed out)
save_to RC_output.csv # if load_from_file == 1, save predictions as csv
```
89 changes: 67 additions & 22 deletions UserTools/RingCounting/RingCounting.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,16 @@ class RingCounting(Tool, RingCountingGlobals):

# ----------------------------------------------------------------------------------------------------
# Config stuff
files_to_load = std.string() # List of files to be loaded (must be in CNNImage format)
load_from_file = std.string() # if 1, load a CNNImage formatted csv file instead of using the tool chain
files_to_load = std.string() # List of files to be loaded (must be in CNNImage format,
# load_from_file has to be true)
version = std.string() # Model version
model_path = std.string() # Path to model directory
pmt_mask = std.string() # See RingCountingGlobals
save_to = std.string() # Where to save the predictions to

# ----------------------------------------------------------------------------------------------------
# Model stuff
# Model variables
model = None
predicted = None

Expand All @@ -100,6 +102,8 @@ def Initialise(self):
# Config area
self.m_variables.Get("files_to_load", self.files_to_load)
self.files_to_load = str(self.files_to_load) # cast to str since std.string =/= str
self.m_variables.Get("load_from_file", self.load_from_file)
self.load_from_file = "1" == str(self.load_from_file)
self.m_variables.Get("version", self.version)
self.m_variables.Get("model_path", self.model_path)
self.m_variables.Get("pmt_mask", self.pmt_mask)
Expand All @@ -109,15 +113,45 @@ def Initialise(self):

# ----------------------------------------------------------------------------------------------------
# Loading data
self.load_data()
self.mask_pmts()
if self.load_from_file:
self.load_data()
else:
self.m_log.Log(__file__ + " Not loading data from csv file.", self.v_message, self.m_verbosity)
self.cnn_image_pmt = np.array([])

# ----------------------------------------------------------------------------------------------------
# Loading model
self.load_model()

return 1

def Execute(self):
""" Execute the tool by generating model predictions on the supplied data. """
self.m_log.Log(__file__ + " Executing", self.v_debug, self.m_verbosity)

self.get_next_event()
self.mask_pmts()
self.predict()

if not self.load_from_file:
predicted_sr = float(self.predicted[0][1])
predicted_mr = float(self.predicted[0][0])

reco_event_bs = self.m_data.Stores.at("RecoEvent")

reco_event_bs.Set("RingCountingSRPrediction", predicted_sr)
reco_event_bs.Set("RingCountingMRPrediction", predicted_mr)

return 1

def Finalise(self):
""" Finalise the tool by saving the predictions. """
self.m_log.Log(__file__ + " Finalising", self.v_debug, self.m_verbosity)
if self.load_from_file:
self.save_data()

return 1

def load_data(self):
""" Load data in the CNNImage format.
Expand Down Expand Up @@ -168,35 +202,46 @@ def mask_pmts(self):
""" Mask PMTs to 0. The PMTs to be masked is given as a list of indices, defined by setting [[pmt_mask]].
For further details check the RingCountingGlobals class.
"""
for event in self.cnn_image_pmt:
np.put(event, self.pmt_mask, 0, mode='raise')
if self.load_from_file:
for event in self.cnn_image_pmt:
np.put(event, self.pmt_mask, 0, mode='raise')
else:
np.put(self.cnn_image_pmt, self.pmt_mask, 0, mode='raise')

def load_model(self):
""" Load the specified model [[version]]."""
self.model = tf.keras.models.load_model(self.model_path + f"RC_model_v{self.version}.model")

def get_next_event(self):
""" Get the next event from the BoostStore. """
if self.load_from_file:
return

reco_event_bs = self.m_data.Stores.at("RecoEvent")
get_ok = reco_event_bs.Has("CNNImageCharge")
self.cnn_image_pmt = std.vector['double'](range(160))

if get_ok:
reco_event_bs.Get("CNNImageCharge", self.cnn_image_pmt)
else:
self.m_log.Log(__file__ + " ERROR: CNNImageCharge not present in RecoEvent boost store.",
self.v_error, self.m_verbosity)

# loop over std::vector to convert to list
self.cnn_image_pmt = [x for x in self.cnn_image_pmt]
# Explicitly adding the extra dimension using np.array([]) is done to ensure the data is properly
# reshaped into the shape (-1, 160).
self.cnn_image_pmt = np.array([self.cnn_image_pmt])
self.cnn_image_pmt = np.reshape(self.cnn_image_pmt, (-1, 160))

def predict(self):
"""
Classify events in single- and multi-ring events using a keras model. Save a list of 2-dimensional predictions
(same order as input) to self.predicted. Predictions are given as [MR prediction, SR prediction].
"""
print("Predicting...")
print(self.cnn_image_pmt)
self.predicted = self.model.predict(np.reshape(self.cnn_image_pmt, newshape=(-1, 10, 16, 1)))

def Execute(self):
""" Execute the tool by generating model predictions on the supplied data. """
self.m_log.Log(__file__ + " Executing", self.v_debug, self.m_verbosity)
self.predict()

return 1

def Finalise(self):
""" Finalise the tool by saving the predictions. """
self.m_log.Log(__file__ + " Finalising", self.v_debug, self.m_verbosity)
self.save_data()

return 1
self.m_log.Log(__file__ + " PREDICTING", self.v_message, self.m_verbosity)
self.predicted = self.model.predict(np.reshape(self.cnn_image_pmt, newshape=(-1, 10, 16, 1)))


###################
Expand Down

0 comments on commit d348aa0

Please sign in to comment.