diff --git a/catalax/__init__.py b/catalax/__init__.py index bc98ace..2246d65 100644 --- a/catalax/__init__.py +++ b/catalax/__init__.py @@ -3,3 +3,44 @@ from .tools.visualization import visualize __version__ = "0.2.0" + + +def set_host_count(n: int = 1): + """ + Sets the number of hosts to be used by JAX for parallel execution. + + Args: + n (int): The number of hosts to use. Defaults to 1. + """ + import numpyro + + numpyro.set_host_count(n) + + +def set_platform(platform: str = "cpu"): + """ + Sets the platform for JAX. + + Args: + platform (str): The platform to use. Must be one of 'cpu' or 'gpu'. Defaults to 'cpu'. + + Raises: + AssertionError: If the platform is not 'cpu' or 'gpu'. + """ + import numpyro + + assert platform in ["cpu", "gpu"], "platform must be one of 'cpu' or 'gpu'" + + numpyro.set_platform(platform) + + +def enable_x64(use_x64: bool = True): + """ + Enables the use of 64-bit precision in JAX. + + Args: + use_x64 (bool, optional): Whether to enable 64-bit precision. Defaults to True. + """ + import numpyro + + numpyro.enable_x64(use_x64=use_x64) diff --git a/examples/SurrogateHMC.ipynb b/examples/SurrogateHMC.ipynb index 626e52a..53e03d1 100644 --- a/examples/SurrogateHMC.ipynb +++ b/examples/SurrogateHMC.ipynb @@ -24,14 +24,14 @@ "outputs": [], "source": [ "import json\n", - "import numpyro\n", "\n", "import catalax as ctx\n", "import catalax.neural as ctn\n", "import catalax.mcmc as cmc\n", "import jax.numpy as jnp\n", "\n", - "numpyro.set_host_device_count(5)" + "# Set the number of workers used for parallelizations\n", + "ctx.set_host_count(5)" ] }, {