Skip to content

Commit

Permalink
temp fix keras/TF bug
Browse files Browse the repository at this point in the history
  • Loading branch information
TjarkMiener committed Oct 17, 2024
1 parent d92c906 commit 6d37efa
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions dl1_data_handler/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,21 +58,29 @@ def __getitem__(self, index):
batch_indices=batch_indices,
)
# Generate the labels for each task
label = {}
labels = {}
if "type" in self.tasks:
label["type"] = to_categorical(
labels["type"] = to_categorical(
batch["true_shower_primary_class"].data,
num_classes=2,
)
label = to_categorical(
batch["true_shower_primary_class"].data,
num_classes=2,
)
if "energy" in self.tasks:
label["energy"] = batch["log_true_energy"].data
labels["energy"] = batch["log_true_energy"].data
if "direction" in self.tasks:
label["direction"] = np.stack(
labels["direction"] = np.stack(
(
batch["spherical_offset_az"].data,
batch["spherical_offset_alt"].data,
batch["angular_separation"].data,
),
axis=1,
)
return features, label
# Temp fix till keras support class weights for multiple outputs or I wrote custom loss
# https://github.com/keras-team/keras/issues/11735
if len(labels) == 1 and labels[0] == "type":
labels = label
return features, labels

0 comments on commit 6d37efa

Please sign in to comment.