Skip to content

Commit

Permalink
updates related to outdated model
Browse files Browse the repository at this point in the history
  • Loading branch information
sophiaanr committed Aug 2, 2024
1 parent 1fc7583 commit 26f5027
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions ptype/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def trainer(conf, evaluate=True, data_split=0, mc_forward_passes=0):
mlp = CategoricalDNN(**conf["model"], callbacks=callbacks)
# train the model
print(scaled_data["train_y"])
raise
#raise
history = mlp.fit(scaled_data["train_x"], scaled_data["train_y"])

if conf["ensemble"]["n_splits"] > 1:
Expand All @@ -111,21 +111,21 @@ def trainer(conf, evaluate=True, data_split=0, mc_forward_passes=0):
if evaluate:
# Save the best model when not using ECHO
if conf["ensemble"]["n_splits"] == 1:
mlp.model.save(os.path.join(conf["save_loc"], "models", "best.h5"))
mlp.save(os.path.join(conf["save_loc"], "models", "best.keras"))
else:
mlp.model.save(
os.path.join(conf["save_loc"], "models", f"model_{data_split}.h5")
mlp.save(
os.path.join(conf["save_loc"], "models", f"model_{data_split}.keras")
)
for name in data.keys():
x = scaled_data[f"{name}_x"]
if use_uncertainty:
pred_probs, u, ale, epi = mlp.predict_uncertainty(x)
pred_probs, u, ale, epi = mlp.predict(x, return_uncertainties=True, batch_size=10000)
pred_probs = pred_probs.numpy()
u = u.numpy()
ale = ale.numpy()
epi = epi.numpy()
elif mc_forward_passes > 0: # Compute epistemic uncertainty with MC dropout
pred_probs = mlp.predict(x)
pred_probs = mlp.predict(x, return_uncertainties=False)
_, ale, epi, entropy, mutual_info = mlp.predict_monte_carlo(
x, mc_forward_passes=mc_forward_passes
)
Expand Down

0 comments on commit 26f5027

Please sign in to comment.