Skip to content

Commit

Permalink
benchmark_lss example
Browse files Browse the repository at this point in the history
  • Loading branch information
AnFreTh committed Jul 18, 2024
1 parent bcce1da commit 2f39f11
Showing 1 changed file with 70 additions and 0 deletions.
70 changes: 70 additions & 0 deletions examples/benchmark_lss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from mambular.models import MambularLSS
from sklearn.preprocessing import StandardScaler
import pandas as pd
import numpy as np
import properscoring as ps
from sklearn.model_selection import KFold

datasets = ["regression_datasets/abalone.csv", "regression_datasets/ca_housing.csv"]

crps = lambda y, pred: np.mean(
[
ps.crps_gaussian(y[i], mu=pred[i, 0], sig=np.sqrt(pred[i, 1]))
for i in range(len(y))
]
)

kf = KFold(n_splits=2, shuffle=True, random_state=42)


# Function to compute NLL
def compute_nll(y, pred):
means = pred[:, 0]
variances = pred[:, 1]
nll = 0.5 * (np.log(2 * np.pi * variances) + ((y - means) ** 2) / variances)
return np.mean(nll)


results = []

for dataset_name in datasets:
data = pd.read_csv(dataset_name)
data = data.dropna().reset_index(drop=True)
y_data = data.pop("Targets")
scaler = StandardScaler()
y_data = scaler.fit_transform(y_data.values.reshape(-1, 1)).squeeze(-1)

crps_vals = []
nll_vals = []
mse_vals = []

for fold, (train_index, val_index) in enumerate(kf.split(data)):
X_train, X_test = data.iloc[train_index], data.iloc[val_index]
y_train, y_test = y_data[train_index], y_data[val_index]

model = MambularLSS()

model.fit(X_train, y_train, family="normal", max_epochs=200, lr=5e-04)

print(model.evaluate(X_test, y_test))

predictions = model.predict(X_test)

crps_vals.append(crps(y_test, predictions))
nll_vals.append(compute_nll(y_test, predictions))
mse_vals.append(model.evaluate(X_test, y_test)["MSE"])

results.append(
{
"Dataset": dataset_name,
"Mean CRPS": np.mean(crps_vals),
"Std CRPS": np.std(crps_vals),
"Mean NLL": np.mean(nll_vals),
"Std NLL": np.std(nll_vals),
"MSE": np.mean(mse_vals),
}
)

results_df = pd.DataFrame(results)

print(results_df)

0 comments on commit 2f39f11

Please sign in to comment.