Skip to content
This repository has been archived by the owner on Jan 26, 2024. It is now read-only.

Commit

Permalink
a fix for the output files of prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
cbaakman committed Dec 21, 2023
1 parent 10e0500 commit 7a25f6b
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
5 changes: 3 additions & 2 deletions deeprank/learn/NeuralNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,9 +736,10 @@ def _epoch(self, epoch_number, pass_name, data_loader, train_model):

if targets is not None:
target_values += targets.tolist()
else:
target_values += [-1] * outputs.shape[0]

if len(target_values) > 0:
self._metrics_output.process(pass_name, epoch_number, entry_names, output_values, target_values)
self._metrics_output.process(pass_name, epoch_number, entry_names, output_values, target_values)

if count_data_entries > 0:
epoch_loss = sum_of_losses / count_data_entries
Expand Down
7 changes: 5 additions & 2 deletions test/test_learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,14 @@ def test_predict():

metrics_directory = os.path.join(work_dir_path, "runs")

output_exporter = OutputExporter(metrics_directory)

neural_net = NeuralNet(dataset, cnn_class, model_type='3d',task='class',
pretrained_model="test/data/models/best_valid_model.pth.tar",
cuda=False, metrics_exporters=[OutputExporter(metrics_directory),
TensorboardBinaryClassificationExporter(metrics_directory)])
cuda=False, metrics_exporters=[output_exporter])
neural_net.test()

assert os.path.isfile(output_exporter.get_filename("test", 0))
finally:
rmtree(work_dir_path)

Expand Down

0 comments on commit 7a25f6b

Please sign in to comment.