Skip to content

Commit

Permalink
Update plots in dist_select
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander März committed Aug 24, 2023
1 parent 5a5e45f commit 87f50d9
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 43 deletions.
24 changes: 6 additions & 18 deletions docs/examples/SplineFlow_Regression.ipynb

Large diffs are not rendered by default.

36 changes: 11 additions & 25 deletions xgboostlss/distributions/flow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from tqdm import tqdm

from typing import Any, Dict, Optional, List, Tuple
from plotnine import *
import matplotlib.pyplot as plt
import seaborn as sns
import warnings


Expand Down Expand Up @@ -623,7 +624,6 @@ def flow_select(self,
target: np.ndarray,
candidate_flows: List,
max_iter: int = 100,
n_samples: int = 1000,
plot: bool = False,
figure_size: tuple = (10, 5),
) -> pd.DataFrame:
Expand All @@ -639,8 +639,6 @@ def flow_select(self,
List of candidate normalizing flow specifications.
max_iter: int
Maximum number of iterations for the optimization.
n_samples: int
Number of samples drawn from the fitted distribution.
plot: bool
If True, a density plot of the actual and fitted distribution is created.
figure_size: tuple
Expand Down Expand Up @@ -699,29 +697,17 @@ def flow_select(self,
flow_params = torch.tensor(best_flow["params"][0]).reshape(1, -1)
flow_dist_sel = best_flow_sel.create_spline_flow(input_dim=1)
_, flow_dist_sel = best_flow_sel.replace_parameters(flow_params, flow_dist_sel)
flow_samples = pd.DataFrame(flow_dist_sel.sample((n_samples,)).squeeze().detach().numpy().T)
n_samples = np.max([10000, target.shape[0]])
n_samples = np.where(n_samples > 500000, 100000, n_samples)
flow_samples = pd.DataFrame(flow_dist_sel.sample((n_samples,)).squeeze().detach().numpy().T).values

# Plot actual and fitted distribution
flow_samples["type"] = f"Best-Fit: {best_flow['NormFlow'].values[0]}"

df_actual = pd.DataFrame(target)
df_actual["type"] = "Data"

plot_df = pd.concat([df_actual, flow_samples]).rename(columns={0: "variable"})

print(
ggplot(plot_df,
aes(x="variable",
color="type")) +
geom_density(size=1.1) +
theme_bw(base_size=15) +
theme(figure_size=figure_size,
legend_position="right",
legend_title=element_blank(),
plot_title=element_text(hjust=0.5)) +
labs(title=f"Actual vs. Fitted Density",
x="")
)
plt.figure(figsize=figure_size)
sns.kdeplot(target.reshape(-1, ), label="Actual")
sns.kdeplot(flow_samples.reshape(-1, ), label=f"Best-Fit: {best_flow['NormFlow'].values[0]}")
plt.legend()
plt.title("Actual vs. Best-Fit Density")
plt.show()

fit_df.drop(columns=["rank", "params"], inplace=True)

Expand Down

0 comments on commit 87f50d9

Please sign in to comment.