diff --git a/content/en/docs/components/training/images/training-operator-overview.drawio.png b/content/en/docs/components/training/images/training-operator-overview.drawio.png
deleted file mode 100644
index 041c68fa42..0000000000
Binary files a/content/en/docs/components/training/images/training-operator-overview.drawio.png and /dev/null differ
diff --git a/content/en/docs/components/training/images/training-operator-overview.drawio.svg b/content/en/docs/components/training/images/training-operator-overview.drawio.svg
new file mode 100644
index 0000000000..c4bbf51682
--- /dev/null
+++ b/content/en/docs/components/training/images/training-operator-overview.drawio.svg
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/content/en/docs/components/training/overview.md b/content/en/docs/components/training/overview.md
index edee739b97..4ae6982ade 100644
--- a/content/en/docs/components/training/overview.md
+++ b/content/en/docs/components/training/overview.md
@@ -10,7 +10,7 @@ weight = 10
The Training Operator is a Kubernetes-native project for fine-tuning and scalable
distributed training of machine learning (ML) models created with different ML frameworks such as
-PyTorch, TensorFlow, XGBoost, and others.
+PyTorch, TensorFlow, XGBoost, JAX, and others.
You can integrate other ML libraries such as [HuggingFace](https://huggingface.co),
[DeepSpeed](https://github.com/microsoft/DeepSpeed), or [Megatron-LM](https://github.com/NVIDIA/Megatron-LM)
@@ -26,7 +26,7 @@ supports running Message Passing Interface (MPI) on Kubernetes which is heavily
The Training Operator implements the V1 API version of MPI Operator. For the MPI Operator V2 version,
please follow [this guide](/docs/components/training/user-guides/mpi/) to install MPI Operator V2.
-
@@ -70,6 +70,7 @@ for each ML framework:
| XGBoost | [XGBoostJob](/docs/components/training/user-guides/xgboost/) |
| MPI | [MPIJob](/docs/components/training/user-guides/mpi/) |
| PaddlePaddle | [PaddleJob](/docs/components/training/user-guides/paddle/) |
+| JAX | [JAXJob](/docs/components/training/user-guides/jax/) |
## Next steps
diff --git a/content/en/docs/components/training/user-guides/jax.md b/content/en/docs/components/training/user-guides/jax.md
new file mode 100644
index 0000000000..1639a2df8b
--- /dev/null
+++ b/content/en/docs/components/training/user-guides/jax.md
@@ -0,0 +1,114 @@
++++
+title = "JAX Training (JAXJob)"
+description = "Using JAXJob to train a model with JAX"
+weight = 60
++++
+
+This page describes `JAXJob` for training a machine learning model with [JAX](https://jax.readthedocs.io/en/latest/).
+
+The `JAXJob` is a Kubernetes
+[custom resource](https://kubernetes.io/docs/concepts/extend-kubernetes/api-extension/custom-resources/)
+to run JAX training jobs on Kubernetes. The Kubeflow implementation of
+the `JAXJob` is in the [`training-operator`](https://github.com/kubeflow/training-operator).
+
+The current custom resource for JAX has been tested to run multiple processes on CPUs using [gloo](https://github.com/facebookincubator/gloo) for communication between CPUs. Worker with replica 0 is recognized as a JAX coordinator. Process 0 will start a JAX coordinator service exposed via the IP address of process 0 in your cluster, together with a port available on that process, to which the other processes in the cluster will connect. We are looking for user feedback to run JAXJob on GPUs and TPUs.
+
+## Creating a JAX training job
+
+You can create a training job by defining a `JAXJob` config file. See the manifests for the [simple JAXJob example](https://github.com/kubeflow/training-operator/blob/master/examples/jax/cpu-demo/demo.yaml).
+You may change the Job config file based on your requirements.
+
+Deploy the `JAXJob` resource to start training:
+
+```
+kubectl create -f https://raw.githubusercontent.com/kubeflow/training-operator/refs/heads/master/examples/jax/cpu-demo/demo.yaml
+```
+
+You should now be able to see the created pods matching the specified number of replicas.
+
+```
+kubectl get pods -n kubeflow -l training.kubeflow.org/job-name=jaxjob-simple
+```
+
+Distributed computation takes several minutes on a CPU cluster. Logs can be inspected to see its progress.
+
+```
+PODNAME=$(kubectl get pods -l training.kubeflow.org/job-name=jaxjob-simple,training.kubeflow.org/replica-type=worker,training.kubeflow.org/replica-index=0 -o name -n kubeflow)
+kubectl logs -f ${PODNAME} -n kubeflow
+```
+
+```
+I1016 14:30:28.956959 139643066051456 distributed.py:106] Starting JAX distributed service on [::]:6666
+I1016 14:30:28.959352 139643066051456 distributed.py:119] Connecting to JAX distributed service on jaxjob-simple-worker-0:6666
+I1016 14:30:30.633651 139643066051456 xla_bridge.py:895] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
+I1016 14:30:30.638316 139643066051456 xla_bridge.py:895] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
+JAX process 0/1 initialized on jaxjob-simple-worker-0
+JAX global devices:[CpuDevice(id=0), CpuDevice(id=131072)]
+JAX local devices:[CpuDevice(id=0)]
+JAX device count:2
+JAX local device count:1
+[2.]
+```
+
+## Monitoring a JAXJob
+
+```
+kubectl get -o yaml jaxjobs jaxjob-simple -n kubeflow
+```
+
+See the status section to monitor the job status. Here is sample output when the job is successfully completed.
+
+```yaml
+apiVersion: kubeflow.org/v1
+kind: JAXJob
+metadata:
+ annotations:
+ kubectl.kubernetes.io/last-applied-configuration: |
+ {"apiVersion":"kubeflow.org/v1","kind":"JAXJob","metadata":{"annotations":{},"name":"jaxjob-simple","namespace":"kubeflow"},"spec":{"jaxReplicaSpecs":{"Worker":{"replicas":2,"restartPolicy":"OnFailure","template":{"spec":{"containers":[{"command":["python3","train.py"],"image":"docker.io/kubeflow/jaxjob-simple:latest","imagePullPolicy":"Always","name":"jax"}]}}}}}}
+ creationTimestamp: "2024-09-22T20:07:59Z"
+ generation: 1
+ name: jaxjob-simple
+ namespace: kubeflow
+ resourceVersion: "1972"
+ uid: eb20c874-44fc-459b-b9a8-09f5c3ff46d3
+spec:
+ jaxReplicaSpecs:
+ Worker:
+ replicas: 2
+ restartPolicy: OnFailure
+ template:
+ spec:
+ containers:
+ - command:
+ - python3
+ - train.py
+ image: docker.io/kubeflow/jaxjob-simple:latest
+ imagePullPolicy: Always
+ name: jax
+status:
+ completionTime: "2024-09-22T20:11:34Z"
+ conditions:
+ - lastTransitionTime: "2024-09-22T20:07:59Z"
+ lastUpdateTime: "2024-09-22T20:07:59Z"
+ message: JAXJob jaxjob-simple is created.
+ reason: JAXJobCreated
+ status: "True"
+ type: Created
+ - lastTransitionTime: "2024-09-22T20:11:28Z"
+ lastUpdateTime: "2024-09-22T20:11:28Z"
+ message: JAXJob kubeflow/jaxjob-simple is running.
+ reason: JAXJobRunning
+ status: "False"
+ type: Running
+ - lastTransitionTime: "2024-09-22T20:11:34Z"
+ lastUpdateTime: "2024-09-22T20:11:34Z"
+ message: JAXJob kubeflow/jaxjob-simple successfully completed.
+ reason: JAXJobSucceeded
+ status: "True"
+ type: Succeeded
+ replicaStatuses:
+ Worker:
+ selector: training.kubeflow.org/job-name=jaxjob-simple,training.kubeflow.org/operator-name=jaxjob-controller,training.kubeflow.org/replica-type=worker
+ succeeded: 2
+ startTime: "2024-09-22T20:07:59Z"
+```
diff --git a/content/en/docs/components/training/user-guides/job-scheduling.md b/content/en/docs/components/training/user-guides/job-scheduling.md
index 10564812f1..84515ae683 100644
--- a/content/en/docs/components/training/user-guides/job-scheduling.md
+++ b/content/en/docs/components/training/user-guides/job-scheduling.md
@@ -1,7 +1,7 @@
+++
title = "Job Scheduling"
description = "How to schedule a job with gang-scheduling"
-weight = 60
+weight = 70
+++
This guide describes how to use [Kueue](https://kueue.sigs.k8s.io/),
diff --git a/content/en/docs/components/training/user-guides/mpi.md b/content/en/docs/components/training/user-guides/mpi.md
index be1910f8cd..ba45067e38 100644
--- a/content/en/docs/components/training/user-guides/mpi.md
+++ b/content/en/docs/components/training/user-guides/mpi.md
@@ -1,7 +1,7 @@
+++
title = "MPI Training (MPIJob)"
description = "Instructions for using MPI for training"
-weight = 60
+weight = 70
+++
{{% beta-status