diff --git a/.github/workflows/run-tests-and-mypy.yml b/.github/workflows/run-tests-and-mypy.yml new file mode 100644 index 0000000..687190b --- /dev/null +++ b/.github/workflows/run-tests-and-mypy.yml @@ -0,0 +1,143 @@ +name: Linting / Tests / Documentation + +on: + push: + branches: + - main + tags: + - 'v[0-9]+.[0-9]+.[0-9]+' + pull_request: + branches: + - main + workflow_dispatch: {} + +concurrency: + group: ${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + lint: + name: Linting + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10"] + steps: + - name: Checkout main code and submodules + uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install hatch + run: | + python -m pip install hatch + - name: Install main dependencies + run: | + python -m hatch -v -e tests + - name: Lint + run: | + python -m hatch -e tests run pre-commit run --all-files + + unit_tests: + name: Unit testing + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.8","3.9", "3.10", "3.11", "3.12"] + steps: + - name: Checkout main code and submodules + uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install hatch + run: | + python -m pip install hatch + - name: Install main dependencies + run: | + python -m hatch -v -e tests + - name: Lint + run: | + python -m hatch -e tests run pre-commit run --all-files + - name: Perform unit tests + run: | + python -m hatch -e tests run pytest -n auto + + test_docs: + name: Documentation + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10"] + steps: + - name: Checkout main code and submodules + uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install hatch + run: | + python -m pip install hatch + - name: Install main dependencies + run: | + python -m hatch -v -e docs + - name: Test docs + run: | + python -m hatch run docs:build + + publish: + name: Publish to PyPI + if: startsWith(github.ref, 'refs/tags/v') + needs: unit_tests + runs-on: ubuntu-latest + steps: + - name: Checkout main code and submodules + uses: actions/checkout@v4 + with: + ref: ${{ github.ref }} + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.10" + - name: Install Python dependencies + run: | + python -m pip install --upgrade pip + pip install hatch + - name: Build and publish package + run: | + hatch build + hatch publish -u __token__ -a ${{ secrets.PYPI_API_TOKEN }} + - name: Confirm deployment + timeout-minutes: 5 + run: | + VERSION=${GITHUB_REF#refs/tags/v} + until pip download qadence-embeddings==$VERSION + do + echo "Failed to download from PyPI, will wait for upload and retry." + sleep 1 + done + + deploy_docs: + name: Deploy documentation + if: startsWith(github.ref, 'refs/tags/v') + needs: unit_tests + runs-on: ubuntu-latest + steps: + - name: Checkout main code and submodules + uses: actions/checkout@v4 + - name: Set up Python 3.10 + uses: actions/setup-python@v4 + with: + python-version: '3.10' + - name: Install Hatch + run: | + pip install hatch + - name: Deploy docs + run: | + git config user.name "GitHub Actions" + git config user.email "actions@github.com" + git fetch origin gh-pages + hatch -v run docs:mike deploy --push --update-aliases ${{ github.ref_name }} latest diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..a5ad010 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,26 @@ +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-added-large-files + +- repo: https://github.com/ambv/black + rev: 24.4.2 + hooks: + - id: black + +- repo: https://github.com/charliermarsh/ruff-pre-commit + rev: "v0.4.4" + hooks: + - id: ruff + args: [--fix] + + +- repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.10.0 + hooks: + - id: mypy + exclude: examples|docs diff --git a/README.md b/README.md index 528dffa..c13e049 100644 --- a/README.md +++ b/README.md @@ -1 +1,3 @@ -# qadence-embeddings \ No newline at end of file +# qadence-embeddings + +**qadence-embeddings** is a engine-agnostic parameter embedding library. diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 0000000..67261ba --- /dev/null +++ b/docs/index.md @@ -0,0 +1,2 @@ +::: qadence_embeddings.callable.ConcretizedCallable +::: qadence_embeddings.embedding.Embedding diff --git a/mkdocs.yml b/mkdocs.yml new file mode 100644 index 0000000..4c5a14c --- /dev/null +++ b/mkdocs.yml @@ -0,0 +1,71 @@ +site_name: qadence-embeddings +repo_url: "https://github.com/pasqal-io/qadence-embeddings" +repo_name: "qadence-embeddings" + +nav: + + - Embeddings in a nutshell: index.md + + +theme: + name: material + features: + - content.code.annotate + - navigation.indexes + - navigation.sections + + palette: + - media: "(prefers-color-scheme: light)" + scheme: default + primary: light green + accent: purple + toggle: + icon: material/weather-sunny + name: Switch to dark mode + - media: "(prefers-color-scheme: dark)" + scheme: slate + primary: black + accent: light green + toggle: + icon: material/weather-night + name: Switch to light mode + +markdown_extensions: +- admonition # for notes +- pymdownx.arithmatex: # for mathjax + generic: true +- pymdownx.highlight: + anchor_linenums: true +- pymdownx.inlinehilite +- pymdownx.snippets +- pymdownx.superfences + +plugins: +- search +- section-index +- mkdocstrings: + default_handler: python + handlers: + python: + selection: + filters: + - "!^_" # exlude all members starting with _ + - "^__init__$" # but always include __init__ modules and methods + +- mkdocs-jupyter: + theme: light +- markdown-exec + +extra_css: +- extras/css/mkdocstrings.css +- extras/css/colors.css +- extras/css/home.css + +# For mathjax +extra_javascript: + - extras/javascripts/mathjax.js + - https://polyfill.io/v3/polyfill.min.js?features=es6 + - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js + +watch: + - qadence_embeddings diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..eb2c988 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,107 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "qadence-embeddings" +description = "a engine-agnostic parameter embedding library." +authors = [ + { name = "Dominik Seitz", email = "dominik.seitz@pasqal.com" }, +] +requires-python = ">=3.8,<3.13" +license = {text = "Apache 2.0"} +version = "1.3.0" +classifiers=[ + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] +dependencies = ["torch", "numpy", "jax"] + +[project.optional-dependencies] +dev = ["flaky","black", "pytest", "pytest-xdist", "pytest-cov", "flake8", "mypy", "pre-commit", "ruff", "nbconvert", "matplotlib", "qutip~=4.7.5"] + +[tool.hatch.envs.tests] +features = [ + "dev", +] + +[tool.hatch.envs.tests.scripts] +test = "pytest -n auto {args}" +test-docs = "hatch -e docs run mkdocs build --clean --strict" +test-cov = "pytest -n auto --cov=qadence-embeddings tests/" + +[tool.hatch.envs.docs] +dependencies = [ + "mkdocs", + "mkdocs-material", + "mkdocstrings", + "mkdocstrings-python", + "mkdocs-section-index", + "mkdocs-jupyter", + "mkdocs-exclude", + "markdown-exec", + "mike", + "matplotlib", +] + +[tool.hatch.envs.docs.scripts] +build = "mkdocs build --clean --strict" +serve = "mkdocs serve --dev-addr localhost:8000" + +[tool.ruff] +lint.select = ["E", "F", "I", "Q"] +lint.extend-ignore = ["F841"] +line-length = 100 + +[tool.ruff.lint.isort] +required-imports = ["from __future__ import annotations"] + +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["F401", "E402"] + +[tool.ruff.lint.mccabe] +max-complexity = 15 + +[tool.ruff.lint.flake8-quotes] +docstring-quotes = "double" + +[lint.black] +line-length = 100 +include = '\.pyi?$' +exclude = ''' +/( + \.git + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | _build + | buck-out + | build + | dist +)/ +''' + +[lint.isort] +line_length = 100 +combine_as_imports = true +balanced_wrapping = true +lines_after_imports = 2 +include_trailing_comma = true +multi_line_output = 5 + +[lint.mypy] +python_version = "3.10" +warn_return_any = true +warn_unused_configs = true +disallow_untyped_defs = true +no_implicit_optional = false +ignore_missing_imports = true diff --git a/qadence_embeddings/__init__.py b/qadence_embeddings/__init__.py new file mode 100644 index 0000000..da6685a --- /dev/null +++ b/qadence_embeddings/__init__.py @@ -0,0 +1,4 @@ +from __future__ import annotations + +from .callable import ConcretizedCallable +from .embedding import Embedding diff --git a/qadence_embeddings/callable.py b/qadence_embeddings/callable.py new file mode 100644 index 0000000..f272b40 --- /dev/null +++ b/qadence_embeddings/callable.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +from importlib import import_module +from logging import getLogger +from typing import Tuple + +from numpy.typing import ArrayLike + +logger = getLogger(__name__) + +ARRAYLIKE_FN_MAP = { + "torch": ("torch", "tensor"), + "jax": ("jax.numpy", "array"), + "numpy": ("numpy", "array"), +} + + +DEFAULT_JAX_MAPPING = { + "mul": ("jax.numpy", "multiply"), + "sub": ("jax.numpy", "subtract"), + "div": ("jax.numpy", "divide"), +} +DEFAULT_TORCH_MAPPING: dict = {} +DEFAULT_NUMPY_MAPPING = { + "mul": ("numpy", "multiply"), + "sub": ("numpy", "subtract"), + "div": ("numpy", "divide"), +} + +DEFAULT_INSTRUCTION_MAPPING = { + "torch": DEFAULT_TORCH_MAPPING, + "jax": DEFAULT_JAX_MAPPING, + "numpy": DEFAULT_NUMPY_MAPPING, +} + + +class ConcretizedCallable: + """Transform an abstract function name and arguments into + a callable in a linear algebra engine which can be evaluated + using user input.""" + + def __init__( + self, + call_name: str, + abstract_args: list[str | float | int], + instruction_mapping: dict[str, Tuple[str, str]] = dict(), + engine_name: str = "torch", + device: str = "cpu", + ) -> None: + instruction_mapping = { + **instruction_mapping, + **DEFAULT_INSTRUCTION_MAPPING[engine_name], + } + self.call_name = call_name + self.abstract_args = abstract_args + self.engine_name = engine_name + self.device = device + self.engine_call = None + engine_call = None + engine = None + try: + engine_name, fn_name = ARRAYLIKE_FN_MAP[engine_name] + engine = import_module(engine_name) + self.arraylike_fn = getattr(engine, fn_name) + except (ModuleNotFoundError, ImportError) as e: + logger.error(f"Unable to import {engine_call} due to {e}.") + + try: + try: + self.engine_call = getattr(engine, call_name) + except AttributeError: + pass + if self.engine_call is None: + mod, fn = instruction_mapping[call_name] + self.engine_call = getattr(import_module(mod), fn) + except (ImportError, KeyError) as e: + logger.error( + f"Requested function {call_name} can not be imported from {engine_name} and is\ + not in instruction_mapping {instruction_mapping} due to {e}." + ) + + """Convert a generic abstract function call and + a list of symbolic or constant parameters + into a concretized Callable in a particular engine. + which can be evaluated using + a inputs dict. + + Arguments: + call_name: The name of the function + abstract_args: A list of strings (in the case of parameters) and numeric constants + denoting the arguments for `call_name` + instruction_mapping: A dict mapping from an abstract call_name to its name in an engine. + engine_name: The engine to use to create the callable. + + Example: + ``` + In [11]: call = ConcretizedCallable('sin', ['x'], engine_name='numpy') + In [12]: call({'x': 0.5}) + Out[12]: 0.479425538604203 + + In [13]: call = ConcretizedCallable('sin', ['x'], engine_name='torch') + In [14]: call({'x': torch.rand(1)}) + Out[14]: tensor([0.5531]) + + In [15]: call = ConcretizedCallable('sin', ['x'], engine_name='jax') + In [16]: call({'x': 0.5}) + Out[16]: Array(0.47942555, dtype=float32, weak_type=True) + ``` + """ + + def evaluate(self, inputs: dict[str, ArrayLike] = dict()) -> ArrayLike: + arraylike_args = [] + for symbol_or_numeric in self.abstract_args: + if isinstance(symbol_or_numeric, (float, int)): + arraylike_args.append( + self.arraylike_fn(symbol_or_numeric, device=self.device) + ) + elif isinstance(symbol_or_numeric, str): + arraylike_args.append(inputs[symbol_or_numeric]) + return self.engine_call(*arraylike_args) # type: ignore[misc] + + def __call__(self, inputs: dict[str, ArrayLike] = dict()) -> ArrayLike: + return self.evaluate(inputs) diff --git a/qadence_embeddings/embedding.py b/qadence_embeddings/embedding.py new file mode 100644 index 0000000..f043dcf --- /dev/null +++ b/qadence_embeddings/embedding.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +from importlib import import_module +from logging import getLogger +from typing import Any, Optional + +from numpy.typing import ArrayLike, DTypeLike + +from .callable import ConcretizedCallable + +logger = getLogger(__name__) + + +def init_param( + engine_name: str, trainable: bool = True, device: str = "cpu" +) -> ArrayLike: + engine = import_module(engine_name) + if engine_name == "jax": + return engine.random.uniform(engine.random.PRNGKey(42), shape=(1,)) + elif engine_name == "torch": + return engine.rand(1, requires_grad=trainable, device=device) + elif engine_name == "numpy": + return engine.random.uniform(0, 1) + + +class Embedding: + """ + A class carrying information about user-facing parameters, a.k.a root parameters + as well as a mapping from interemediate and leaf variables using in expressions + or directly as parameters for linear algebra operations. + """ + + def __init__( + self, + vparam_names: list[str] = [], + fparam_names: list[str] = [], + tparam_name: Optional[str] = None, + var_to_call: dict[str, ConcretizedCallable] = dict(), + engine_name: str = "torch", + device: str = "cpu", + ) -> None: + + self.vparams = { + vp: init_param(engine_name, trainable=True, device=device) + for vp in vparam_names + } + self.fparam_names: list[str] = fparam_names + self.tparam_name = tparam_name + self.var_to_call: dict[str, ConcretizedCallable] = var_to_call + self._dtype: DTypeLike = None + self.time_dependent_vars: list[str] = [] + self._device = device + self._time_vars_identified = False + + @property + def root_param_names(self) -> list[str]: + return list(self.vparams.keys()) + self.fparam_names + + def embed_all( + self, + inputs: dict[str, ArrayLike], + ) -> dict[str, ArrayLike]: + """The standard embedding of all intermediate and leaf parameters. + Include the root_params, i.e., the vparams and fparams original values + to be reused in computations. + """ + for intermediate_or_leaf_var, engine_callable in self.var_to_call.items(): + # We mutate the original inputs dict and include intermediates and leaves. + if not self._time_vars_identified: + # we do this only on the first embedding call + if self.tparam_name and any( + [ + p in [self.tparam_name] + self.time_dependent_vars + for p in engine_callable.abstract_args + ] # we check if any parameter in the callables args is time + # or depends on an intermediate variable which itself depends on time + ): + self.time_dependent_vars.append(intermediate_or_leaf_var) + # we remember which parameters depend on time + inputs[intermediate_or_leaf_var] = engine_callable(inputs) + self._time_vars_identified = True + return inputs + + def reembed_time( + self, + embedded_params: dict[str, ArrayLike], + tparam_value: ArrayLike, + ) -> dict[str, ArrayLike]: + """Receive already embedded params containing intermediate and leaf parameters + and recalculate the those which are dependent on the time parameter using the new value + `tparam_value`. + """ + assert self.tparam_name is not None + embedded_params[self.tparam_name] = tparam_value + for time_dependent_param in self.time_dependent_vars: + embedded_params[time_dependent_param] = self.var_to_call[ + time_dependent_param + ](embedded_params) + return embedded_params + + def __call__(self, inputs: dict[str, ArrayLike]) -> dict[str, ArrayLike]: + """Functional version of legacy embedding: Return a new dictionary\ + with all embedded parameters.""" + return self.embed_all(inputs) + + @property + def dtype(self) -> DTypeLike: + return self._dtype + + @property + def device(self) -> str: + return self._device + + def to(self, args: Any, kwargs: Any) -> None: + # TODO move to device and dtype + pass diff --git a/tests/test_callable.py b/tests/test_callable.py new file mode 100644 index 0000000..5faf29b --- /dev/null +++ b/tests/test_callable.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +import numpy as np +import pytest +import torch + +from qadence_embeddings.callable import ConcretizedCallable + + +@pytest.mark.parametrize( + "fn", ["sin", "cos", "log", "tanh", "tan", "acos", "sin", "sqrt", "square"] +) +def test_univariate(fn: str) -> None: + results = [] + x = np.random.uniform(0, 1) + for engine_name in ["jax", "torch", "numpy"]: + native_call = ConcretizedCallable(fn, ["x"], {}, engine_name) + native_result = native_call( + {"x": (torch.tensor(x) if engine_name == "torch" else x)} + ) + results.append(native_result.item()) + assert np.allclose(results[0], results[1]) and np.allclose(results[0], results[2]) + + +@pytest.mark.parametrize("fn", ["mul", "add", "div", "sub"]) +def test_multivariate(fn: str) -> None: + results = [] + x = np.random.randn(1) + y = np.random.randn(1) + for engine_name in ["jax", "torch", "numpy"]: + native_call = ConcretizedCallable(fn, ["x", "y"], {}, engine_name) + native_result = native_call( + { + "x": torch.tensor(x) if engine_name == "torch" else x, + "y": torch.tensor(y) if engine_name == "torch" else y, + } + ) + results.append(native_result.item()) + assert np.allclose(results[0], results[1]) and np.allclose(results[0], results[2]) diff --git a/tests/test_embedding.py b/tests/test_embedding.py new file mode 100644 index 0000000..5b1660d --- /dev/null +++ b/tests/test_embedding.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +import numpy as np +import torch + +from qadence_embeddings.callable import ConcretizedCallable +from qadence_embeddings.embedding import Embedding + + +def test_embedding() -> None: + x = np.random.uniform(0, 1) + theta = np.random.uniform(0, 1) + results = [] + for engine_name in ["jax", "torch", "numpy"]: + v_params = ["theta"] + f_params = ["x"] + leaf0, native_call0 = "%0", ConcretizedCallable( + "mul", ["x", "theta"], {}, engine_name + ) + embedding = Embedding( + v_params, + f_params, + None, + var_to_call={leaf0: native_call0}, + engine_name=engine_name, + ) + inputs = { + "x": (torch.tensor(x) if engine_name == "torch" else x), + "theta": (torch.tensor(theta) if engine_name == "torch" else theta), + } + eval_0 = embedding.var_to_call["%0"](inputs) + results.append(eval_0.item()) + assert np.allclose(results[0], results[1]) and np.allclose(results[0], results[2]) + + +def test_reembedding() -> None: + x = np.random.uniform(0, 1) + theta = np.random.uniform(0, 1) + t = np.random.uniform(0, 1) + t_reembed = np.random.uniform(0, 1) + results = [] + reembedded_results = [] + for engine_name in ["jax", "torch", "numpy"]: + v_params = ["theta"] + f_params = ["x"] + tparam = "t" + leaf0, native_call0 = "%0", ConcretizedCallable( + "mul", ["x", "theta"], {}, engine_name + ) + leaf1, native_call1 = "%1", ConcretizedCallable( + "mul", ["t", "%0"], {}, engine_name + ) + + leaf2, native_call2 = "%2", ConcretizedCallable("sin", ["%1"], {}, engine_name) + embedding = Embedding( + v_params, + f_params, + tparam, + var_to_call={leaf0: native_call0, leaf1: native_call1, leaf2: native_call2}, + engine_name=engine_name, + ) + inputs = { + "x": (torch.tensor(x) if engine_name == "torch" else x), + "theta": (torch.tensor(theta) if engine_name == "torch" else theta), + "t": (torch.tensor(t) if engine_name == "torch" else t), + } + all_params = embedding.embed_all(inputs) + new_tparam_val = ( + torch.tensor(t_reembed) if engine_name == "torch" else t_reembed + ) + reembedded_params = embedding.reembed_time(all_params, new_tparam_val) + results.append(all_params["%2"].item()) + reembedded_results.append(reembedded_params["%2"].item()) + assert all([p in ["%1", "%2"] for p in embedding.time_dependent_vars]) + assert "%0" not in embedding.time_dependent_vars + assert np.allclose(results[0], results[1]) and np.allclose(results[0], results[2]) + assert np.allclose(reembedded_results[0], reembedded_results[1]) and np.allclose( + reembedded_results[0], reembedded_results[2] + )