Skip to content

Commit

Permalink
update tweedie
Browse files Browse the repository at this point in the history
  • Loading branch information
davidruegamer committed Apr 19, 2024
1 parent aa4af50 commit cbc19d3
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 19 deletions.
23 changes: 17 additions & 6 deletions R/families.R
Original file line number Diff line number Diff line change
Expand Up @@ -544,14 +544,15 @@ tfd_mvr <- function(loc, scale,
}

# Implementation of a distribution-like layer for (Quasi-)Tweedie
tfd_tweedie <- function(loc, p = 1.5, quasi = FALSE,
tfd_tweedie <- function(loc, phi, p = 1.5, quasi = FALSE,
validate_args = FALSE,
allow_nan_stats = TRUE,
name = "QuasiTweedie")
name = "Tweedie")
{

args <- list(
loc = loc,
scale = phi,
var_power = p,
quasi = quasi,
validate_args = validate_args,
Expand All @@ -567,12 +568,22 @@ tfd_tweedie <- function(loc, p = 1.5, quasi = FALSE,
}

# tfd_distfun for (Quasi-)Tweedie to allow for flexible p
tweedie <- function(p, quasi = FALSE)
tweedie <- function(p, quasi = FALSE, output_dim = 1L)
{

tfd_dist <- function(l) tfd_tweedie(loc = l, p = p, quasi = quasi)
ret_fun <- function(x) tfd_dist(tf$add(1e-8, tfe(x)))
attr(ret_fun, "nrparams_dist") <- 1L
tfd_dist <- function(l, s) tfd_tweedie(loc = l, phi = s, p = p, quasi = quasi)
trafo_list <- list(function(x) tf$add(1e-8, tfe(x)),
function(x) tf$add(1e-8, tfe(x)))
dist_dim <- 2L
ret_fun <- function(x)
do.call(tfd_dist,
lapply(1:(x$shape[[2]]/output_dim),
function(i)
trafo_list[[i]](
tf_stride_cols(x,(i-1L)*output_dim+1L,
(i-1L)*output_dim+output_dim)))
)
attr(ret_fun, "nrparams_dist") <- 2L

return(ret_fun)

Expand Down
52 changes: 39 additions & 13 deletions inst/python/distributions/tweedie.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@
from tensorflow.experimental import numpy as tnp

class Tweedie(distribution.AutoCompositeTensorDistribution):
"""Mean-Variance Regression (https://arxiv.org/pdf/1804.01631.pdf)
"""Tweedie
"""

def __init__(self,
loc,
scale,
var_power=1.,
quasi=False,
a=1.01,
b=1.99,
validate_args=False,
allow_nan_stats=True,
name='Tweedie'):
Expand All @@ -26,6 +29,8 @@ def __init__(self,
broadcasting.
Args:
loc: Floating point tensor; the means of the distribution(s).
scale: Floating point tensor; the scale of the distribution for Quasi,
phi for non-Quasi
var_power: The variance power, also referred to as "p". The default is 1.
quasi: Python `bool`, default `False`. When `True` quasi log-liklihood is used.
validate_args: Python `bool`, default `False`. When `True` distribution
Expand All @@ -40,11 +45,15 @@ def __init__(self,
"""
parameters = dict(locals())
with tf.name_scope(name) as name:
dtype = dtype_util.common_dtype([loc], dtype_hint=tf.float32)
dtype = dtype_util.common_dtype([loc, scale], dtype_hint=tf.float32)
self._loc = tensor_util.convert_nonref_to_tensor(
loc, dtype=dtype, name='loc')
self._scale = tensor_util.convert_nonref_to_tensor(
scale, dtype=dtype, name='scale')
self._p = var_power
self.quasi = quasi
self.a = a
self.b = b
super(Tweedie, self).__init__(
dtype=dtype,
reparameterization_type=reparameterization.FULLY_REPARAMETERIZED,
Expand All @@ -67,10 +76,15 @@ def _parameter_properties(cls, dtype, num_classes=None):
def loc(self):
"""Parameter for the mean."""
return self._loc

@property
def scale(self):
"""Parameter for standard deviation."""
return self._scale

@property
def p(self):
"""Parameter for standard deviation."""
"""Parameter for power."""
return self._p

def _event_shape_tensor(self):
Expand All @@ -81,7 +95,8 @@ def _event_shape(self):

def _sample_n(self, n, seed=None):
loc = tf.convert_to_tensor(self.loc)
shape = ps.concat([[n], self._batch_shape_tensor(loc=loc, scale=1)],
scale = tf.convert_to_tensor(self.scale)
shape = ps.concat([[n], self._batch_shape_tensor(loc=loc, scale=scale)],
axis=0)
sampled = samplers.normal(
shape=shape, mean=0., stddev=1., dtype=self.dtype, seed=seed)
Expand All @@ -91,25 +106,36 @@ def _sample_n(self, n, seed=None):
def _log_prob(self, x):
"""Used for the loss of the model -- not an actual log prob"""
if self.quasi: # from https://www.statsmodels.org/stable/_modules/statsmodels/genmod/families/family.html#Tweedie
llf = log(2 * tnp.pi) + self.p * log(x)
llf = log(2 * tnp.pi * self.scale) + self.p * log(x)
llf /= -2
u = (x ** (2 - self.p) - (2 - self.p) * x * self.loc ** (1 - self.p) + (1 - self.p) * self.loc ** (2 - self.p))
u *= 1 / ((1 - self.p) * (2 - self.p))
u *= 1 / (self.scale * (1 - self.p) * (2 - self.p))
return llf - u

else: # from https://github.com/cran/statmod/blob/master/R/tweedie.R negative deviance residuals
x1 = x + 0.1 * tf.cast(tf.equal(x, 0), tf.float32)
theta = (tf.pow(x1, 1 - self.p) - tf.pow(self.loc, 1 - self.p)) / (1 - self.p)
kappa = (tf.pow(x, 2 - self.p) - tf.pow(self.loc, 2 - self.p)) / (2 - self.p)

return - 2 * (x * theta - kappa)
else:
# from https://github.com/cran/statmod/blob/master/R/tweedie.R negative deviance residuals
# x1 = x + 0.1 * tf.cast(tf.equal(x, 0), tf.float32)
# theta = (tf.pow(x1, 1 - self.p) - tf.pow(self.loc, 1 - self.p)) / (1 - self.p)
# kappa = (tf.pow(x, 2 - self.p) - tf.pow(self.loc, 2 - self.p)) / (2 - self.p)
# return - 2 * (x * theta - kappa)
# from https://github.com/cran/mgcv/blob/aff4560d187dfd7d98c7bd367f5a0076faf129b7/R/gamlss.r#L2474
ethi = tf.exp(-self.p) # assuming p > 0
p = (self.b + self.a * ethi)/(1+ethi)
x1 = x + tf.cast(x == 0, tf.float32)
theta = (tf.pow(x1, 1 - p) - tf.pow(self.loc, 1 - p)) / (1 - p)
kappa = (tf.pow(x, 2 - p) - tf.pow(self.loc, 2 - p)) / (2 - p)
return tf.sign(x - self.loc) * tf.sqrt(tf.nn.relu(2 * (x * theta - kappa) * 1 / self.scale))



def _mean(self):
return self.loc * tf.ones_like(self.scale)

def _stddev(self):
return self.scale * tf.ones_like(self.loc)
if self.quasi:
return self._scale
else:
return tf.sqrt(self._scale * tf.pow(self.loc, self.p))

def _default_event_space_bijector(self):
return identity_bijector.Identity(validate_args=self.validate_args)
Expand Down

0 comments on commit cbc19d3

Please sign in to comment.