diff --git a/deeprvat/deeprvat/train.py b/deeprvat/deeprvat/train.py index fd047868..c2e2abd4 100644 --- a/deeprvat/deeprvat/train.py +++ b/deeprvat/deeprvat/train.py @@ -285,7 +285,7 @@ def __getitem__(self, index): start_idx = index * self.batch_size end_idx = min(self.total_samples, start_idx + self.batch_size) batch_samples = self.sample_order.iloc[start_idx:end_idx] - samples_by_pheno = batch_samples.groupby("phenotype") + samples_by_pheno = batch_samples.groupby("phenotype", observed=True) result = dict() for pheno, df in samples_by_pheno: