Skip to content

Commit

Permalink
[Prediction] Fixed prediction tf
Browse files Browse the repository at this point in the history
  • Loading branch information
YanSte committed Aug 6, 2023
1 parent 1adef35 commit c4cbf96
Showing 1 changed file with 16 additions and 13 deletions.
29 changes: 16 additions & 13 deletions src/skit/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,23 +32,30 @@ def tf_predictions(dataset, model, num_take='all', labels='default', verbosity=0
- y_test (list): The actual labels of the test data.
- y_pred (list): The predicted labels.
"""
x_test = []
y_test = []
y_pred = []
if not isinstance(dataset, tf.data.Dataset):
raise ValueError("The provided dataset is not an instance of tf.data.Dataset.")

# Labels
# ----
if labels == "default":
labels = dataset.class_names
try:
labels = dataset.class_names
except AttributeError:
raise AttributeError("The dataset does not have an attribute 'class_names'. Please provide explicit labels.")

# Num take
# ----
if num_take != 'all':
dataset_size = dataset.cardinality().numpy()
if num_take > dataset_size:
raise Exception(f"The num_take is bigger than the dataset size: {dataset_size}.")
else:
dataset = dataset.take(num_take)
raise ValueError(f"The value of num_take ({num_take}) exceeds the dataset size: {dataset_size}.")
dataset = dataset.take(num_take)

# Setup returns
# ----
x_test = []
y_test = []
y_pred = []

for images, true_labels in dataset:
# Get feature
Expand All @@ -57,18 +64,14 @@ def tf_predictions(dataset, model, num_take='all', labels='default', verbosity=0

# Get true label
# ----
true_indices = np.argmax(true_labels.numpy(), axis=1)
true_labels_list = [labels[idx] for idx in true_indices]
y_test.extend(true_labels_list)
y_test.extend(labels[idx] for idx in np.argmax(true_labels.numpy(), axis=1))

# Predict
# ----
predictions = model.predict(images, verbose=verbosity)

# Get predicted labels
# ----
predicted_indices = np.argmax(predictions, axis=1)
predicted_labels = [labels[idx] for idx in predicted_indices]
y_pred.extend(predicted_labels)
y_pred.extend(labels[idx] for idx in np.argmax(predictions, axis=1))

return x_test, y_test, y_pred

0 comments on commit c4cbf96

Please sign in to comment.