diff --git a/content/en/docs/components/training/images/training-operator-overview.drawio.png b/content/en/docs/components/training/images/training-operator-overview.drawio.png deleted file mode 100644 index 041c68fa42..0000000000 Binary files a/content/en/docs/components/training/images/training-operator-overview.drawio.png and /dev/null differ diff --git a/content/en/docs/components/training/images/training-operator-overview.drawio.svg b/content/en/docs/components/training/images/training-operator-overview.drawio.svg new file mode 100644 index 0000000000..c4bbf51682 --- /dev/null +++ b/content/en/docs/components/training/images/training-operator-overview.drawio.svg @@ -0,0 +1,4 @@ + + + +
Features
Training Operator
Python SDK
API
Distributed Training (e.g. PyTorchJob) 
Fine Tuning
All-Reduce Style Training with MPI
High Performance Computing (HPC) with MPI
Job Scheduling with Volcano, Kueue
Elastic Training
JAX Light Stroke
\ No newline at end of file diff --git a/content/en/docs/components/training/overview.md b/content/en/docs/components/training/overview.md index edee739b97..4ae6982ade 100644 --- a/content/en/docs/components/training/overview.md +++ b/content/en/docs/components/training/overview.md @@ -10,7 +10,7 @@ weight = 10 The Training Operator is a Kubernetes-native project for fine-tuning and scalable distributed training of machine learning (ML) models created with different ML frameworks such as -PyTorch, TensorFlow, XGBoost, and others. +PyTorch, TensorFlow, XGBoost, JAX, and others. You can integrate other ML libraries such as [HuggingFace](https://huggingface.co), [DeepSpeed](https://github.com/microsoft/DeepSpeed), or [Megatron-LM](https://github.com/NVIDIA/Megatron-LM) @@ -26,7 +26,7 @@ supports running Message Passing Interface (MPI) on Kubernetes which is heavily The Training Operator implements the V1 API version of MPI Operator. For the MPI Operator V2 version, please follow [this guide](/docs/components/training/user-guides/mpi/) to install MPI Operator V2. -Training Operator Overview @@ -70,6 +70,7 @@ for each ML framework: | XGBoost | [XGBoostJob](/docs/components/training/user-guides/xgboost/) | | MPI | [MPIJob](/docs/components/training/user-guides/mpi/) | | PaddlePaddle | [PaddleJob](/docs/components/training/user-guides/paddle/) | +| JAX | [JAXJob](/docs/components/training/user-guides/jax/) | ## Next steps diff --git a/content/en/docs/components/training/user-guides/jax.md b/content/en/docs/components/training/user-guides/jax.md new file mode 100644 index 0000000000..1639a2df8b --- /dev/null +++ b/content/en/docs/components/training/user-guides/jax.md @@ -0,0 +1,114 @@ ++++ +title = "JAX Training (JAXJob)" +description = "Using JAXJob to train a model with JAX" +weight = 60 ++++ + +This page describes `JAXJob` for training a machine learning model with [JAX](https://jax.readthedocs.io/en/latest/). + +The `JAXJob` is a Kubernetes +[custom resource](https://kubernetes.io/docs/concepts/extend-kubernetes/api-extension/custom-resources/) +to run JAX training jobs on Kubernetes. The Kubeflow implementation of +the `JAXJob` is in the [`training-operator`](https://github.com/kubeflow/training-operator). + +The current custom resource for JAX has been tested to run multiple processes on CPUs using [gloo](https://github.com/facebookincubator/gloo) for communication between CPUs. Worker with replica 0 is recognized as a JAX coordinator. Process 0 will start a JAX coordinator service exposed via the IP address of process 0 in your cluster, together with a port available on that process, to which the other processes in the cluster will connect. We are looking for user feedback to run JAXJob on GPUs and TPUs. + +## Creating a JAX training job + +You can create a training job by defining a `JAXJob` config file. See the manifests for the [simple JAXJob example](https://github.com/kubeflow/training-operator/blob/master/examples/jax/cpu-demo/demo.yaml). +You may change the Job config file based on your requirements. + +Deploy the `JAXJob` resource to start training: + +``` +kubectl create -f https://raw.githubusercontent.com/kubeflow/training-operator/refs/heads/master/examples/jax/cpu-demo/demo.yaml +``` + +You should now be able to see the created pods matching the specified number of replicas. + +``` +kubectl get pods -n kubeflow -l training.kubeflow.org/job-name=jaxjob-simple +``` + +Distributed computation takes several minutes on a CPU cluster. Logs can be inspected to see its progress. + +``` +PODNAME=$(kubectl get pods -l training.kubeflow.org/job-name=jaxjob-simple,training.kubeflow.org/replica-type=worker,training.kubeflow.org/replica-index=0 -o name -n kubeflow) +kubectl logs -f ${PODNAME} -n kubeflow +``` + +``` +I1016 14:30:28.956959 139643066051456 distributed.py:106] Starting JAX distributed service on [::]:6666 +I1016 14:30:28.959352 139643066051456 distributed.py:119] Connecting to JAX distributed service on jaxjob-simple-worker-0:6666 +I1016 14:30:30.633651 139643066051456 xla_bridge.py:895] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig' +I1016 14:30:30.638316 139643066051456 xla_bridge.py:895] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory +JAX process 0/1 initialized on jaxjob-simple-worker-0 +JAX global devices:[CpuDevice(id=0), CpuDevice(id=131072)] +JAX local devices:[CpuDevice(id=0)] +JAX device count:2 +JAX local device count:1 +[2.] +``` + +## Monitoring a JAXJob + +``` +kubectl get -o yaml jaxjobs jaxjob-simple -n kubeflow +``` + +See the status section to monitor the job status. Here is sample output when the job is successfully completed. + +```yaml +apiVersion: kubeflow.org/v1 +kind: JAXJob +metadata: + annotations: + kubectl.kubernetes.io/last-applied-configuration: | + {"apiVersion":"kubeflow.org/v1","kind":"JAXJob","metadata":{"annotations":{},"name":"jaxjob-simple","namespace":"kubeflow"},"spec":{"jaxReplicaSpecs":{"Worker":{"replicas":2,"restartPolicy":"OnFailure","template":{"spec":{"containers":[{"command":["python3","train.py"],"image":"docker.io/kubeflow/jaxjob-simple:latest","imagePullPolicy":"Always","name":"jax"}]}}}}}} + creationTimestamp: "2024-09-22T20:07:59Z" + generation: 1 + name: jaxjob-simple + namespace: kubeflow + resourceVersion: "1972" + uid: eb20c874-44fc-459b-b9a8-09f5c3ff46d3 +spec: + jaxReplicaSpecs: + Worker: + replicas: 2 + restartPolicy: OnFailure + template: + spec: + containers: + - command: + - python3 + - train.py + image: docker.io/kubeflow/jaxjob-simple:latest + imagePullPolicy: Always + name: jax +status: + completionTime: "2024-09-22T20:11:34Z" + conditions: + - lastTransitionTime: "2024-09-22T20:07:59Z" + lastUpdateTime: "2024-09-22T20:07:59Z" + message: JAXJob jaxjob-simple is created. + reason: JAXJobCreated + status: "True" + type: Created + - lastTransitionTime: "2024-09-22T20:11:28Z" + lastUpdateTime: "2024-09-22T20:11:28Z" + message: JAXJob kubeflow/jaxjob-simple is running. + reason: JAXJobRunning + status: "False" + type: Running + - lastTransitionTime: "2024-09-22T20:11:34Z" + lastUpdateTime: "2024-09-22T20:11:34Z" + message: JAXJob kubeflow/jaxjob-simple successfully completed. + reason: JAXJobSucceeded + status: "True" + type: Succeeded + replicaStatuses: + Worker: + selector: training.kubeflow.org/job-name=jaxjob-simple,training.kubeflow.org/operator-name=jaxjob-controller,training.kubeflow.org/replica-type=worker + succeeded: 2 + startTime: "2024-09-22T20:07:59Z" +``` diff --git a/content/en/docs/components/training/user-guides/job-scheduling.md b/content/en/docs/components/training/user-guides/job-scheduling.md index 10564812f1..84515ae683 100644 --- a/content/en/docs/components/training/user-guides/job-scheduling.md +++ b/content/en/docs/components/training/user-guides/job-scheduling.md @@ -1,7 +1,7 @@ +++ title = "Job Scheduling" description = "How to schedule a job with gang-scheduling" -weight = 60 +weight = 70 +++ This guide describes how to use [Kueue](https://kueue.sigs.k8s.io/), diff --git a/content/en/docs/components/training/user-guides/mpi.md b/content/en/docs/components/training/user-guides/mpi.md index be1910f8cd..ba45067e38 100644 --- a/content/en/docs/components/training/user-guides/mpi.md +++ b/content/en/docs/components/training/user-guides/mpi.md @@ -1,7 +1,7 @@ +++ title = "MPI Training (MPIJob)" description = "Instructions for using MPI for training" -weight = 60 +weight = 70 +++ {{% beta-status