Skip to content

Commit

Permalink
Merge pull request #1392 from alanlujan91/fix_cons_bequest
Browse files Browse the repository at this point in the history
Fix cons bequest
  • Loading branch information
mnwhite authored Mar 11, 2024
2 parents df9b60c + 2914149 commit de61ed1
Show file tree
Hide file tree
Showing 5 changed files with 272 additions and 329 deletions.
190 changes: 131 additions & 59 deletions HARK/ConsumptionSaving/ConsBequestModel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""
Classes to solve consumption-saving models with a bequest motive and
"""Classes to solve consumption-saving models with a bequest motive and
idiosyncratic shocks to income and wealth. All models here assume
separable CRRA utility of consumption and Stone-Geary utility of
savings with geometric discounting of the continuation value and
Expand All @@ -11,7 +10,6 @@
"""

import numpy as np

from HARK.ConsumptionSaving.ConsIndShockModel import (
ConsIndShockSolver,
IndShockConsumerType,
Expand All @@ -37,51 +35,44 @@


class BequestWarmGlowConsumerType(IndShockConsumerType):
time_vary_ = IndShockConsumerType.time_vary_ + [
time_inv_ = IndShockConsumerType.time_inv_ + [
"BeqCRRA",
"BeqFac",
"BeqShift",
]

time_vary_ = IndShockConsumerType.time_vary_ + [
"BeqFac",
]

def __init__(self, **kwds):
params = init_wealth_in_utility.copy()
params.update(kwds)

super().__init__(**params)

self.solve_one_period = make_one_period_oo_solver(BequestWarmGlowConsumerSolver)
self.solve_one_period = make_one_period_oo_solver(
BequestWarmGlowConsumerSolver,
)

def update(self):
super().update()
self.update_parameters()

def update_parameters(self):
if isinstance(self.BeqCRRA, (int, float)):
self.BeqCRRA = [self.BeqCRRA] * self.T_cycle
elif len(self.BeqCRRA) == 1:
self.BeqCRRA *= self.T_cycle
elif len(self.BeqCRRA) != self.T_cycle:
raise ValueError(
"Bequest CRRA parameter must be a single value or a list of length T_cycle"
)
if not isinstance(self.BeqCRRA, (int, float)):
raise ValueError("Bequest CRRA parameter must be a single value.")

if isinstance(self.BeqFac, (int, float)):
self.BeqFac = [self.BeqFac] * self.T_cycle
elif len(self.BeqFac) == 1:
self.BeqFac *= self.T_cycle
elif len(self.BeqFac) != self.T_cycle:
raise ValueError(
"Bequest relative value parameter must be a single value or a list of length T_cycle"
"Bequest relative value parameter must be a single value or a list of length T_cycle",
)

if isinstance(self.BeqShift, (int, float)):
self.BeqShift = [self.BeqShift] * self.T_cycle
elif len(self.BeqShift) == 1:
self.BeqShift *= self.T_cycle
elif len(self.BeqShift) != self.T_cycle:
raise ValueError(
"Bequest Stone-Geary parameter must be a single value or a list of length T_cycle"
)
if not isinstance(self.BeqShift, (int, float)):
raise ValueError("Bequest Stone-Geary parameter must be a single value.")

def update_solution_terminal(self):
if self.TermBeqFac == 0.0: # No terminal bequest
Expand All @@ -90,7 +81,9 @@ def update_solution_terminal(self):
utility = UtilityFuncCRRA(self.CRRA)

warm_glow = UtilityFuncStoneGeary(
self.TermBeqCRRA, factor=self.TermBeqFac, shifter=self.TermBeqShift
self.TermBeqCRRA,
factor=self.TermBeqFac,
shifter=self.TermBeqShift,
)

aNrmGrid = (
Expand All @@ -117,7 +110,16 @@ def update_solution_terminal(self):
self.solution_terminal.mNrmMin = 0.0


class BequestWarmGlowPortfolioType(PortfolioConsumerType, BequestWarmGlowConsumerType):
class BequestWarmGlowPortfolioType(PortfolioConsumerType):
time_inv_ = IndShockConsumerType.time_inv_ + [
"BeqCRRA",
"BeqShift",
]

time_vary_ = IndShockConsumerType.time_vary_ + [
"BeqFac",
]

def __init__(self, **kwds):
params = init_portfolio_bequest.copy()
params.update(kwds)
Expand All @@ -127,46 +129,94 @@ def __init__(self, **kwds):
super().__init__(**params)

self.solve_one_period = make_one_period_oo_solver(
BequestWarmGlowPortfolioSolver
BequestWarmGlowPortfolioSolver,
)

def update(self):
PortfolioConsumerType.update(self)
super().update()
self.update_parameters()

def update_parameters(self):
if not isinstance(self.BeqCRRA, (int, float)):
raise ValueError("Bequest CRRA parameter must be a single value.")

if isinstance(self.BeqFac, (int, float)):
self.BeqFac = [self.BeqFac] * self.T_cycle
elif len(self.BeqFac) == 1:
self.BeqFac *= self.T_cycle
elif len(self.BeqFac) != self.T_cycle:
raise ValueError(
"Bequest relative value parameter must be a single value or a list of length T_cycle",
)

if not isinstance(self.BeqShift, (int, float)):
raise ValueError("Bequest Stone-Geary parameter must be a single value.")

def update_solution_terminal(self):
BequestWarmGlowConsumerType.update_solution_terminal(self)

# Consume all market resources: c_T = m_T
cFuncAdj_terminal = self.solution_terminal.cFunc
cFuncFxd_terminal = lambda m, s: self.solution_terminal.cFunc(m)

# Risky share is irrelevant-- no end-of-period assets; set to zero
ShareFuncAdj_terminal = ConstantFunction(0.0)
ShareFuncFxd_terminal = IdentityFunction(i_dim=1, n_dims=2)

# Value function is simply utility from consuming market resources
vFuncAdj_terminal = self.solution_terminal.vFunc
vFuncFxd_terminal = lambda m, s: self.solution_terminal.vFunc(m)

# Marginal value of market resources is marg utility at the consumption function
vPfuncAdj_terminal = self.solution_terminal.vPfunc
dvdmFuncFxd_terminal = lambda m, s: self.solution_terminal.vPfunc(m)
# No future, no marg value of Share
dvdsFuncFxd_terminal = ConstantFunction(0.0)

# Construct the terminal period solution
self.solution_terminal = PortfolioSolution(
cFuncAdj=cFuncAdj_terminal,
ShareFuncAdj=ShareFuncAdj_terminal,
vFuncAdj=vFuncAdj_terminal,
vPfuncAdj=vPfuncAdj_terminal,
cFuncFxd=cFuncFxd_terminal,
ShareFuncFxd=ShareFuncFxd_terminal,
vFuncFxd=vFuncFxd_terminal,
dvdmFuncFxd=dvdmFuncFxd_terminal,
dvdsFuncFxd=dvdsFuncFxd_terminal,
)
if self.TermBeqFac == 0.0: # No terminal bequest
super().update_solution_terminal()
else:
utility = UtilityFuncCRRA(self.CRRA)

warm_glow = UtilityFuncStoneGeary(
self.TermBeqCRRA,
factor=self.TermBeqFac,
shifter=self.TermBeqShift,
)

aNrmGrid = (
np.append(0.0, self.aXtraGrid)
if self.TermBeqShift != 0.0
else self.aXtraGrid
)
cNrmGrid = utility.derinv(warm_glow.der(aNrmGrid))
vGrid = utility(cNrmGrid) + warm_glow(aNrmGrid)
cNrmGridW0 = np.append(0.0, cNrmGrid)
mNrmGridW0 = np.append(0.0, aNrmGrid + cNrmGrid)
vNvrsGridW0 = np.append(0.0, utility.inv(vGrid))

cFunc_term = LinearInterp(mNrmGridW0, cNrmGridW0)
vNvrsFunc_term = LinearInterp(mNrmGridW0, vNvrsGridW0)
vFunc_term = ValueFuncCRRA(vNvrsFunc_term, self.CRRA)
vPfunc_term = MargValueFuncCRRA(cFunc_term, self.CRRA)
vPPfunc_term = MargMargValueFuncCRRA(cFunc_term, self.CRRA)

self.solution_terminal.cFunc = cFunc_term
self.solution_terminal.vFunc = vFunc_term
self.solution_terminal.vPfunc = vPfunc_term
self.solution_terminal.vPPfunc = vPPfunc_term
self.solution_terminal.mNrmMin = 0.0

# Consume all market resources: c_T = m_T
cFuncAdj_terminal = self.solution_terminal.cFunc
cFuncFxd_terminal = lambda m, s: self.solution_terminal.cFunc(m)

# Risky share is irrelevant-- no end-of-period assets; set to zero
ShareFuncAdj_terminal = ConstantFunction(0.0)
ShareFuncFxd_terminal = IdentityFunction(i_dim=1, n_dims=2)

# Value function is simply utility from consuming market resources
vFuncAdj_terminal = self.solution_terminal.vFunc
vFuncFxd_terminal = lambda m, s: self.solution_terminal.vFunc(m)

# Marginal value of market resources is marg utility at the consumption function
vPfuncAdj_terminal = self.solution_terminal.vPfunc
dvdmFuncFxd_terminal = lambda m, s: self.solution_terminal.vPfunc(m)
# No future, no marg value of Share
dvdsFuncFxd_terminal = ConstantFunction(0.0)

# Construct the terminal period solution
self.solution_terminal = PortfolioSolution(
cFuncAdj=cFuncAdj_terminal,
ShareFuncAdj=ShareFuncAdj_terminal,
vFuncAdj=vFuncAdj_terminal,
vPfuncAdj=vPfuncAdj_terminal,
cFuncFxd=cFuncFxd_terminal,
ShareFuncFxd=ShareFuncFxd_terminal,
vFuncFxd=vFuncFxd_terminal,
dvdmFuncFxd=dvdmFuncFxd_terminal,
dvdsFuncFxd=dvdsFuncFxd_terminal,
)


class BequestWarmGlowConsumerSolver(ConsIndShockSolver):
Expand Down Expand Up @@ -212,6 +262,28 @@ def def_utility_funcs(self):

self.warm_glow = UtilityFuncStoneGeary(self.BeqCRRA, BeqFacEff, self.BeqShift)

def def_BoroCnst(self, BoroCnstArt):
self.BoroCnstNat = (
(self.solution_next.mNrmMin - self.TranShkMinNext)
* (self.PermGroFac * self.PermShkMinNext)
/ self.Rfree
)

self.BoroCnstNat = np.max([self.BoroCnstNat, -self.BeqShift])

if BoroCnstArt is None:
self.mNrmMinNow = self.BoroCnstNat
else:
self.mNrmMinNow = np.max([self.BoroCnstNat, BoroCnstArt])
if self.BoroCnstNat < self.mNrmMinNow:
self.MPCmaxEff = 1.0
else:
self.MPCmaxEff = self.MPCmaxNow

self.cFuncNowCnst = LinearInterp(
np.array([self.mNrmMinNow, self.mNrmMinNow + 1]), np.array([0.0, 1.0])
)

def calc_EndOfPrdvP(self):
EndofPrdvP = super().calc_EndOfPrdvP()

Expand Down
27 changes: 13 additions & 14 deletions HARK/ConsumptionSaving/ConsIndShockModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,6 @@
from copy import copy, deepcopy

import numpy as np
from scipy import sparse as sp
from scipy.optimize import newton

from HARK import (
AgentType,
NullFunc,
_log,
make_one_period_oo_solver,
set_verbosity_level,
)
from HARK.Calibration.Income.IncomeTools import (
Cagetti_income,
parse_income_spec,
Expand All @@ -44,7 +34,6 @@
combine_indep_dstns,
expected,
)
from HARK.interpolation import CubicHermiteInterp as CubicInterp
from HARK.interpolation import (
CubicInterp,
LinearInterp,
Expand Down Expand Up @@ -72,6 +61,16 @@
jump_to_grid_2D,
make_grid_exp_mult,
)
from scipy import sparse as sp
from scipy.optimize import newton

from HARK import (
AgentType,
NullFunc,
_log,
make_one_period_oo_solver,
set_verbosity_level,
)

__all__ = [
"ConsumerSolution",
Expand Down Expand Up @@ -2095,8 +2094,8 @@ def calc_stable_points(self):


# Make a dictionary to specify an idiosyncratic income shocks consumer
init_idiosyncratic_shocks = dict(
init_perfect_foresight,
init_idiosyncratic_shocks = {
**init_perfect_foresight,
**{ # assets above grid parameters
"aXtraMin": 0.001, # Minimum end-of-period "assets above minimum" value
"aXtraMax": 20, # Maximum end-of-period "assets above minimum" value
Expand Down Expand Up @@ -2129,7 +2128,7 @@ def calc_stable_points(self):
# Whether Newborns have transitory shock. The default is False.
"NewbornTransShk": False,
},
)
}


class IndShockConsumerType(PerfForesightConsumerType):
Expand Down
117 changes: 62 additions & 55 deletions examples/ConsBequestModel/example_AccidentalBequest.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit de61ed1

Please sign in to comment.