diff --git a/xgboostlss/utils.py b/xgboostlss/utils.py index 021802fe..a2b41376 100644 --- a/xgboostlss/utils.py +++ b/xgboostlss/utils.py @@ -1,9 +1,9 @@ import torch +from torch.nn.functional import softplus, gumbel_softmax, softmax - -def identity_fn(predt: torch.tensor) -> torch.tensor: +def nan_to_num(predt: torch.tensor) -> torch.tensor: """ - Identity mapping of predt. + Replace nan, inf and -inf with the mean of predt. Arguments --------- @@ -15,12 +15,18 @@ def identity_fn(predt: torch.tensor) -> torch.tensor: predt: torch.tensor Predicted values. """ + predt = torch.nan_to_num(predt, + nan=float(torch.nanmean(predt)), + posinf=float(torch.nanmean(predt)), + neginf=float(torch.nanmean(predt)) + ) + return predt -def exp_fn(predt: torch.tensor) -> torch.tensor: +def identity_fn(predt: torch.tensor) -> torch.tensor: """ - Exponential function used to ensure predt is strictly positive. + Identity mapping of predt. Arguments --------- @@ -32,15 +38,14 @@ def exp_fn(predt: torch.tensor) -> torch.tensor: predt: torch.tensor Predicted values. """ - predt = torch.exp(predt) - predt = torch.nan_to_num(predt, nan=float(torch.nanmean(predt))) + torch.tensor(1e-06, dtype=predt.dtype) + predt = nan_to_num(predt) + torch.tensor(0, dtype=predt.dtype) return predt -def exp_fn_df(predt: torch.tensor) -> torch.tensor: +def exp_fn(predt: torch.tensor) -> torch.tensor: """ - Exponential function used for Student-T distribution. + Exponential function used to ensure predt is strictly positive. Arguments --------- @@ -52,15 +57,14 @@ def exp_fn_df(predt: torch.tensor) -> torch.tensor: predt: torch.tensor Predicted values. """ - predt = torch.exp(predt) - predt = torch.nan_to_num(predt, nan=float(torch.nanmean(predt))) + torch.tensor(1e-06, dtype=predt.dtype) + predt = torch.exp(nan_to_num(predt)) + torch.tensor(1e-06, dtype=predt.dtype) - return predt + torch.tensor(2.0, dtype=predt.dtype) + return predt -def log_fn(predt: torch.tensor) -> torch.tensor: +def exp_fn_df(predt: torch.tensor) -> torch.tensor: """ - Inverse of exp_fn function. + Exponential function used for Student-T distribution. Arguments --------- @@ -72,10 +76,9 @@ def log_fn(predt: torch.tensor) -> torch.tensor: predt: torch.tensor Predicted values. """ - predt = torch.log(predt) - predt = torch.nan_to_num(predt, nan=float(torch.nanmean(predt))) + float(1e-06) + predt = torch.exp(nan_to_num(predt)) + torch.tensor(1e-06, dtype=predt.dtype) - return predt + return predt + torch.tensor(2.0, dtype=predt.dtype) def softplus_fn(predt: torch.tensor) -> torch.tensor: @@ -92,9 +95,7 @@ def softplus_fn(predt: torch.tensor) -> torch.tensor: predt: torch.tensor Predicted values. """ - predt = torch.log1p(torch.exp(-torch.abs(predt))) + torch.maximum(predt, torch.tensor(0.)) - predt[predt == 0] = torch.tensor(1e-06, dtype=predt.dtype) - predt = torch.nan_to_num(predt, nan=float(torch.nanmean(predt))) + torch.tensor(1e-06, dtype=predt.dtype) + predt = softplus(nan_to_num(predt)) + torch.tensor(1e-06, dtype=predt.dtype) return predt @@ -113,9 +114,7 @@ def softplus_fn_df(predt: torch.tensor) -> torch.tensor: predt: torch.tensor Predicted values. """ - predt = torch.log1p(torch.exp(-torch.abs(predt))) + torch.maximum(predt, torch.tensor(0.)) - predt[predt == 0] = torch.tensor(1e-06, dtype=predt.dtype) - predt = torch.nan_to_num(predt, nan=float(torch.nanmean(predt))) + torch.tensor(1e-06, dtype=predt.dtype) + predt = softplus(nan_to_num(predt)) + torch.tensor(1e-06, dtype=predt.dtype) return predt + torch.tensor(2.0, dtype=predt.dtype) @@ -134,8 +133,7 @@ def sigmoid_fn(predt: torch.tensor) -> torch.tensor: predt: torch.tensor Predicted values. """ - predt = torch.sigmoid(predt) - predt = torch.nan_to_num(predt, nan=float(torch.nanmean(predt))) + torch.tensor(1e-06, dtype=predt.dtype) + predt = torch.sigmoid(nan_to_num(predt)) + torch.tensor(1e-06, dtype=predt.dtype) predt = torch.clamp(predt, 1e-03, 1-1e-03) return predt @@ -155,7 +153,69 @@ def relu_fn(predt: torch.tensor) -> torch.tensor: predt: torch.tensor Predicted values. """ - predt = torch.relu(predt) - predt = torch.nan_to_num(predt, nan=float(torch.nanmean(predt))) + torch.tensor(1e-6, dtype=predt.dtype) + predt = torch.relu(nan_to_num(predt)) + torch.tensor(1e-06, dtype=predt.dtype) + + return predt + + +def softmax_fn(predt: torch.tensor) -> torch.tensor: + """ + Softmax function used to ensure predt is adding to one. + + + Arguments + --------- + predt: torch.tensor + Predicted values. + + Returns + ------- + predt: torch.tensor + Predicted values. + """ + predt = softmax(nan_to_num(predt), dim=1) + torch.tensor(0, dtype=predt.dtype) + + return predt + + +def gumbel_softmax_fn(predt: torch.tensor, + tau: float = 1.0 + ) -> torch.tensor: + """ + Gumbel-softmax function used to ensure predt is adding to one. + + The Gumbel-softmax distribution is a continuous distribution over the simplex, which can be thought of as a "soft" + version of a categorical distribution. It’s a way to draw samples from a categorical distribution in a + differentiable way. The motivation behind using the Gumbel-Softmax is to make the discrete sampling process of + categorical variables differentiable, which is useful in gradient-based optimization problems. To sample from a + Gumbel-Softmax distribution, one would use the Gumbel-max trick: add a Gumbel noise to logits and apply the softmax. + Formally, given a vector z, the Gumbel-softmax function s(z, \tau)_i for a component i at temperature \tau is + defined as: + + s(z, \tau)_i = \frac{e^{(z_i + g_i) / \tau}}{\sum_{j=1}^M e^{(z_j + g_j) / \tau}} + + where g_i is a sample from the Gumbel(0, 1) distribution. The parameter \tau (temperature) controls the sharpness + of the output distribution. As \tau \to 0, the output becomes more discrete, and as \tau \to \infty, the output + becomes more uniform. For more information we refer to + + Jang, E., Gu, Shixiang and Poole, B. "Categorical Reparameterization with Gumbel-Softmax", ICLR, 2017. + + + Arguments + --------- + predt: torch.tensor + Predicted values. + tau: float, non-negative scalar temperature. + Temperature parameter for the Gumbel-softmax distribution. As tau -> 0, the output becomes more discrete, and as + tau -> inf, the output becomes more uniform. + + Returns + ------- + predt: torch.tensor + Predicted values. + """ + torch.manual_seed(123) + predt = gumbel_softmax(nan_to_num(predt), tau=tau, dim=1) + torch.tensor(0, dtype=predt.dtype) + return predt