diff --git a/dl1_data_handler/loader.py b/dl1_data_handler/loader.py index f21cc9c..3d9ac56 100644 --- a/dl1_data_handler/loader.py +++ b/dl1_data_handler/loader.py @@ -13,7 +13,9 @@ def __init__( tasks, batch_size=64, random_seed=0, + **kwargs, ): + super().__init__(**kwargs) "Initialization" self.DLDataReader = DLDataReader self.indices = indices @@ -82,4 +84,7 @@ def __getitem__(self, index): ), axis=1, ) + # Temp fix for supporting keras2 & keras3 + if int(keras.__version__.split(".")[0]) >= 3: + features = features["input"] return features, labels