Skip to content

Commit

Permalink
Add utility functions
Browse files Browse the repository at this point in the history
  • Loading branch information
JR-1991 committed Oct 12, 2023
1 parent 7725fc1 commit 60cbf4f
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 2 deletions.
41 changes: 41 additions & 0 deletions catalax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,44 @@
from .tools.visualization import visualize

__version__ = "0.2.0"


def set_host_count(n: int = 1):
"""
Sets the number of hosts to be used by JAX for parallel execution.
Args:
n (int): The number of hosts to use. Defaults to 1.
"""
import numpyro

numpyro.set_host_count(n)


def set_platform(platform: str = "cpu"):
"""
Sets the platform for JAX.
Args:
platform (str): The platform to use. Must be one of 'cpu' or 'gpu'. Defaults to 'cpu'.
Raises:
AssertionError: If the platform is not 'cpu' or 'gpu'.
"""
import numpyro

assert platform in ["cpu", "gpu"], "platform must be one of 'cpu' or 'gpu'"

numpyro.set_platform(platform)


def enable_x64(use_x64: bool = True):
"""
Enables the use of 64-bit precision in JAX.
Args:
use_x64 (bool, optional): Whether to enable 64-bit precision. Defaults to True.
"""
import numpyro

numpyro.enable_x64(use_x64=use_x64)
4 changes: 2 additions & 2 deletions examples/SurrogateHMC.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@
"outputs": [],
"source": [
"import json\n",
"import numpyro\n",
"\n",
"import catalax as ctx\n",
"import catalax.neural as ctn\n",
"import catalax.mcmc as cmc\n",
"import jax.numpy as jnp\n",
"\n",
"numpyro.set_host_device_count(5)"
"# Set the number of workers used for parallelizations\n",
"ctx.set_host_count(5)"
]
},
{
Expand Down

0 comments on commit 60cbf4f

Please sign in to comment.