Skip to content

Commit

Permalink
ENH: set data keys as first positional arguments (#488)
Browse files Browse the repository at this point in the history
  • Loading branch information
redeboer authored Nov 8, 2023
1 parent 8ebaccf commit aa36e66
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 3 deletions.
5 changes: 4 additions & 1 deletion src/tensorwaves/function/sympy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,10 @@ def create_parametrized_function(
[0.0, 0.0, 0.0, 0.0, 0.0]
"""
free_symbols = _get_free_symbols(expression)
sorted_symbols = sorted(free_symbols, key=lambda s: s.name)
parameter_set = set(parameters)
parameter_symbols = sorted(free_symbols & parameter_set, key=lambda s: s.name)
data_symbols = sorted(free_symbols - parameter_set, key=lambda s: s.name)
sorted_symbols = tuple(data_symbols + parameter_symbols) # for partial+gradient
lambdified_function = _lambdify_normal_or_fast(
expression=expression,
symbols=sorted_symbols,
Expand Down
3 changes: 2 additions & 1 deletion tests/function/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def function(self) -> ParametrizedBackendFunction:
return create_parametrized_function(expression, parameters, backend="numpy")

def test_argument_order(self, function: ParametrizedBackendFunction):
assert function.argument_order == ("c_1", "c_2", "c_3", "c_4", "x")
"""Test whether data arguments come before parameters."""
assert function.argument_order == ("x", "c_1", "c_2", "c_3", "c_4")

@pytest.mark.parametrize(
("test_data", "expected_results"),
Expand Down
2 changes: 1 addition & 1 deletion tests/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def test_create_cached_function(backend):

assert isinstance(cached_function, ParametrizedBackendFunction)
assert isinstance(cache_transformer, SympyDataTransformer)
assert cached_function.argument_order == ("a", "c", "f0", "x")
assert cached_function.argument_order == ("f0", "x", "a", "c") # data args first
assert set(cached_function.parameters) == {"a", "c"}
assert set(cache_transformer.functions) == {"f0", "x"}

Expand Down

0 comments on commit aa36e66

Please sign in to comment.