Skip to content

Commit

Permalink
Merge pull request #42 from ai2es/sreiner
Browse files Browse the repository at this point in the history
Bug fixes and improvements to plotting
  • Loading branch information
charlie-becker authored Aug 20, 2024
2 parents 216f522 + 3e3f179 commit a7deda8
Show file tree
Hide file tree
Showing 10 changed files with 11,068 additions and 342 deletions.
4 changes: 2 additions & 2 deletions applications/evaluate_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def evaluate(conf, reevaluate=False, data_split=0, mc_forward_passes=0):
for name in data.keys():
x = scaled_data[f"{name}_x"]
if use_uncertainty:
pred_probs, u, ale, epi = mlp.predict_uncertainty(x)
pred_probs, u, ale, epi = mlp.predict(x, return_uncertainties=True, batch_size=10000)
pred_probs = pred_probs.numpy()
u = u.numpy()
ale = ale.numpy()
Expand Down Expand Up @@ -212,8 +212,8 @@ def evaluate(conf, reevaluate=False, data_split=0, mc_forward_passes=0):
for name in data.keys():
rocs = []
for i in range(len(output_features)):
forecasts = data[name]["pred_conf"]
obs = np.where(data[name]["true_label"] == i, 1, 0)
forecasts = data[name][f"pred_conf{i+1}"]
roc = DistributedROC(
thresholds=np.arange(0.0, 1.01, 0.01), obs_threshold=0.5
)
Expand Down
2 changes: 1 addition & 1 deletion applications/train_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pandas as pd
from ptype.callbacks import MetricsCallback
from ptype.data import load_ptype_uq, preprocess_data
from pytpe.trainer import trainer
from ptype.trainer import trainer
from mlguess.keras.callbacks import get_callbacks
from mlguess.keras.models import CategoricalDNN
from mlguess.pbs import launch_pbs_jobs
Expand Down
474 changes: 442 additions & 32 deletions notebooks/confusion_plots.ipynb

Large diffs are not rendered by default.

5,632 changes: 5,375 additions & 257 deletions notebooks/mping_visualization.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit a7deda8

Please sign in to comment.