diff --git a/deeprank/generate/NormalizeData.py b/deeprank/generate/NormalizeData.py index 1586f0a..9cc2ea6 100644 --- a/deeprank/generate/NormalizeData.py +++ b/deeprank/generate/NormalizeData.py @@ -127,22 +127,25 @@ def _extract_data(self): self.parameters['features'][feat_types][name].add( np.mean(mat), np.var(mat)) - # get the target groups - target_group = f5.get(mol + '/targets') + mol_group = f5[mol] + if "targets" in mol_group: - # loop over all the targets - for tname, tval in target_group.items(): + # get the target groups + target_group = mol_group["targets"] - # we skip the already computed target - if tname in self.skip_target: - continue + # loop over all the targets + for tname, tval in target_group.items(): - # create a new item if needed - if tname not in self.parameters['targets']: - self.parameters['targets'][tname] = MinMaxParam() + # we skip the already computed target + if tname in self.skip_target: + continue + + # create a new item if needed + if tname not in self.parameters['targets']: + self.parameters['targets'][tname] = MinMaxParam() - # update the value - self.parameters['targets'][tname].update(tval[()]) + # update the value + self.parameters['targets'][tname].update(tval[()]) f5.close() diff --git a/deeprank/learn/DataSet.py b/deeprank/learn/DataSet.py index 74bdb64..ed2bd98 100644 --- a/deeprank/learn/DataSet.py +++ b/deeprank/learn/DataSet.py @@ -761,7 +761,10 @@ def _read_norm(self): mean, var) # handle the target - if self.select_target is not None: + if self.select_target is not None and \ + "targets" in data and \ + self.select_target in data["targets"]: + minv = data['targets'][self.select_target].min maxv = data['targets'][self.select_target].max self.param_norm['targets'][self.select_target].update(minv) diff --git a/test/test_learn.py b/test/test_learn.py index af75f8f..2c4f292 100644 --- a/test/test_learn.py +++ b/test/test_learn.py @@ -31,19 +31,19 @@ def test_predict_without_target(): 'number_of_points': [30,30,30], 'resolution': [1.,1.,1.], 'atomic_densities': atomic_densities, - } + } environment = Environment(pdb_root="test/data/pdb", pssm_root="test/data/pssm") variants = [PdbVariantSelection("101m", "A", 10, valine, cysteine, protein_accession="P02144", protein_residue_number=10, - variant_class=VariantClass.BENIGN), + ), PdbVariantSelection("101m", "A", 8, glutamine, cysteine, protein_accession="P02144", - variant_class=VariantClass.PATHOGENIC), + ), PdbVariantSelection("101m", "A", 9, glutamine, cysteine, protein_accession="P02144", protein_residue_number=9, - variant_class=VariantClass.PATHOGENIC)] + )] augmentation = 5 work_dir_path = mkdtemp() @@ -64,8 +64,7 @@ def test_predict_without_target(): dataset = DataSet(hdf5_path, grid_info=grid_info, select_feature='all', - select_target='target1', - normalize_features=False) + normalize_features=True) eq_(len(dataset), len(variants) * (augmentation + 1)) ok_(dataset[0] is not None)