From 01635e6a3268f09721b181c15dd1796e06f6190c Mon Sep 17 00:00:00 2001 From: seitzdom Date: Tue, 2 Jul 2024 16:45:06 +0200 Subject: [PATCH 01/17] [Feature] Main Embedding and ConcretizedCallable logic --- README.md | 4 +- docs/index.md | 0 mkdocs.yml | 71 +++++++++++++++++++++ pyproject.toml | 107 ++++++++++++++++++++++++++++++++ qadence_embeddings/__init__.py | 4 ++ qadence_embeddings/callable.py | 83 +++++++++++++++++++++++++ qadence_embeddings/embedding.py | 38 ++++++++++++ tests/test_callable.py | 0 tests/test_embedding.py | 0 9 files changed, 306 insertions(+), 1 deletion(-) create mode 100644 docs/index.md create mode 100644 mkdocs.yml create mode 100644 pyproject.toml create mode 100644 qadence_embeddings/__init__.py create mode 100644 qadence_embeddings/callable.py create mode 100644 qadence_embeddings/embedding.py create mode 100644 tests/test_callable.py create mode 100644 tests/test_embedding.py diff --git a/README.md b/README.md index 528dffa..c13e049 100644 --- a/README.md +++ b/README.md @@ -1 +1,3 @@ -# qadence-embeddings \ No newline at end of file +# qadence-embeddings + +**qadence-embeddings** is a engine-agnostic parameter embedding library. diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 0000000..e69de29 diff --git a/mkdocs.yml b/mkdocs.yml new file mode 100644 index 0000000..4c5a14c --- /dev/null +++ b/mkdocs.yml @@ -0,0 +1,71 @@ +site_name: qadence-embeddings +repo_url: "https://github.com/pasqal-io/qadence-embeddings" +repo_name: "qadence-embeddings" + +nav: + + - Embeddings in a nutshell: index.md + + +theme: + name: material + features: + - content.code.annotate + - navigation.indexes + - navigation.sections + + palette: + - media: "(prefers-color-scheme: light)" + scheme: default + primary: light green + accent: purple + toggle: + icon: material/weather-sunny + name: Switch to dark mode + - media: "(prefers-color-scheme: dark)" + scheme: slate + primary: black + accent: light green + toggle: + icon: material/weather-night + name: Switch to light mode + +markdown_extensions: +- admonition # for notes +- pymdownx.arithmatex: # for mathjax + generic: true +- pymdownx.highlight: + anchor_linenums: true +- pymdownx.inlinehilite +- pymdownx.snippets +- pymdownx.superfences + +plugins: +- search +- section-index +- mkdocstrings: + default_handler: python + handlers: + python: + selection: + filters: + - "!^_" # exlude all members starting with _ + - "^__init__$" # but always include __init__ modules and methods + +- mkdocs-jupyter: + theme: light +- markdown-exec + +extra_css: +- extras/css/mkdocstrings.css +- extras/css/colors.css +- extras/css/home.css + +# For mathjax +extra_javascript: + - extras/javascripts/mathjax.js + - https://polyfill.io/v3/polyfill.min.js?features=es6 + - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js + +watch: + - qadence_embeddings diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..2515bfc --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,107 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "qadence-embeddings" +description = "a engine-agnostic parameter embedding library." +authors = [ + { name = "Dominik Seitz", email = "dominik.seitz@pasqal.com" }, +] +requires-python = ">=3.8,<3.13" +license = {text = "Apache 2.0"} +version = "1.3.0" +classifiers=[ + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] +dependencies = ["torch", "numpy"] + +[project.optional-dependencies] +dev = ["flaky","black", "pytest", "pytest-xdist", "pytest-cov", "flake8", "mypy", "pre-commit", "ruff", "nbconvert", "matplotlib", "qutip~=4.7.5"] + +[tool.hatch.envs.tests] +features = [ + "dev", +] + +[tool.hatch.envs.tests.scripts] +test = "pytest -n auto {args}" +test-docs = "hatch -e docs run mkdocs build --clean --strict" +test-cov = "pytest -n auto --cov=qadence-embeddings tests/" + +[tool.hatch.envs.docs] +dependencies = [ + "mkdocs", + "mkdocs-material", + "mkdocstrings", + "mkdocstrings-python", + "mkdocs-section-index", + "mkdocs-jupyter", + "mkdocs-exclude", + "markdown-exec", + "mike", + "matplotlib", +] + +[tool.hatch.envs.docs.scripts] +build = "mkdocs build --clean --strict" +serve = "mkdocs serve --dev-addr localhost:8000" + +[tool.ruff] +lint.select = ["E", "F", "I", "Q"] +lint.extend-ignore = ["F841"] +line-length = 100 + +[tool.ruff.lint.isort] +required-imports = ["from __future__ import annotations"] + +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["F401", "E402"] + +[tool.ruff.lint.mccabe] +max-complexity = 15 + +[tool.ruff.lint.flake8-quotes] +docstring-quotes = "double" + +[lint.black] +line-length = 100 +include = '\.pyi?$' +exclude = ''' +/( + \.git + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | _build + | buck-out + | build + | dist +)/ +''' + +[lint.isort] +line_length = 100 +combine_as_imports = true +balanced_wrapping = true +lines_after_imports = 2 +include_trailing_comma = true +multi_line_output = 5 + +[lint.mypy] +python_version = "3.10" +warn_return_any = true +warn_unused_configs = true +disallow_untyped_defs = true +no_implicit_optional = false +ignore_missing_imports = true diff --git a/qadence_embeddings/__init__.py b/qadence_embeddings/__init__.py new file mode 100644 index 0000000..da6685a --- /dev/null +++ b/qadence_embeddings/__init__.py @@ -0,0 +1,4 @@ +from __future__ import annotations + +from .callable import ConcretizedCallable +from .embedding import Embedding diff --git a/qadence_embeddings/callable.py b/qadence_embeddings/callable.py new file mode 100644 index 0000000..7061472 --- /dev/null +++ b/qadence_embeddings/callable.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +from importlib import import_module +from logging import getLogger +from typing import Callable + +from numpy.typing import ArrayLike + +logger = getLogger(__name__) + +ARRAYLIKE_FN_MAP = { + "torch": ("torch", "tensor"), + "jax": ("jax.numpy", "array"), + "numpy": ("numpy", "array"), +} + + +def ConcretizedCallable( + call_name: str, + abstract_args: list[str | float | int], + instruction_mapping: dict[str, Callable] = dict(), + engine_name: str = "torch", +) -> Callable[[dict, dict], ArrayLike]: + """Convert a generic abstract function call and + a list of symbolic or constant parameters + into a concretized Callable in a particular engine. + which can be evaluated using + a vparams and inputs dict. + + Arguments: + call_name: The name of the function + abstract_args: A list of strings (in the case of parameters) and numeric constants + denoting the arguments for `call_name` + instruction_mapping: A dict mapping from an abstract call_name to its name in an engine. + engine_name: The engine to use to create the callable. + + Example: + ``` + In [11]: call = ConcretizedCallable('sin', ['x'], engine_name='numpy') + In [12]: call({'x': 0.5}) + Out[12]: 0.479425538604203 + + In [13]: call = ConcretizedCallable('sin', ['x'], engine_name='torch') + In [14]: call({'x': torch.rand(1)}) + Out[14]: tensor([0.5531]) + + In [15]: call = ConcretizedCallable('sin', ['x'], engine_name='jax') + In [16]: call({'x': 0.5}) + Out[16]: Array(0.47942555, dtype=float32, weak_type=True) + ``` + """ + engine_call = None + engine = None + try: + engine_name, fn_name = ARRAYLIKE_FN_MAP[engine_name] + engine = import_module(engine_name) + arraylike_fn = getattr(engine, fn_name) + except (ModuleNotFoundError, ImportError) as e: + logger.error(f"Unable to import {engine_call} due to {e}.") + + try: + engine_call = getattr(engine, call_name) + except ImportError: + pass + if engine_call is None: + try: + engine_call = instruction_mapping[call_name] + except KeyError as e: + logger.error( + f"Requested function {call_name} can not be imported from {engine_name} and is\ + not in instruction_mapping {instruction_mapping} due to {e}." + ) + + def evaluate(params: dict = dict(), inputs: dict = dict()) -> ArrayLike: + arraylike_args = [] + for symbol_or_numeric in abstract_args: + if isinstance(symbol_or_numeric, (float, int)): + arraylike_args.append(arraylike_fn(symbol_or_numeric)) + elif isinstance(symbol_or_numeric, str): + arraylike_args.append({**params, **inputs}[symbol_or_numeric]) + return engine_call(*arraylike_args) # type: ignore[misc] + + return evaluate diff --git a/qadence_embeddings/embedding.py b/qadence_embeddings/embedding.py new file mode 100644 index 0000000..3b2b5e0 --- /dev/null +++ b/qadence_embeddings/embedding.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Callable, Optional + + +class ParameterType: + pass + + +class DType: + pass + + +class Embedding(ABC): + """ + A generic module class to hold and handle the parameters and expressions + functions coming from the `Model`. It may contain the list of user input + parameters, as well as the trainable variational parameters and the + evaluated functions from the data types being used, i.e. torch, numpy, etc. + """ + + vparams: dict[str, ParameterType] + fparams: dict[str, Optional[ParameterType]] + mapped_vars: dict[str, Callable] + _dtype: DType + + @abstractmethod + def __call__(self, *args: Any, **kwargs: Any) -> dict[str, Callable]: + raise NotImplementedError() + + @abstractmethod + def name_mapping(self) -> dict: + raise NotImplementedError() + + @property + def dtype(self) -> DType: + return self._dtype diff --git a/tests/test_callable.py b/tests/test_callable.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_embedding.py b/tests/test_embedding.py new file mode 100644 index 0000000..e69de29 From 0296b592ff62cf7a6fe19404b474645ee3f13b41 Mon Sep 17 00:00:00 2001 From: seitzdom Date: Thu, 4 Jul 2024 13:13:46 +0200 Subject: [PATCH 02/17] rework to class, add tests --- pyproject.toml | 2 +- qadence_embeddings/callable.py | 94 ++++++++++++++++++++++------------ tests/test_callable.py | 63 +++++++++++++++++++++++ 3 files changed, 124 insertions(+), 35 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2515bfc..eb2c988 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ classifiers=[ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["torch", "numpy"] +dependencies = ["torch", "numpy", "jax"] [project.optional-dependencies] dev = ["flaky","black", "pytest", "pytest-xdist", "pytest-cov", "flake8", "mypy", "pre-commit", "ruff", "nbconvert", "matplotlib", "qutip~=4.7.5"] diff --git a/qadence_embeddings/callable.py b/qadence_embeddings/callable.py index 7061472..408da7e 100644 --- a/qadence_embeddings/callable.py +++ b/qadence_embeddings/callable.py @@ -15,17 +15,63 @@ } -def ConcretizedCallable( - call_name: str, - abstract_args: list[str | float | int], - instruction_mapping: dict[str, Callable] = dict(), - engine_name: str = "torch", -) -> Callable[[dict, dict], ArrayLike]: +DEFAULT_JAX_MAPPING = {"mul": ("jax.numpy", "multiply")} +DEFAULT_TORCH_MAPPING = {} +DEFAULT_NUMPY_MAPPING = {"mul": ("numpy", "multiply")} + +DEFAULT_INSTRUCTION_MAPPING = { + "torch": DEFAULT_TORCH_MAPPING, + "jax": DEFAULT_JAX_MAPPING, + "numpy": DEFAULT_NUMPY_MAPPING, +} + + +class ConcretizedCallable: + + def __init__( + self, + call_name: str, + abstract_args: list[str | float | int], + instruction_mapping: dict[str, Callable] = dict(), + engine_name: str = "torch", + ) -> None: + instruction_mapping = { + **instruction_mapping, + **DEFAULT_INSTRUCTION_MAPPING[engine_name], + } + self.call_name = call_name + self.abstract_args = abstract_args + self.engine_name = engine_name + self.engine_call = None + engine_call = None + engine = None + try: + engine_name, fn_name = ARRAYLIKE_FN_MAP[engine_name] + engine = import_module(engine_name) + self.arraylike_fn = getattr(engine, fn_name) + except (ModuleNotFoundError, ImportError) as e: + logger.error(f"Unable to import {engine_call} due to {e}.") + + try: + # breakpoint() + try: + self.engine_call = getattr(engine, call_name) + except AttributeError: + pass + if self.engine_call is None: + mod, fn = instruction_mapping[call_name] + self.engine_call = getattr(import_module(mod), fn) + except (ImportError, KeyError) as e: + logger.error( + f"Requested function {call_name} can not be imported from {engine_name} and is\ + not in instruction_mapping {instruction_mapping} due to {e}." + ) + """Convert a generic abstract function call and a list of symbolic or constant parameters into a concretized Callable in a particular engine. which can be evaluated using - a vparams and inputs dict. + a inputs dict. Arguments: call_name: The name of the function @@ -49,35 +95,15 @@ def ConcretizedCallable( Out[16]: Array(0.47942555, dtype=float32, weak_type=True) ``` """ - engine_call = None - engine = None - try: - engine_name, fn_name = ARRAYLIKE_FN_MAP[engine_name] - engine = import_module(engine_name) - arraylike_fn = getattr(engine, fn_name) - except (ModuleNotFoundError, ImportError) as e: - logger.error(f"Unable to import {engine_call} due to {e}.") - - try: - engine_call = getattr(engine, call_name) - except ImportError: - pass - if engine_call is None: - try: - engine_call = instruction_mapping[call_name] - except KeyError as e: - logger.error( - f"Requested function {call_name} can not be imported from {engine_name} and is\ - not in instruction_mapping {instruction_mapping} due to {e}." - ) - def evaluate(params: dict = dict(), inputs: dict = dict()) -> ArrayLike: + def evaluate(self, inputs: dict[str, ArrayLike] = dict()) -> ArrayLike: arraylike_args = [] - for symbol_or_numeric in abstract_args: + for symbol_or_numeric in self.abstract_args: if isinstance(symbol_or_numeric, (float, int)): - arraylike_args.append(arraylike_fn(symbol_or_numeric)) + arraylike_args.append(self.arraylike_fn(symbol_or_numeric)) elif isinstance(symbol_or_numeric, str): - arraylike_args.append({**params, **inputs}[symbol_or_numeric]) - return engine_call(*arraylike_args) # type: ignore[misc] + arraylike_args.append(inputs[symbol_or_numeric]) + return self.engine_call(*arraylike_args) # type: ignore[misc] - return evaluate + def __call__(self, inputs: dict[str, ArrayLike] = dict()) -> ArrayLike: + return self.evaluate(inputs) diff --git a/tests/test_callable.py b/tests/test_callable.py index e69de29..8ba1842 100644 --- a/tests/test_callable.py +++ b/tests/test_callable.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +import numpy as np +import torch + +from qadence_embeddings.callable import ConcretizedCallable + + +def test_sin() -> None: + results = [] + x = np.random.randn(1) + for engine_name in ["jax", "torch", "numpy"]: + native_call = ConcretizedCallable("sin", ["x"], {}, engine_name) + native_result = native_call( + {"x": (torch.tensor(x) if engine_name == "torch" else x)} + ) + results.append(native_result.item()) + assert np.allclose(results[0], results[1]) and np.allclose(results[0], results[2]) + + +def test_log() -> None: + results = [] + x = np.random.uniform(0, 5) + for engine_name in ["jax", "torch", "numpy"]: + native_call = ConcretizedCallable("log", ["x"], {}, engine_name) + native_result = native_call( + {"x": (torch.tensor(x) if engine_name == "torch" else x)} + ) + results.append(native_result.item()) + + assert np.allclose(results[0], results[1]) and np.allclose(results[0], results[2]) + + +def test_mul() -> None: + results = [] + x = np.random.randn(1) + y = np.random.randn(1) + for engine_name in ["jax", "torch", "numpy"]: + native_call = ConcretizedCallable("mul", ["x", "y"], {}, engine_name) + native_result = native_call( + { + "x": torch.tensor(x) if engine_name == "torch" else x, + "y": torch.tensor(y) if engine_name == "torch" else y, + } + ) + results.append(native_result.item()) + assert np.allclose(results[0], results[1]) and np.allclose(results[0], results[2]) + + +def test_add() -> None: + results = [] + x = np.random.randn(1) + y = np.random.randn(1) + for engine_name in ["jax", "torch", "numpy"]: + native_call = ConcretizedCallable("add", ["x", "y"], {}, engine_name) + native_result = native_call( + { + "x": torch.tensor(x) if engine_name == "torch" else x, + "y": torch.tensor(y) if engine_name == "torch" else y, + } + ) + results.append(native_result.item()) + assert np.allclose(results[0], results[1]) and np.allclose(results[0], results[2]) From e6dd59f9ce6b092600fd33d54674fe0f1c4f4561 Mon Sep 17 00:00:00 2001 From: seitzdom Date: Thu, 4 Jul 2024 13:30:03 +0200 Subject: [PATCH 03/17] add arithmetics tests --- qadence_embeddings/callable.py | 12 ++++++++++-- tests/test_callable.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/qadence_embeddings/callable.py b/qadence_embeddings/callable.py index 408da7e..c3710bb 100644 --- a/qadence_embeddings/callable.py +++ b/qadence_embeddings/callable.py @@ -15,9 +15,17 @@ } -DEFAULT_JAX_MAPPING = {"mul": ("jax.numpy", "multiply")} +DEFAULT_JAX_MAPPING = { + "mul": ("jax.numpy", "multiply"), + "sub": ("jax.numpy", "subtract"), + "div": ("jax.numpy", "divide"), +} DEFAULT_TORCH_MAPPING = {} -DEFAULT_NUMPY_MAPPING = {"mul": ("numpy", "multiply")} +DEFAULT_NUMPY_MAPPING = { + "mul": ("numpy", "multiply"), + "sub": ("numpy", "subtract"), + "div": ("numpy", "divide"), +} DEFAULT_INSTRUCTION_MAPPING = { "torch": DEFAULT_TORCH_MAPPING, diff --git a/tests/test_callable.py b/tests/test_callable.py index 8ba1842..ad66c53 100644 --- a/tests/test_callable.py +++ b/tests/test_callable.py @@ -61,3 +61,35 @@ def test_add() -> None: ) results.append(native_result.item()) assert np.allclose(results[0], results[1]) and np.allclose(results[0], results[2]) + + +def test_subtract() -> None: + results = [] + x = np.random.randn(1) + y = np.random.randn(1) + for engine_name in ["jax", "torch", "numpy"]: + native_call = ConcretizedCallable("sub", ["x", "y"], {}, engine_name) + native_result = native_call( + { + "x": torch.tensor(x) if engine_name == "torch" else x, + "y": torch.tensor(y) if engine_name == "torch" else y, + } + ) + results.append(native_result.item()) + assert np.allclose(results[0], results[1]) and np.allclose(results[0], results[2]) + + +def test_div() -> None: + results = [] + x = np.random.randn(1) + y = np.random.randn(1) + for engine_name in ["jax", "torch", "numpy"]: + native_call = ConcretizedCallable("div", ["x", "y"], {}, engine_name) + native_result = native_call( + { + "x": torch.tensor(x) if engine_name == "torch" else x, + "y": torch.tensor(y) if engine_name == "torch" else y, + } + ) + results.append(native_result.item()) + assert np.allclose(results[0], results[1]) and np.allclose(results[0], results[2]) From 0a7fbea8dc60fc758d4cfdc1f296909fc593b789 Mon Sep 17 00:00:00 2001 From: seitzdom Date: Thu, 4 Jul 2024 13:31:13 +0200 Subject: [PATCH 04/17] more tests --- tests/test_callable.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/test_callable.py b/tests/test_callable.py index ad66c53..b87a04b 100644 --- a/tests/test_callable.py +++ b/tests/test_callable.py @@ -18,6 +18,18 @@ def test_sin() -> None: assert np.allclose(results[0], results[1]) and np.allclose(results[0], results[2]) +def test_cos() -> None: + results = [] + x = np.random.randn(1) + for engine_name in ["jax", "torch", "numpy"]: + native_call = ConcretizedCallable("cos", ["x"], {}, engine_name) + native_result = native_call( + {"x": (torch.tensor(x) if engine_name == "torch" else x)} + ) + results.append(native_result.item()) + assert np.allclose(results[0], results[1]) and np.allclose(results[0], results[2]) + + def test_log() -> None: results = [] x = np.random.uniform(0, 5) From 2586c0ea92b387a22954ba7e498a237076a096d1 Mon Sep 17 00:00:00 2001 From: seitzdom Date: Thu, 4 Jul 2024 13:37:36 +0200 Subject: [PATCH 05/17] mark parametrize --- qadence_embeddings/callable.py | 1 - tests/test_callable.py | 86 ++++------------------------------ 2 files changed, 8 insertions(+), 79 deletions(-) diff --git a/qadence_embeddings/callable.py b/qadence_embeddings/callable.py index c3710bb..3afa2a8 100644 --- a/qadence_embeddings/callable.py +++ b/qadence_embeddings/callable.py @@ -61,7 +61,6 @@ def __init__( logger.error(f"Unable to import {engine_call} due to {e}.") try: - # breakpoint() try: self.engine_call = getattr(engine, call_name) except AttributeError: diff --git a/tests/test_callable.py b/tests/test_callable.py index b87a04b..29cf7cf 100644 --- a/tests/test_callable.py +++ b/tests/test_callable.py @@ -1,102 +1,32 @@ from __future__ import annotations import numpy as np +import pytest import torch from qadence_embeddings.callable import ConcretizedCallable -def test_sin() -> None: +@pytest.mark.parametrize("fn", ["sin", "cos", "log", "tanh", "sin", "sqrt", "square"]) +def test_univariate(fn: str) -> None: results = [] - x = np.random.randn(1) - for engine_name in ["jax", "torch", "numpy"]: - native_call = ConcretizedCallable("sin", ["x"], {}, engine_name) - native_result = native_call( - {"x": (torch.tensor(x) if engine_name == "torch" else x)} - ) - results.append(native_result.item()) - assert np.allclose(results[0], results[1]) and np.allclose(results[0], results[2]) - - -def test_cos() -> None: - results = [] - x = np.random.randn(1) - for engine_name in ["jax", "torch", "numpy"]: - native_call = ConcretizedCallable("cos", ["x"], {}, engine_name) - native_result = native_call( - {"x": (torch.tensor(x) if engine_name == "torch" else x)} - ) - results.append(native_result.item()) - assert np.allclose(results[0], results[1]) and np.allclose(results[0], results[2]) - - -def test_log() -> None: - results = [] - x = np.random.uniform(0, 5) + x = np.random.uniform(0, 1) for engine_name in ["jax", "torch", "numpy"]: - native_call = ConcretizedCallable("log", ["x"], {}, engine_name) + native_call = ConcretizedCallable(fn, ["x"], {}, engine_name) native_result = native_call( {"x": (torch.tensor(x) if engine_name == "torch" else x)} ) results.append(native_result.item()) - - assert np.allclose(results[0], results[1]) and np.allclose(results[0], results[2]) - - -def test_mul() -> None: - results = [] - x = np.random.randn(1) - y = np.random.randn(1) - for engine_name in ["jax", "torch", "numpy"]: - native_call = ConcretizedCallable("mul", ["x", "y"], {}, engine_name) - native_result = native_call( - { - "x": torch.tensor(x) if engine_name == "torch" else x, - "y": torch.tensor(y) if engine_name == "torch" else y, - } - ) - results.append(native_result.item()) - assert np.allclose(results[0], results[1]) and np.allclose(results[0], results[2]) - - -def test_add() -> None: - results = [] - x = np.random.randn(1) - y = np.random.randn(1) - for engine_name in ["jax", "torch", "numpy"]: - native_call = ConcretizedCallable("add", ["x", "y"], {}, engine_name) - native_result = native_call( - { - "x": torch.tensor(x) if engine_name == "torch" else x, - "y": torch.tensor(y) if engine_name == "torch" else y, - } - ) - results.append(native_result.item()) - assert np.allclose(results[0], results[1]) and np.allclose(results[0], results[2]) - - -def test_subtract() -> None: - results = [] - x = np.random.randn(1) - y = np.random.randn(1) - for engine_name in ["jax", "torch", "numpy"]: - native_call = ConcretizedCallable("sub", ["x", "y"], {}, engine_name) - native_result = native_call( - { - "x": torch.tensor(x) if engine_name == "torch" else x, - "y": torch.tensor(y) if engine_name == "torch" else y, - } - ) - results.append(native_result.item()) assert np.allclose(results[0], results[1]) and np.allclose(results[0], results[2]) -def test_div() -> None: +@pytest.mark.parametrize("fn", ["mul", "add", "div", "sub"]) +def test_multivariate(fn: str) -> None: results = [] x = np.random.randn(1) y = np.random.randn(1) for engine_name in ["jax", "torch", "numpy"]: - native_call = ConcretizedCallable("div", ["x", "y"], {}, engine_name) + native_call = ConcretizedCallable(fn, ["x", "y"], {}, engine_name) native_result = native_call( { "x": torch.tensor(x) if engine_name == "torch" else x, From 8457f98f3f4dbdb5e97535c528228e514e1214a7 Mon Sep 17 00:00:00 2001 From: seitzdom Date: Thu, 4 Jul 2024 13:41:02 +0200 Subject: [PATCH 06/17] test more fns --- tests/test_callable.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_callable.py b/tests/test_callable.py index 29cf7cf..5faf29b 100644 --- a/tests/test_callable.py +++ b/tests/test_callable.py @@ -7,7 +7,9 @@ from qadence_embeddings.callable import ConcretizedCallable -@pytest.mark.parametrize("fn", ["sin", "cos", "log", "tanh", "sin", "sqrt", "square"]) +@pytest.mark.parametrize( + "fn", ["sin", "cos", "log", "tanh", "tan", "acos", "sin", "sqrt", "square"] +) def test_univariate(fn: str) -> None: results = [] x = np.random.uniform(0, 1) From 5ac4327b8dc7aefaa023823d8e564b198bce686a Mon Sep 17 00:00:00 2001 From: seitzdom Date: Thu, 4 Jul 2024 13:43:19 +0200 Subject: [PATCH 07/17] add infra --- .github/workflows/run-tests-and-mypy.yml | 143 +++++++++++++++++++++++ .pre-commit-config.yaml | 26 +++++ 2 files changed, 169 insertions(+) create mode 100644 .github/workflows/run-tests-and-mypy.yml create mode 100644 .pre-commit-config.yaml diff --git a/.github/workflows/run-tests-and-mypy.yml b/.github/workflows/run-tests-and-mypy.yml new file mode 100644 index 0000000..687190b --- /dev/null +++ b/.github/workflows/run-tests-and-mypy.yml @@ -0,0 +1,143 @@ +name: Linting / Tests / Documentation + +on: + push: + branches: + - main + tags: + - 'v[0-9]+.[0-9]+.[0-9]+' + pull_request: + branches: + - main + workflow_dispatch: {} + +concurrency: + group: ${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + lint: + name: Linting + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10"] + steps: + - name: Checkout main code and submodules + uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install hatch + run: | + python -m pip install hatch + - name: Install main dependencies + run: | + python -m hatch -v -e tests + - name: Lint + run: | + python -m hatch -e tests run pre-commit run --all-files + + unit_tests: + name: Unit testing + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.8","3.9", "3.10", "3.11", "3.12"] + steps: + - name: Checkout main code and submodules + uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install hatch + run: | + python -m pip install hatch + - name: Install main dependencies + run: | + python -m hatch -v -e tests + - name: Lint + run: | + python -m hatch -e tests run pre-commit run --all-files + - name: Perform unit tests + run: | + python -m hatch -e tests run pytest -n auto + + test_docs: + name: Documentation + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10"] + steps: + - name: Checkout main code and submodules + uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install hatch + run: | + python -m pip install hatch + - name: Install main dependencies + run: | + python -m hatch -v -e docs + - name: Test docs + run: | + python -m hatch run docs:build + + publish: + name: Publish to PyPI + if: startsWith(github.ref, 'refs/tags/v') + needs: unit_tests + runs-on: ubuntu-latest + steps: + - name: Checkout main code and submodules + uses: actions/checkout@v4 + with: + ref: ${{ github.ref }} + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.10" + - name: Install Python dependencies + run: | + python -m pip install --upgrade pip + pip install hatch + - name: Build and publish package + run: | + hatch build + hatch publish -u __token__ -a ${{ secrets.PYPI_API_TOKEN }} + - name: Confirm deployment + timeout-minutes: 5 + run: | + VERSION=${GITHUB_REF#refs/tags/v} + until pip download qadence-embeddings==$VERSION + do + echo "Failed to download from PyPI, will wait for upload and retry." + sleep 1 + done + + deploy_docs: + name: Deploy documentation + if: startsWith(github.ref, 'refs/tags/v') + needs: unit_tests + runs-on: ubuntu-latest + steps: + - name: Checkout main code and submodules + uses: actions/checkout@v4 + - name: Set up Python 3.10 + uses: actions/setup-python@v4 + with: + python-version: '3.10' + - name: Install Hatch + run: | + pip install hatch + - name: Deploy docs + run: | + git config user.name "GitHub Actions" + git config user.email "actions@github.com" + git fetch origin gh-pages + hatch -v run docs:mike deploy --push --update-aliases ${{ github.ref_name }} latest diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..a5ad010 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,26 @@ +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-added-large-files + +- repo: https://github.com/ambv/black + rev: 24.4.2 + hooks: + - id: black + +- repo: https://github.com/charliermarsh/ruff-pre-commit + rev: "v0.4.4" + hooks: + - id: ruff + args: [--fix] + + +- repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.10.0 + hooks: + - id: mypy + exclude: examples|docs From 081d737d7d844d609a5b11ccc271dae386b22722 Mon Sep 17 00:00:00 2001 From: seitzdom Date: Thu, 4 Jul 2024 16:57:44 +0200 Subject: [PATCH 08/17] add and test reembedding --- qadence_embeddings/embedding.py | 129 +++++++++++++++++++++++++++----- tests/test_embedding.py | 70 +++++++++++++++++ 2 files changed, 181 insertions(+), 18 deletions(-) diff --git a/qadence_embeddings/embedding.py b/qadence_embeddings/embedding.py index 3b2b5e0..532117c 100644 --- a/qadence_embeddings/embedding.py +++ b/qadence_embeddings/embedding.py @@ -1,18 +1,27 @@ from __future__ import annotations -from abc import ABC, abstractmethod -from typing import Any, Callable, Optional +from importlib import import_module +from logging import getLogger +from typing import Any, Optional +from numpy.typing import ArrayLike, DTypeLike -class ParameterType: - pass +from .callable import ConcretizedCallable +logger = getLogger(__name__) -class DType: - pass +def init_param(engine_name: str, trainable: bool = True) -> ArrayLike: + engine = import_module(engine_name) + if engine_name == "jax": + return engine.random.uniform(engine.random.PRNGKey(42), shape=(1,)) + elif engine_name == "torch": + return engine.rand(1, requires_grad=trainable) + elif engine_name == "numpy": + return engine.random.uniform(0, 1) -class Embedding(ABC): + +class Embedding: """ A generic module class to hold and handle the parameters and expressions functions coming from the `Model`. It may contain the list of user input @@ -20,19 +29,103 @@ class Embedding(ABC): evaluated functions from the data types being used, i.e. torch, numpy, etc. """ - vparams: dict[str, ParameterType] - fparams: dict[str, Optional[ParameterType]] - mapped_vars: dict[str, Callable] - _dtype: DType + def __init__( + self, + vparam_names: list[str], + fparam_names: list[str], + tparam_names: Optional[list[str]], + var_to_call: dict[str, ConcretizedCallable], + engine_name: str = "torch", + ) -> None: + self.vparams = { + vp: init_param(engine_name, trainable=True) for vp in vparam_names + } + self.fparams: dict[str, Optional[ArrayLike]] = {fp: None for fp in fparam_names} + self.tparams: dict[str, Optional[ArrayLike]] = ( + None if tparam_names is None else {fp: None for fp in tparam_names} + ) + self.var_to_call: dict[str, ConcretizedCallable] = var_to_call + self._dtype: DTypeLike = None + self.fparams_assigned: bool = False + + def flush_fparams(self) -> None: + """Flush all stored fparams and set them to None.""" + self.fparams = {key: None for key in self.fparams.keys()} + self.fparams_assigned = False + + def assign_fparams( + self, inputs: dict[str, ArrayLike | None], flush_current: bool = True + ) -> None: + """Mutate the `self.fparams` field to store inputs from the user.""" + if self.fparams_assigned: + ( + self.flush_fparams() + if flush_current + else logger.error( + "Fparams are still assigned. Please flush them before re-embedding." + ) + ) + if not inputs.keys() == self.fparams.keys(): + logger.error( + f"Please provide all fparams, Expected {self.fparams.keys()}, received {inputs.keys()}." + ) + self.fparams = inputs + self.fparams_assigned = True + + def evaluate_param( + self, param_name: str, inputs: dict[str, ArrayLike] + ) -> ArrayLike: + """Returns the result of evaluation an expression in `var_to_call`.""" + return self.var_to_call[param_name](inputs) - @abstractmethod - def __call__(self, *args: Any, **kwargs: Any) -> dict[str, Callable]: - raise NotImplementedError() + def embed_all( + self, + inputs: dict[str, ArrayLike], + include_root_vars: bool = True, + store_inputs: bool = False, + ) -> dict[str, ArrayLike]: + """The standard embedding of all intermediate and leaf parameters. + Include the root_params, i.e., the vparams and fparams original values + to be reused in computations. + """ + if not include_root_vars: + logger.error( + "Warning: Original parameters are not included, only intermediates and leaves." + ) + if store_inputs: + self.assign_fparams(inputs) + for intermediate_or_leaf_var, engine_callable in self.var_to_call.items(): + # We mutate the original inputs dict and include intermediates and leaves. + inputs[intermediate_or_leaf_var] = engine_callable(inputs) + return inputs - @abstractmethod - def name_mapping(self) -> dict: - raise NotImplementedError() + def reembed(self, inputs: dict[str, ArrayLike]) -> dict[str, ArrayLike]: + assert ( + self.fparams_assigned + ), "To reembed, please store original fparam values by setting\ + `include_root_vars = True` when calling `embed_all`" + + # We filter out intermediates and leaves and leave only the original vparams and fparams + + # the `inputs` dict which contains new pairs + inputs = { + p: v + for p, v in inputs.items() + if p in self.vparams.keys() or p in self.fparams.keys() + } + return self.embed_all({**self.vparams, **self.fparams, **inputs}) + + def __call__(self, inputs: dict[str, ArrayLike]) -> dict[str, ArrayLike]: + """Functional version of legacy embedding: Return a new dictionary with all embedded parameters.""" + return self.embed_all(inputs) + + # @abstractmethod + # def name_mapping(self) -> dict: + # raise NotImplementedError() @property - def dtype(self) -> DType: + def dtype(self) -> DTypeLike: return self._dtype + + def to(self, args: Any, kwargs: Any) -> None: + # TODO move to device and dtype + pass diff --git a/tests/test_embedding.py b/tests/test_embedding.py index e69de29..7be6e42 100644 --- a/tests/test_embedding.py +++ b/tests/test_embedding.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +import numpy as np +import torch + +from qadence_embeddings.callable import ConcretizedCallable +from qadence_embeddings.embedding import Embedding + + +def test_embedding() -> None: + x = np.random.uniform(0, 1) + theta = np.random.uniform(0, 1) + results = [] + for engine_name in ["jax", "torch", "numpy"]: + v_params = ["theta"] + f_params = ["x"] + leaf0, native_call0 = "%0", ConcretizedCallable( + "mul", ["x", "theta"], {}, engine_name + ) + embedding = Embedding( + v_params, + f_params, + None, + var_to_call={leaf0: native_call0}, + engine_name=engine_name, + ) + inputs = { + "x": (torch.tensor(x) if engine_name == "torch" else x), + "theta": (torch.tensor(theta) if engine_name == "torch" else theta), + } + eval_0 = embedding.evaluate_param("%0", inputs) + results.append(eval_0.item()) + assert np.allclose(results[0], results[1]) and np.allclose(results[0], results[2]) + + +def test_reembedding() -> None: + x = np.random.uniform(0, 1) + theta = np.random.uniform(0, 1) + x_rembed = np.random.uniform(0, 1) + results = [] + reembedded_results = [] + for engine_name in ["jax", "torch", "numpy"]: + v_params = ["theta"] + f_params = ["x"] + leaf0, native_call0 = "%0", ConcretizedCallable( + "mul", ["x", "theta"], {}, engine_name + ) + embedding = Embedding( + v_params, + f_params, + None, + var_to_call={leaf0: native_call0}, + engine_name=engine_name, + ) + inputs = { + "x": (torch.tensor(x) if engine_name == "torch" else x), + "theta": (torch.tensor(theta) if engine_name == "torch" else theta), + } + all_params = embedding.embed_all( + inputs, include_root_vars=True, store_inputs=True + ) + reembedded_params = embedding.reembed( + {"x": (torch.tensor(x_rembed) if engine_name == "torch" else x_rembed)} + ) + results.append(all_params["%0"].item()) + reembedded_results.append(reembedded_params["%0"].item()) + assert np.allclose(results[0], results[1]) and np.allclose(results[0], results[2]) + assert np.allclose(reembedded_results[0], reembedded_results[1]) and np.allclose( + reembedded_results[0], reembedded_results[2] + ) From 329aed08f20263bac11aaacf092329af479a4ba2 Mon Sep 17 00:00:00 2001 From: seitzdom Date: Thu, 4 Jul 2024 18:13:16 +0200 Subject: [PATCH 09/17] rename --- qadence_embeddings/embedding.py | 2 +- tests/test_embedding.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/qadence_embeddings/embedding.py b/qadence_embeddings/embedding.py index 532117c..2e5f0a4 100644 --- a/qadence_embeddings/embedding.py +++ b/qadence_embeddings/embedding.py @@ -99,7 +99,7 @@ def embed_all( inputs[intermediate_or_leaf_var] = engine_callable(inputs) return inputs - def reembed(self, inputs: dict[str, ArrayLike]) -> dict[str, ArrayLike]: + def reembed_all(self, inputs: dict[str, ArrayLike]) -> dict[str, ArrayLike]: assert ( self.fparams_assigned ), "To reembed, please store original fparam values by setting\ diff --git a/tests/test_embedding.py b/tests/test_embedding.py index 7be6e42..84ed881 100644 --- a/tests/test_embedding.py +++ b/tests/test_embedding.py @@ -59,7 +59,7 @@ def test_reembedding() -> None: all_params = embedding.embed_all( inputs, include_root_vars=True, store_inputs=True ) - reembedded_params = embedding.reembed( + reembedded_params = embedding.reembed_all( {"x": (torch.tensor(x_rembed) if engine_name == "torch" else x_rembed)} ) results.append(all_params["%0"].item()) From 99b7ef6aa65e320b9bdb83c96dab183733c73cf8 Mon Sep 17 00:00:00 2001 From: seitzdom Date: Thu, 4 Jul 2024 18:20:38 +0200 Subject: [PATCH 10/17] fix lint --- docs/index.md | 2 ++ qadence_embeddings/callable.py | 6 +++--- qadence_embeddings/embedding.py | 13 +++++++------ 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/docs/index.md b/docs/index.md index e69de29..67261ba 100644 --- a/docs/index.md +++ b/docs/index.md @@ -0,0 +1,2 @@ +::: qadence_embeddings.callable.ConcretizedCallable +::: qadence_embeddings.embedding.Embedding diff --git a/qadence_embeddings/callable.py b/qadence_embeddings/callable.py index 3afa2a8..f0bf090 100644 --- a/qadence_embeddings/callable.py +++ b/qadence_embeddings/callable.py @@ -2,7 +2,7 @@ from importlib import import_module from logging import getLogger -from typing import Callable +from typing import Tuple from numpy.typing import ArrayLike @@ -20,7 +20,7 @@ "sub": ("jax.numpy", "subtract"), "div": ("jax.numpy", "divide"), } -DEFAULT_TORCH_MAPPING = {} +DEFAULT_TORCH_MAPPING: dict = {} DEFAULT_NUMPY_MAPPING = { "mul": ("numpy", "multiply"), "sub": ("numpy", "subtract"), @@ -40,7 +40,7 @@ def __init__( self, call_name: str, abstract_args: list[str | float | int], - instruction_mapping: dict[str, Callable] = dict(), + instruction_mapping: dict[str, Tuple[str, str]] = dict(), engine_name: str = "torch", ) -> None: instruction_mapping = { diff --git a/qadence_embeddings/embedding.py b/qadence_embeddings/embedding.py index 2e5f0a4..cd1c0eb 100644 --- a/qadence_embeddings/embedding.py +++ b/qadence_embeddings/embedding.py @@ -42,7 +42,9 @@ def __init__( } self.fparams: dict[str, Optional[ArrayLike]] = {fp: None for fp in fparam_names} self.tparams: dict[str, Optional[ArrayLike]] = ( - None if tparam_names is None else {fp: None for fp in tparam_names} + None + if tparam_names is None + else {fp: None for fp in tparam_names} # type: ignore[assignment] ) self.var_to_call: dict[str, ConcretizedCallable] = var_to_call self._dtype: DTypeLike = None @@ -67,7 +69,8 @@ def assign_fparams( ) if not inputs.keys() == self.fparams.keys(): logger.error( - f"Please provide all fparams, Expected {self.fparams.keys()}, received {inputs.keys()}." + f"Please provide all fparams, Expected {self.fparams.keys()},\ + received {inputs.keys()}." ) self.fparams = inputs self.fparams_assigned = True @@ -115,12 +118,10 @@ def reembed_all(self, inputs: dict[str, ArrayLike]) -> dict[str, ArrayLike]: return self.embed_all({**self.vparams, **self.fparams, **inputs}) def __call__(self, inputs: dict[str, ArrayLike]) -> dict[str, ArrayLike]: - """Functional version of legacy embedding: Return a new dictionary with all embedded parameters.""" + """Functional version of legacy embedding: Return a new dictionary\ + with all embedded parameters.""" return self.embed_all(inputs) - # @abstractmethod - # def name_mapping(self) -> dict: - # raise NotImplementedError() @property def dtype(self) -> DTypeLike: From 81ac1c51b7cd6d5fec5478d9f0c3bb46655ea7ef Mon Sep 17 00:00:00 2001 From: seitzdom Date: Thu, 4 Jul 2024 19:07:41 +0200 Subject: [PATCH 11/17] simpler version --- qadence_embeddings/embedding.py | 60 ++++++--------------------------- tests/test_embedding.py | 13 ++++--- 2 files changed, 16 insertions(+), 57 deletions(-) diff --git a/qadence_embeddings/embedding.py b/qadence_embeddings/embedding.py index cd1c0eb..9fa14d7 100644 --- a/qadence_embeddings/embedding.py +++ b/qadence_embeddings/embedding.py @@ -48,81 +48,41 @@ def __init__( ) self.var_to_call: dict[str, ConcretizedCallable] = var_to_call self._dtype: DTypeLike = None - self.fparams_assigned: bool = False - def flush_fparams(self) -> None: - """Flush all stored fparams and set them to None.""" - self.fparams = {key: None for key in self.fparams.keys()} - self.fparams_assigned = False - - def assign_fparams( - self, inputs: dict[str, ArrayLike | None], flush_current: bool = True - ) -> None: - """Mutate the `self.fparams` field to store inputs from the user.""" - if self.fparams_assigned: - ( - self.flush_fparams() - if flush_current - else logger.error( - "Fparams are still assigned. Please flush them before re-embedding." - ) - ) - if not inputs.keys() == self.fparams.keys(): - logger.error( - f"Please provide all fparams, Expected {self.fparams.keys()},\ - received {inputs.keys()}." - ) - self.fparams = inputs - self.fparams_assigned = True - - def evaluate_param( - self, param_name: str, inputs: dict[str, ArrayLike] - ) -> ArrayLike: - """Returns the result of evaluation an expression in `var_to_call`.""" - return self.var_to_call[param_name](inputs) + @property + def root_param_names(self) -> list[str]: + return list(self.vparams.keys()) + list(self.fparams.keys()) def embed_all( self, inputs: dict[str, ArrayLike], - include_root_vars: bool = True, - store_inputs: bool = False, ) -> dict[str, ArrayLike]: """The standard embedding of all intermediate and leaf parameters. Include the root_params, i.e., the vparams and fparams original values to be reused in computations. """ - if not include_root_vars: - logger.error( - "Warning: Original parameters are not included, only intermediates and leaves." - ) - if store_inputs: - self.assign_fparams(inputs) for intermediate_or_leaf_var, engine_callable in self.var_to_call.items(): # We mutate the original inputs dict and include intermediates and leaves. inputs[intermediate_or_leaf_var] = engine_callable(inputs) return inputs - def reembed_all(self, inputs: dict[str, ArrayLike]) -> dict[str, ArrayLike]: - assert ( - self.fparams_assigned - ), "To reembed, please store original fparam values by setting\ - `include_root_vars = True` when calling `embed_all`" - + def reembed_all( + self, + embedded_params: dict[str, ArrayLike], + new_root_params: dict[str, ArrayLike], + ) -> dict[str, ArrayLike]: # We filter out intermediates and leaves and leave only the original vparams and fparams + # the `inputs` dict which contains new pairs inputs = { - p: v - for p, v in inputs.items() - if p in self.vparams.keys() or p in self.fparams.keys() + p: v for p, v in embedded_params.items() if p in self.root_param_names } - return self.embed_all({**self.vparams, **self.fparams, **inputs}) + return self.embed_all({**self.vparams, **inputs, **new_root_params}) def __call__(self, inputs: dict[str, ArrayLike]) -> dict[str, ArrayLike]: """Functional version of legacy embedding: Return a new dictionary\ with all embedded parameters.""" return self.embed_all(inputs) - @property def dtype(self) -> DTypeLike: return self._dtype diff --git a/tests/test_embedding.py b/tests/test_embedding.py index 84ed881..6379728 100644 --- a/tests/test_embedding.py +++ b/tests/test_embedding.py @@ -28,7 +28,7 @@ def test_embedding() -> None: "x": (torch.tensor(x) if engine_name == "torch" else x), "theta": (torch.tensor(theta) if engine_name == "torch" else theta), } - eval_0 = embedding.evaluate_param("%0", inputs) + eval_0 = embedding.var_to_call["%0"](inputs) results.append(eval_0.item()) assert np.allclose(results[0], results[1]) and np.allclose(results[0], results[2]) @@ -56,12 +56,11 @@ def test_reembedding() -> None: "x": (torch.tensor(x) if engine_name == "torch" else x), "theta": (torch.tensor(theta) if engine_name == "torch" else theta), } - all_params = embedding.embed_all( - inputs, include_root_vars=True, store_inputs=True - ) - reembedded_params = embedding.reembed_all( - {"x": (torch.tensor(x_rembed) if engine_name == "torch" else x_rembed)} - ) + all_params = embedding.embed_all(inputs) + new_params = { + "x": (torch.tensor(x_rembed) if engine_name == "torch" else x_rembed) + } + reembedded_params = embedding.reembed_all(all_params, new_params) results.append(all_params["%0"].item()) reembedded_results.append(reembedded_params["%0"].item()) assert np.allclose(results[0], results[1]) and np.allclose(results[0], results[2]) From 98eee1b41634731f6efa51368e332c9f60365704 Mon Sep 17 00:00:00 2001 From: seitzdom Date: Thu, 4 Jul 2024 19:10:22 +0200 Subject: [PATCH 12/17] docstring --- qadence_embeddings/embedding.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/qadence_embeddings/embedding.py b/qadence_embeddings/embedding.py index 9fa14d7..dde6e18 100644 --- a/qadence_embeddings/embedding.py +++ b/qadence_embeddings/embedding.py @@ -71,6 +71,10 @@ def reembed_all( embedded_params: dict[str, ArrayLike], new_root_params: dict[str, ArrayLike], ) -> dict[str, ArrayLike]: + """Receive already embedded params containing intermediate and leaf parameters + and remove them from the `embedded_params` dict to reconstruct the user input, and finally + recalculate the embedding using values for parameters in passes in `new_root_params`. + """ # We filter out intermediates and leaves and leave only the original vparams and fparams + # the `inputs` dict which contains new pairs inputs = { From deb1c2f65baf625c8dcb276ebc72d85b10d72c06 Mon Sep 17 00:00:00 2001 From: seitzdom Date: Fri, 5 Jul 2024 03:26:22 +0200 Subject: [PATCH 13/17] better version --- qadence_embeddings/callable.py | 6 +++- qadence_embeddings/embedding.py | 54 ++++++++++++++++++--------------- tests/test_embedding.py | 24 +++++++++------ 3 files changed, 50 insertions(+), 34 deletions(-) diff --git a/qadence_embeddings/callable.py b/qadence_embeddings/callable.py index f0bf090..f49b899 100644 --- a/qadence_embeddings/callable.py +++ b/qadence_embeddings/callable.py @@ -42,6 +42,7 @@ def __init__( abstract_args: list[str | float | int], instruction_mapping: dict[str, Tuple[str, str]] = dict(), engine_name: str = "torch", + device: str = "cpu", ) -> None: instruction_mapping = { **instruction_mapping, @@ -50,6 +51,7 @@ def __init__( self.call_name = call_name self.abstract_args = abstract_args self.engine_name = engine_name + self.device = device self.engine_call = None engine_call = None engine = None @@ -107,7 +109,9 @@ def evaluate(self, inputs: dict[str, ArrayLike] = dict()) -> ArrayLike: arraylike_args = [] for symbol_or_numeric in self.abstract_args: if isinstance(symbol_or_numeric, (float, int)): - arraylike_args.append(self.arraylike_fn(symbol_or_numeric)) + arraylike_args.append( + self.arraylike_fn(symbol_or_numeric, device=self.device) + ) elif isinstance(symbol_or_numeric, str): arraylike_args.append(inputs[symbol_or_numeric]) return self.engine_call(*arraylike_args) # type: ignore[misc] diff --git a/qadence_embeddings/embedding.py b/qadence_embeddings/embedding.py index dde6e18..e412710 100644 --- a/qadence_embeddings/embedding.py +++ b/qadence_embeddings/embedding.py @@ -11,12 +11,14 @@ logger = getLogger(__name__) -def init_param(engine_name: str, trainable: bool = True) -> ArrayLike: +def init_param( + engine_name: str, trainable: bool = True, device: str = "cpu" +) -> ArrayLike: engine = import_module(engine_name) if engine_name == "jax": return engine.random.uniform(engine.random.PRNGKey(42), shape=(1,)) elif engine_name == "torch": - return engine.rand(1, requires_grad=trainable) + return engine.rand(1, requires_grad=trainable, device=device) elif engine_name == "numpy": return engine.random.uniform(0, 1) @@ -31,27 +33,27 @@ class Embedding: def __init__( self, - vparam_names: list[str], - fparam_names: list[str], - tparam_names: Optional[list[str]], - var_to_call: dict[str, ConcretizedCallable], + vparam_names: list[str] = [], + fparam_names: list[str] = [], + tparam_name: Optional[str] = None, + var_to_call: dict[str, ConcretizedCallable] = dict(), engine_name: str = "torch", + device: str = "cpu", ) -> None: + self.vparams = { - vp: init_param(engine_name, trainable=True) for vp in vparam_names + vp: init_param(engine_name, trainable=True, device=device) + for vp in vparam_names } - self.fparams: dict[str, Optional[ArrayLike]] = {fp: None for fp in fparam_names} - self.tparams: dict[str, Optional[ArrayLike]] = ( - None - if tparam_names is None - else {fp: None for fp in tparam_names} # type: ignore[assignment] - ) + self.fparam_names: list[str] = fparam_names + self.tparam_name = tparam_name self.var_to_call: dict[str, ConcretizedCallable] = var_to_call self._dtype: DTypeLike = None + self.time_dependent_vars: list[str] = [] @property def root_param_names(self) -> list[str]: - return list(self.vparams.keys()) + list(self.fparams.keys()) + return list(self.vparams.keys()) + self.fparam_names def embed_all( self, @@ -63,24 +65,28 @@ def embed_all( """ for intermediate_or_leaf_var, engine_callable in self.var_to_call.items(): # We mutate the original inputs dict and include intermediates and leaves. + if self.tparam_name in engine_callable.abstract_args: + self.time_dependent_vars.append(intermediate_or_leaf_var) + # we remember which parameters depend on time inputs[intermediate_or_leaf_var] = engine_callable(inputs) return inputs - def reembed_all( + def reembed_tparam( self, embedded_params: dict[str, ArrayLike], - new_root_params: dict[str, ArrayLike], + tparam_value: ArrayLike, ) -> dict[str, ArrayLike]: """Receive already embedded params containing intermediate and leaf parameters - and remove them from the `embedded_params` dict to reconstruct the user input, and finally - recalculate the embedding using values for parameters in passes in `new_root_params`. + and recalculate the those which are dependent on the time parameter using the new value + `tparam_value`. """ - # We filter out intermediates and leaves and leave only the original vparams and fparams + - # the `inputs` dict which contains new pairs - inputs = { - p: v for p, v in embedded_params.items() if p in self.root_param_names - } - return self.embed_all({**self.vparams, **inputs, **new_root_params}) + assert self.tparam_name is not None + embedded_params[self.tparam_name] = tparam_value + for time_dependent_param in self.time_dependent_vars: + embedded_params[time_dependent_param] = self.var_to_call[ + time_dependent_param + ](embedded_params) + return embedded_params def __call__(self, inputs: dict[str, ArrayLike]) -> dict[str, ArrayLike]: """Functional version of legacy embedding: Return a new dictionary\ diff --git a/tests/test_embedding.py b/tests/test_embedding.py index 6379728..e5c8ea0 100644 --- a/tests/test_embedding.py +++ b/tests/test_embedding.py @@ -36,33 +36,39 @@ def test_embedding() -> None: def test_reembedding() -> None: x = np.random.uniform(0, 1) theta = np.random.uniform(0, 1) - x_rembed = np.random.uniform(0, 1) + t = np.random.uniform(0, 1) + t_reembed = np.random.uniform(0, 1) results = [] reembedded_results = [] for engine_name in ["jax", "torch", "numpy"]: v_params = ["theta"] f_params = ["x"] + tparam = "t" leaf0, native_call0 = "%0", ConcretizedCallable( "mul", ["x", "theta"], {}, engine_name ) + leaf1, native_call1 = "%1", ConcretizedCallable( + "mul", ["t", "%0"], {}, engine_name + ) embedding = Embedding( v_params, f_params, - None, - var_to_call={leaf0: native_call0}, + tparam, + var_to_call={leaf0: native_call0, leaf1: native_call1}, engine_name=engine_name, ) inputs = { "x": (torch.tensor(x) if engine_name == "torch" else x), "theta": (torch.tensor(theta) if engine_name == "torch" else theta), + "t": (torch.tensor(t) if engine_name == "torch" else t), } all_params = embedding.embed_all(inputs) - new_params = { - "x": (torch.tensor(x_rembed) if engine_name == "torch" else x_rembed) - } - reembedded_params = embedding.reembed_all(all_params, new_params) - results.append(all_params["%0"].item()) - reembedded_results.append(reembedded_params["%0"].item()) + new_tparam_val = ( + torch.tensor(t_reembed) if engine_name == "torch" else t_reembed + ) + reembedded_params = embedding.reembed_tparam(all_params, new_tparam_val) + results.append(all_params["%1"].item()) + reembedded_results.append(reembedded_params["%1"].item()) assert np.allclose(results[0], results[1]) and np.allclose(results[0], results[2]) assert np.allclose(reembedded_results[0], reembedded_results[1]) and np.allclose( reembedded_results[0], reembedded_results[2] From 2e00a5e3867ad3e1e5b20ae4806ee2c8e9aace3f Mon Sep 17 00:00:00 2001 From: seitzdom Date: Fri, 5 Jul 2024 03:33:53 +0200 Subject: [PATCH 14/17] doc --- qadence_embeddings/callable.py | 3 +++ qadence_embeddings/embedding.py | 16 ++++++++++------ tests/test_embedding.py | 2 +- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/qadence_embeddings/callable.py b/qadence_embeddings/callable.py index f49b899..f272b40 100644 --- a/qadence_embeddings/callable.py +++ b/qadence_embeddings/callable.py @@ -35,6 +35,9 @@ class ConcretizedCallable: + """Transform an abstract function name and arguments into + a callable in a linear algebra engine which can be evaluated + using user input.""" def __init__( self, diff --git a/qadence_embeddings/embedding.py b/qadence_embeddings/embedding.py index e412710..b9c1032 100644 --- a/qadence_embeddings/embedding.py +++ b/qadence_embeddings/embedding.py @@ -25,10 +25,9 @@ def init_param( class Embedding: """ - A generic module class to hold and handle the parameters and expressions - functions coming from the `Model`. It may contain the list of user input - parameters, as well as the trainable variational parameters and the - evaluated functions from the data types being used, i.e. torch, numpy, etc. + A class carrying information about user-facing parameters, a.k.a root parameters + as well as a mapping from interemediate and leaf variables using in expressions + or directly as parameters for linear algebra operations. """ def __init__( @@ -50,6 +49,7 @@ def __init__( self.var_to_call: dict[str, ConcretizedCallable] = var_to_call self._dtype: DTypeLike = None self.time_dependent_vars: list[str] = [] + self._device = device @property def root_param_names(self) -> list[str]: @@ -65,13 +65,13 @@ def embed_all( """ for intermediate_or_leaf_var, engine_callable in self.var_to_call.items(): # We mutate the original inputs dict and include intermediates and leaves. - if self.tparam_name in engine_callable.abstract_args: + if self.tparam_name and self.tparam_name in engine_callable.abstract_args: self.time_dependent_vars.append(intermediate_or_leaf_var) # we remember which parameters depend on time inputs[intermediate_or_leaf_var] = engine_callable(inputs) return inputs - def reembed_tparam( + def reembed_time( self, embedded_params: dict[str, ArrayLike], tparam_value: ArrayLike, @@ -97,6 +97,10 @@ def __call__(self, inputs: dict[str, ArrayLike]) -> dict[str, ArrayLike]: def dtype(self) -> DTypeLike: return self._dtype + @property + def device(self) -> str: + return self._device + def to(self, args: Any, kwargs: Any) -> None: # TODO move to device and dtype pass diff --git a/tests/test_embedding.py b/tests/test_embedding.py index e5c8ea0..8210a81 100644 --- a/tests/test_embedding.py +++ b/tests/test_embedding.py @@ -66,7 +66,7 @@ def test_reembedding() -> None: new_tparam_val = ( torch.tensor(t_reembed) if engine_name == "torch" else t_reembed ) - reembedded_params = embedding.reembed_tparam(all_params, new_tparam_val) + reembedded_params = embedding.reembed_time(all_params, new_tparam_val) results.append(all_params["%1"].item()) reembedded_results.append(reembedded_params["%1"].item()) assert np.allclose(results[0], results[1]) and np.allclose(results[0], results[2]) From 360d7f7a929dd2d85bbbeaaef4bd368cd86146fd Mon Sep 17 00:00:00 2001 From: seitzdom Date: Fri, 5 Jul 2024 03:59:50 +0200 Subject: [PATCH 15/17] store also intermediate td params --- qadence_embeddings/embedding.py | 8 +++++++- tests/test_embedding.py | 8 +++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/qadence_embeddings/embedding.py b/qadence_embeddings/embedding.py index b9c1032..8a7fd0c 100644 --- a/qadence_embeddings/embedding.py +++ b/qadence_embeddings/embedding.py @@ -65,7 +65,13 @@ def embed_all( """ for intermediate_or_leaf_var, engine_callable in self.var_to_call.items(): # We mutate the original inputs dict and include intermediates and leaves. - if self.tparam_name and self.tparam_name in engine_callable.abstract_args: + if self.tparam_name and any( + [ + p in [self.tparam_name] + self.time_dependent_vars + for p in engine_callable.abstract_args + ] # we check if any parameter in the callables args is time + # or depends on an intermediate variable which itself depends on time + ): self.time_dependent_vars.append(intermediate_or_leaf_var) # we remember which parameters depend on time inputs[intermediate_or_leaf_var] = engine_callable(inputs) diff --git a/tests/test_embedding.py b/tests/test_embedding.py index 8210a81..a59bb79 100644 --- a/tests/test_embedding.py +++ b/tests/test_embedding.py @@ -50,11 +50,13 @@ def test_reembedding() -> None: leaf1, native_call1 = "%1", ConcretizedCallable( "mul", ["t", "%0"], {}, engine_name ) + + leaf2, native_call2 = "%2", ConcretizedCallable("sin", ["%1"], {}, engine_name) embedding = Embedding( v_params, f_params, tparam, - var_to_call={leaf0: native_call0, leaf1: native_call1}, + var_to_call={leaf0: native_call0, leaf1: native_call1, leaf2: native_call2}, engine_name=engine_name, ) inputs = { @@ -67,8 +69,8 @@ def test_reembedding() -> None: torch.tensor(t_reembed) if engine_name == "torch" else t_reembed ) reembedded_params = embedding.reembed_time(all_params, new_tparam_val) - results.append(all_params["%1"].item()) - reembedded_results.append(reembedded_params["%1"].item()) + results.append(all_params["%2"].item()) + reembedded_results.append(reembedded_params["%2"].item()) assert np.allclose(results[0], results[1]) and np.allclose(results[0], results[2]) assert np.allclose(reembedded_results[0], reembedded_results[1]) and np.allclose( reembedded_results[0], reembedded_results[2] From 4a30d06a0fd2836f70453acee4465307c52c486d Mon Sep 17 00:00:00 2001 From: seitzdom Date: Fri, 5 Jul 2024 04:02:59 +0200 Subject: [PATCH 16/17] ultimate test --- tests/test_embedding.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_embedding.py b/tests/test_embedding.py index a59bb79..5b1660d 100644 --- a/tests/test_embedding.py +++ b/tests/test_embedding.py @@ -71,6 +71,8 @@ def test_reembedding() -> None: reembedded_params = embedding.reembed_time(all_params, new_tparam_val) results.append(all_params["%2"].item()) reembedded_results.append(reembedded_params["%2"].item()) + assert all([p in ["%1", "%2"] for p in embedding.time_dependent_vars]) + assert "%0" not in embedding.time_dependent_vars assert np.allclose(results[0], results[1]) and np.allclose(results[0], results[2]) assert np.allclose(reembedded_results[0], reembedded_results[1]) and np.allclose( reembedded_results[0], reembedded_results[2] From ff826b9d474efccc37bb63b2e338dbc5d99b33d6 Mon Sep 17 00:00:00 2001 From: seitzdom Date: Fri, 5 Jul 2024 08:52:06 +0200 Subject: [PATCH 17/17] only alloc once --- qadence_embeddings/embedding.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/qadence_embeddings/embedding.py b/qadence_embeddings/embedding.py index 8a7fd0c..f043dcf 100644 --- a/qadence_embeddings/embedding.py +++ b/qadence_embeddings/embedding.py @@ -50,6 +50,7 @@ def __init__( self._dtype: DTypeLike = None self.time_dependent_vars: list[str] = [] self._device = device + self._time_vars_identified = False @property def root_param_names(self) -> list[str]: @@ -65,16 +66,19 @@ def embed_all( """ for intermediate_or_leaf_var, engine_callable in self.var_to_call.items(): # We mutate the original inputs dict and include intermediates and leaves. - if self.tparam_name and any( - [ - p in [self.tparam_name] + self.time_dependent_vars - for p in engine_callable.abstract_args - ] # we check if any parameter in the callables args is time - # or depends on an intermediate variable which itself depends on time - ): - self.time_dependent_vars.append(intermediate_or_leaf_var) - # we remember which parameters depend on time + if not self._time_vars_identified: + # we do this only on the first embedding call + if self.tparam_name and any( + [ + p in [self.tparam_name] + self.time_dependent_vars + for p in engine_callable.abstract_args + ] # we check if any parameter in the callables args is time + # or depends on an intermediate variable which itself depends on time + ): + self.time_dependent_vars.append(intermediate_or_leaf_var) + # we remember which parameters depend on time inputs[intermediate_or_leaf_var] = engine_callable(inputs) + self._time_vars_identified = True return inputs def reembed_time(