diff --git a/src/transonic/backends/jax.py b/src/transonic/backends/jax.py index 98041c8..4669b60 100644 --- a/src/transonic/backends/jax.py +++ b/src/transonic/backends/jax.py @@ -41,9 +41,7 @@ def add_jax_comments(code): # Add JIT decorator if isinstance(node, gast.FunctionDef): - new_body.append( - CommentLine("# __protected__ @jit") - ) + new_body.append(CommentLine("# __protected__ @jit")) new_body.append(node) mod.body = new_body diff --git a/src/transonic/util.py b/src/transonic/util.py index da7b4ee..05cdc03 100644 --- a/src/transonic/util.py +++ b/src/transonic/util.py @@ -131,7 +131,7 @@ def can_import_accelerator(backend: str = backend_default): import numba except ImportError: return False - elif backend =="jax": + elif backend == "jax": try: import jax except ImportError: diff --git a/tests/test_init_transonified.py b/tests/test_init_transonified.py index 8850419..cc3c870 100644 --- a/tests/test_init_transonified.py +++ b/tests/test_init_transonified.py @@ -99,6 +99,9 @@ def test_transonified(self): for_test_init.func1(1.1, 2.2) for_test_init.check_class() + @unittest.skipIf( + backend.name == "jax", "Not yet supported by our JAX backend" + ) @unittest.skipIf( sys.platform.startswith("win") or not can_import_accelerator(), f"{backend.name} is required for TRANSONIC_COMPILE_AT_IMPORT",