Skip to content

Commit

Permalink
Merge pull request #423 from JaxGaussianProcesses/update_keyarray_typing
Browse files Browse the repository at this point in the history
Update typing.py
  • Loading branch information
daniel-dodd authored Nov 29, 2023
2 parents 48f2db9 + 33e3ce1 commit ac47576
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion gpjax/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,18 @@
Callable,
Union,
)
from jax.random import KeyArray as JAXKeyArray
from jaxtyping import (
Array as JAXArray,
Bool,
Float,
Int,
Key,
UInt32,
)
from numpy import ndarray as NumpyArray

OldKeyArray = UInt32[JAXArray, "2"]
JAXKeyArray = Key[JAXArray, ""]
KeyArray = Union[
OldKeyArray, JAXKeyArray
] # for compatibility regardless of enable_custom_prng setting
Expand Down

0 comments on commit ac47576

Please sign in to comment.