Skip to content

Commit

Permalink
Refactor latex printing
Browse files Browse the repository at this point in the history
Fix failing tests

Pin sympy<1.13
  • Loading branch information
jessegrabowski committed Oct 16, 2024
1 parent f6c073e commit 7298e04
Show file tree
Hide file tree
Showing 9 changed files with 48 additions and 36 deletions.
42 changes: 24 additions & 18 deletions cge_modeling/base/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,18 @@ def _pretty_print_dim_flags(s, dims, dim_vals):
return s


def _process_subscript_chunks(chunks):
return ", ".join(chunks)
def _tokenize_latex_subscript(subscript):
"""
Given input of the form "_{...}", where ... is a comma separated list of subscript values,
return a list of the individual subscript values.
"""
if subscript.startswith("_{"):
subscript = subscript[2:]
if subscript.endswith("}"):
subscript = subscript[:-1]

tokens = [x.strip() for x in subscript.split(",")]
return tokens if tokens != [""] else []


@dataclass(slots=True, order=True, frozen=False, repr=False)
Expand Down Expand Up @@ -129,28 +139,24 @@ def _initialize_latex_name(self):

def _initialize_full_latex_name(self):
idx_subscript = self._make_latex_subscript()
idx_subscript_tokens = _tokenize_latex_subscript(idx_subscript)
base_latex_name = self.latex_name
int_extend = int(self.extend_subscript)
has_subscript = "_{" in base_latex_name and "}" in base_latex_name

if self.extend_subscript and has_subscript:
tokens = base_latex_name.split("_")
base_chunks, subscript_chunks = tokens[:-int_extend], tokens[-int_extend:]
subscript = _process_subscript_chunks(subscript_chunks)
if self.extend_subscript:
tokens = [re.sub(r"[\{\}]", "", x.strip()) for x in base_latex_name.split("_")]
base_tokens, subscript_tokens = tokens[:-int_extend], tokens[-int_extend:]

if len(base_chunks) == 0:
base_latex_name, subscript = subscript, base_chunks
else:
base_latex_name = "_".join(base_chunks)
if len(base_tokens) == 0:
base_tokens, subscript_tokens = subscript_tokens, base_tokens
base_latex_name = "_".join(base_tokens)

if len(subscript) > 0:
subscript = subscript.replace("{", "").replace("}", "")
idx_subscript = "_{" + subscript
if len(idx_subscript) > 0:
subscript += ", " + idx_subscript[2:]
subscript_tokens = subscript_tokens + idx_subscript_tokens

if not idx_subscript.endswith("}"):
idx_subscript += "}"
if len(subscript_tokens) == 0:
idx_subscript = ""
else:
idx_subscript = "_{" + ", ".join(subscript_tokens) + "}"

if "_" in base_latex_name:
base_latex_name = r"\text{" + base_latex_name + "}"
Expand Down
2 changes: 1 addition & 1 deletion conda_envs/environment.yaml → conda_envs/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ dependencies:
- pip
- numpy
- scipy
- sympy
- sympy<1.13
- pandas
- numba
- pytensor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ dependencies:
- pip
- numpy
- scipy
- sympy
- sympy<1.13
- pandas
- numba
- IPython
Expand Down
10 changes: 8 additions & 2 deletions tests/test_cge_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import sys

from collections.abc import Callable
from importlib.util import find_spec
from unittest.mock import patch

import numpy as np
Expand All @@ -23,6 +24,8 @@
model_2_data,
)

JAX_INSTALLED = find_spec("JAX") is not None


@pytest.fixture()
def model():
Expand Down Expand Up @@ -351,7 +354,10 @@ def test_long_unpack():
"backend", ["numba", "sympytensor", "pytensor"], ids=["numba", "sympytensor", "pytensor"]
)
def test_model_gradients(model_function, jac_function, backend):
mode = "FAST_COMPILE" if backend in ["pytensor", "sympytensor"] else None
mode = "FAST_RUN" if backend in ["pytensor", "sympytensor"] else None
if mode == "FAST_RUN" and JAX_INSTALLED:
mode = "JAX"

mod = model_function(
backend=backend,
mode=mode,
Expand Down Expand Up @@ -391,7 +397,7 @@ def test_model_gradients(model_function, jac_function, backend):
def test_pytensor_from_sympy(model_function, calibrate_model, f_expected_jac, data, sparse):
mod = model_function(
backend="sympytensor",
mode="FAST_COMPILE",
mode="FAST_RUN" if not JAX_INSTALLED else "JAX",
use_sparse_matrices=sparse,
)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def f_analytic(v3):
v2 = 2 - v1
return np.array([v1, v2])

mode = "FAST_COMPILE"
mode = "FAST_RUN"

theta_final, n_steps, result = symbolic_euler_approximation(
equations, variables=variables, parameters=[parameters]
Expand Down
4 changes: 2 additions & 2 deletions tests/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def f_model(Y, C, L_d, K_d, P, r, resid, K_s, L_s, A, alpha, w):
[1.16392629e04, 1.16392629e04, 7000.0, 4000.0, 0.897630824, 0.861940299, 0.0]
)
f_model_compiled = pytensor.function(
inputs=variables + params, outputs=equations, mode="FAST_COMPILE"
inputs=variables + params, outputs=equations, mode="FAST_RUN"
)

# Check optimizer converges to the scipy result
Expand Down Expand Up @@ -151,7 +151,7 @@ def test_small_model_from_compile():


def test_sector_model_from_compile():
mod = load_model_2(mode="FAST_COMPILE", backend="pytensor")
mod = load_model_2(mode="FAST_RUN", backend="pytensor")
calib_dict = calibrate_model_2(**model_2_data)

x0 = {var.name: calib_dict[var.name] for var in mod.variables}
Expand Down
2 changes: 1 addition & 1 deletion tests/test_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_to_dict(cls):
("x_F", ["i", "j"], True, "x_{F, i=\\text{A}, j}"),
("x_{Fish}", ["i", "i"], True, "x_{Fish, i=\\text{A}, i=\\text{A}}"),
("x_K_d", ["i", "j"], 2, "x_{K, d, i=\\text{A}, j}"),
("var_with_underscore", ["i"], False, "var_with_underscore_{i=\\text{A}}"),
("var_with_underscore", ["i"], False, "\\text{var_with_underscore}_{i=\\text{A}}"),
("P_CE", None, True, "P_{CE}"),
],
)
Expand Down
18 changes: 9 additions & 9 deletions tests/test_production_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def compile_Y():
return locals()["Y"], [K, L, A, alpha]

Y, inputs = compile_Y()
f_Y = pytensor.function(inputs, Y, on_unused_input="ignore", mode="FAST_COMPILE")
f_Y = pytensor.function(inputs, Y, on_unused_input="ignore", mode="FAST_RUN")
expected = expected_Y(L=1, K=1, A=1, alpha=0.5)
np.testing.assert_allclose(f_Y(1, 1, 1, 0.5), expected)

Expand All @@ -160,7 +160,7 @@ def compile_L():
return locals()["L"], [Y, w, P, alpha]

L, inputs = compile_L()
f_L = pytensor.function(inputs, L, on_unused_input="ignore", mode="FAST_COMPILE")
f_L = pytensor.function(inputs, L, on_unused_input="ignore", mode="FAST_RUN")
expected = expected_L(Y=1, w=1, P=1, alpha=0.5)
np.testing.assert_allclose(f_L(1, 1, 1, 0.5), expected)

Expand All @@ -174,7 +174,7 @@ def compile_K():
return locals()["K"], [Y, r, P, alpha]

K, inputs = compile_K()
f_K = pytensor.function(inputs, K, on_unused_input="ignore", mode="FAST_COMPILE")
f_K = pytensor.function(inputs, K, on_unused_input="ignore", mode="FAST_RUN")
expected = expected_K(Y=1, r=1, P=1, alpha=0.5)
np.testing.assert_allclose(f_K(1, 1, 1, 0.5), expected)

Expand Down Expand Up @@ -394,7 +394,7 @@ def compile_Y():
return locals()["Y"], [X, P_X, P_Y, alpha, A, epsilon]

Y, inputs = compile_Y()
f_Y = pytensor.function(inputs, Y, on_unused_input="ignore", mode="FAST_COMPILE")
f_Y = pytensor.function(inputs, Y, on_unused_input="ignore", mode="FAST_RUN")
expected = dx_Y(X=np.arange(7), A=1, alpha=np.full(7, 1 / 7), epsilon=3)
np.testing.assert_allclose(
f_Y(np.arange(7), np.ones(7), 1, np.full(7, 1 / 7), 1, 3), expected
Expand All @@ -412,7 +412,7 @@ def compile_X():
return locals()["X"], [Y, P_Y, P_X, A, alpha, epsilon]

X, inputs = compile_X()
f_X = pytensor.function(inputs, X, on_unused_input="ignore", mode="FAST_COMPILE")
f_X = pytensor.function(inputs, X, on_unused_input="ignore", mode="FAST_RUN")
expected = dx_X(Y=1, A=1, alpha=np.full(7, 1 / 7), P_Y=1, P_X=np.ones(7), epsilon=3)
np.testing.assert_allclose(f_X(1, 1, np.ones(7), 1, np.full(7, 1 / 7), 3), expected)

Expand Down Expand Up @@ -624,7 +624,7 @@ def compile_Y():
return locals()["Y"], [VA, VC, P_VA, P_VC, P_Y]

Y, inputs = compile_Y()
f_Y = pytensor.function(inputs, Y, on_unused_input="ignore", mode="FAST_COMPILE")
f_Y = pytensor.function(inputs, Y, on_unused_input="ignore", mode="FAST_RUN")
expected = leontief_Y(
VA=np.arange(7), VC=np.arange(7), P_VA=np.ones(7), P_VC=np.ones(7), P_Y=1
)
Expand All @@ -640,7 +640,7 @@ def compile_VA():
return locals()["VA"], [Y, phi_VA]

VA, inputs = compile_VA()
f_VA = pytensor.function(inputs, VA, on_unused_input="ignore", mode="FAST_COMPILE")
f_VA = pytensor.function(inputs, VA, on_unused_input="ignore", mode="FAST_RUN")
expected = leontief_VA(Y=1, phi_VA=np.full(7, 1 / 7))
np.testing.assert_allclose(f_VA(1, np.full(7, 1 / 7)), expected)

Expand All @@ -652,7 +652,7 @@ def compile_VC():
return locals()["VC"], [Y, phi_VC]

VC, inputs = compile_VC()
f_VC = pytensor.function(inputs, VC, on_unused_input="ignore", mode="FAST_COMPILE")
f_VC = pytensor.function(inputs, VC, on_unused_input="ignore", mode="FAST_RUN")
expected = leontief_VC(Y=1, phi_VC=np.full(7, 1 / 7))
np.testing.assert_allclose(f_VC(1, np.full(7, 1 / 7)), expected)

Expand Down Expand Up @@ -754,6 +754,6 @@ def compile_X():
return locals()["X"], [Y, phi_X]

Y, inputs = compile_Y()
f_Y = pytensor.function(inputs, Y, on_unused_input="ignore", mode="FAST_COMPILE")
f_Y = pytensor.function(inputs, Y, on_unused_input="ignore", mode="FAST_RUN")
expected = leontief_Y(X=X_val, P_X=X_prices, P_Y=P_Y)
np.testing.assert_allclose(f_Y(P_Y, X_val, X_prices), expected)
2 changes: 1 addition & 1 deletion tests/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def test_wrap_pytensor_func_for_scipy():
f = a * x + b * y + c * z
mse = (f**2).sum()

f_mse = pytensor.function([x, y, z, a, b, c], mse, mode="FAST_COMPILE")
f_mse = pytensor.function([x, y, z, a, b, c], mse, mode="FAST_RUN")
test_value_dict = {"a": 1.0, "b": 2.0, "c": 3.0, "x": 1.0, "y": 2.0, "z": 3.0}
assert f_mse(**test_value_dict) == 196.0

Expand Down

0 comments on commit 7298e04

Please sign in to comment.