Skip to content

Commit

Permalink
Added gumbel-softmax
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander März committed Aug 23, 2023
1 parent 88be6aa commit d7d70ab
Showing 1 changed file with 87 additions and 27 deletions.
114 changes: 87 additions & 27 deletions xgboostlss/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import torch
from torch.nn.functional import softplus, gumbel_softmax, softmax


def identity_fn(predt: torch.tensor) -> torch.tensor:
def nan_to_num(predt: torch.tensor) -> torch.tensor:
"""
Identity mapping of predt.
Replace nan, inf and -inf with the mean of predt.
Arguments
---------
Expand All @@ -15,12 +15,18 @@ def identity_fn(predt: torch.tensor) -> torch.tensor:
predt: torch.tensor
Predicted values.
"""
predt = torch.nan_to_num(predt,
nan=float(torch.nanmean(predt)),
posinf=float(torch.nanmean(predt)),
neginf=float(torch.nanmean(predt))
)

return predt


def exp_fn(predt: torch.tensor) -> torch.tensor:
def identity_fn(predt: torch.tensor) -> torch.tensor:
"""
Exponential function used to ensure predt is strictly positive.
Identity mapping of predt.
Arguments
---------
Expand All @@ -32,15 +38,14 @@ def exp_fn(predt: torch.tensor) -> torch.tensor:
predt: torch.tensor
Predicted values.
"""
predt = torch.exp(predt)
predt = torch.nan_to_num(predt, nan=float(torch.nanmean(predt))) + torch.tensor(1e-06, dtype=predt.dtype)
predt = nan_to_num(predt) + torch.tensor(0, dtype=predt.dtype)

return predt


def exp_fn_df(predt: torch.tensor) -> torch.tensor:
def exp_fn(predt: torch.tensor) -> torch.tensor:
"""
Exponential function used for Student-T distribution.
Exponential function used to ensure predt is strictly positive.
Arguments
---------
Expand All @@ -52,15 +57,14 @@ def exp_fn_df(predt: torch.tensor) -> torch.tensor:
predt: torch.tensor
Predicted values.
"""
predt = torch.exp(predt)
predt = torch.nan_to_num(predt, nan=float(torch.nanmean(predt))) + torch.tensor(1e-06, dtype=predt.dtype)
predt = torch.exp(nan_to_num(predt)) + torch.tensor(1e-06, dtype=predt.dtype)

return predt + torch.tensor(2.0, dtype=predt.dtype)
return predt


def log_fn(predt: torch.tensor) -> torch.tensor:
def exp_fn_df(predt: torch.tensor) -> torch.tensor:
"""
Inverse of exp_fn function.
Exponential function used for Student-T distribution.
Arguments
---------
Expand All @@ -72,10 +76,9 @@ def log_fn(predt: torch.tensor) -> torch.tensor:
predt: torch.tensor
Predicted values.
"""
predt = torch.log(predt)
predt = torch.nan_to_num(predt, nan=float(torch.nanmean(predt))) + float(1e-06)
predt = torch.exp(nan_to_num(predt)) + torch.tensor(1e-06, dtype=predt.dtype)

return predt
return predt + torch.tensor(2.0, dtype=predt.dtype)


def softplus_fn(predt: torch.tensor) -> torch.tensor:
Expand All @@ -92,9 +95,7 @@ def softplus_fn(predt: torch.tensor) -> torch.tensor:
predt: torch.tensor
Predicted values.
"""
predt = torch.log1p(torch.exp(-torch.abs(predt))) + torch.maximum(predt, torch.tensor(0.))
predt[predt == 0] = torch.tensor(1e-06, dtype=predt.dtype)
predt = torch.nan_to_num(predt, nan=float(torch.nanmean(predt))) + torch.tensor(1e-06, dtype=predt.dtype)
predt = softplus(nan_to_num(predt)) + torch.tensor(1e-06, dtype=predt.dtype)

return predt

Expand All @@ -113,9 +114,7 @@ def softplus_fn_df(predt: torch.tensor) -> torch.tensor:
predt: torch.tensor
Predicted values.
"""
predt = torch.log1p(torch.exp(-torch.abs(predt))) + torch.maximum(predt, torch.tensor(0.))
predt[predt == 0] = torch.tensor(1e-06, dtype=predt.dtype)
predt = torch.nan_to_num(predt, nan=float(torch.nanmean(predt))) + torch.tensor(1e-06, dtype=predt.dtype)
predt = softplus(nan_to_num(predt)) + torch.tensor(1e-06, dtype=predt.dtype)

return predt + torch.tensor(2.0, dtype=predt.dtype)

Expand All @@ -134,8 +133,7 @@ def sigmoid_fn(predt: torch.tensor) -> torch.tensor:
predt: torch.tensor
Predicted values.
"""
predt = torch.sigmoid(predt)
predt = torch.nan_to_num(predt, nan=float(torch.nanmean(predt))) + torch.tensor(1e-06, dtype=predt.dtype)
predt = torch.sigmoid(nan_to_num(predt)) + torch.tensor(1e-06, dtype=predt.dtype)
predt = torch.clamp(predt, 1e-03, 1-1e-03)

return predt
Expand All @@ -155,7 +153,69 @@ def relu_fn(predt: torch.tensor) -> torch.tensor:
predt: torch.tensor
Predicted values.
"""
predt = torch.relu(predt)
predt = torch.nan_to_num(predt, nan=float(torch.nanmean(predt))) + torch.tensor(1e-6, dtype=predt.dtype)
predt = torch.relu(nan_to_num(predt)) + torch.tensor(1e-06, dtype=predt.dtype)

return predt


def softmax_fn(predt: torch.tensor) -> torch.tensor:
"""
Softmax function used to ensure predt is adding to one.
Arguments
---------
predt: torch.tensor
Predicted values.
Returns
-------
predt: torch.tensor
Predicted values.
"""
predt = softmax(nan_to_num(predt), dim=1) + torch.tensor(0, dtype=predt.dtype)

return predt


def gumbel_softmax_fn(predt: torch.tensor,
tau: float = 1.0
) -> torch.tensor:
"""
Gumbel-softmax function used to ensure predt is adding to one.
The Gumbel-softmax distribution is a continuous distribution over the simplex, which can be thought of as a "soft"
version of a categorical distribution. It’s a way to draw samples from a categorical distribution in a
differentiable way. The motivation behind using the Gumbel-Softmax is to make the discrete sampling process of
categorical variables differentiable, which is useful in gradient-based optimization problems. To sample from a
Gumbel-Softmax distribution, one would use the Gumbel-max trick: add a Gumbel noise to logits and apply the softmax.
Formally, given a vector z, the Gumbel-softmax function s(z, \tau)_i for a component i at temperature \tau is
defined as:
s(z, \tau)_i = \frac{e^{(z_i + g_i) / \tau}}{\sum_{j=1}^M e^{(z_j + g_j) / \tau}}
where g_i is a sample from the Gumbel(0, 1) distribution. The parameter \tau (temperature) controls the sharpness
of the output distribution. As \tau \to 0, the output becomes more discrete, and as \tau \to \infty, the output
becomes more uniform. For more information we refer to
Jang, E., Gu, Shixiang and Poole, B. "Categorical Reparameterization with Gumbel-Softmax", ICLR, 2017.
Arguments
---------
predt: torch.tensor
Predicted values.
tau: float, non-negative scalar temperature.
Temperature parameter for the Gumbel-softmax distribution. As tau -> 0, the output becomes more discrete, and as
tau -> inf, the output becomes more uniform.
Returns
-------
predt: torch.tensor
Predicted values.
"""
torch.manual_seed(123)
predt = gumbel_softmax(nan_to_num(predt), tau=tau, dim=1) + torch.tensor(0, dtype=predt.dtype)


return predt

0 comments on commit d7d70ab

Please sign in to comment.