JAX-vLLM offloading #19
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| name: JAX-vLLM offloading | |
| on: | |
| schedule: | |
| - cron: '30 9 * * *' # Pacific Time 01:30 AM in UTC | |
| pull_request: | |
| types: | |
| - opened | |
| - reopened | |
| - ready_for_review | |
| - synchronize | |
| paths: | |
| - 'jax-inference-offloading/**' | |
| - '.github/gke-workflow/jax-vllm-offloading/**' | |
| - '.github/workflows/jax-vllm-offloading*.yml' | |
| workflow_dispatch: | |
| inputs: | |
| PUBLISH: | |
| type: boolean | |
| description: Publish dated images and update the 'latest' tag? | |
| default: false | |
| required: false | |
| concurrency: | |
| group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} | |
| cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} | |
| permissions: | |
| contents: read # to fetch code | |
| actions: write # to cancel previous workflows | |
| packages: write # to upload containers | |
| jobs: | |
| metadata: | |
| runs-on: ubuntu-22.04 | |
| outputs: | |
| BUILD_DATE: ${{ steps.date.outputs.BUILD_DATE }} | |
| PUBLISH: ${{ steps.if-publish.outputs.PUBLISH }} | |
| steps: | |
| - name: Set build date | |
| id: date | |
| shell: bash -x -e {0} | |
| run: | | |
| BUILD_DATE=$(TZ='US/Los_Angeles' date '+%Y-%m-%d') | |
| echo "BUILD_DATE=${BUILD_DATE}" >> $GITHUB_OUTPUT | |
| - name: Determine whether results will be 'published' | |
| id: if-publish | |
| shell: bash -x -e {0} | |
| run: | | |
| echo "PUBLISH=${{ github.event_name == 'schedule' || inputs.PUBLISH }}" >> $GITHUB_OUTPUT | |
| amd64: | |
| needs: metadata | |
| runs-on: [self-hosted, amd64, small] | |
| outputs: | |
| DOCKER_TAG_MEALKIT: ${{ steps.build-container.outputs.DOCKER_TAG_MEALKIT }} | |
| DOCKER_TAG_FINAL: ${{ steps.build-container.outputs.DOCKER_TAG_FINAL }} | |
| steps: | |
| - name: Checkout repository | |
| uses: actions/checkout@v4 | |
| - name: Build container | |
| id: build-container | |
| uses: ./.github/actions/build-container | |
| with: | |
| ARCHITECTURE: amd64 | |
| ARTIFACT_NAME: artifact-jio-build | |
| BADGE_FILENAME: badge-jio-build | |
| BASE_IMAGE: nvcr.io/nvidia/cuda-dl-base:25.06-cuda12.9-devel-ubuntu24.04 | |
| BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }} | |
| CONTAINER_NAME: jio | |
| DOCKERFILE: jax-inference-offloading/dockerfile/oss.dockerfile | |
| RUNNER_SIZE: small | |
| ssh-private-key: ${{ secrets.SSH_PRIVATE_KEY }} | |
| ssh-known-hosts: ${{ vars.SSH_KNOWN_HOSTS }} | |
| github-token: ${{ secrets.GITHUB_TOKEN }} | |
| EXTRA_BUILD_ARGS: | | |
| REF_JIO=${{ github.ref }} | |
| arm64: | |
| needs: metadata | |
| runs-on: [self-hosted, arm64, small] | |
| outputs: | |
| DOCKER_TAG_MEALKIT: ${{ steps.build-container.outputs.DOCKER_TAG_MEALKIT }} | |
| DOCKER_TAG_FINAL: ${{ steps.build-container.outputs.DOCKER_TAG_FINAL }} | |
| steps: | |
| - name: Checkout repository | |
| uses: actions/checkout@v4 | |
| - name: Build container | |
| id: build-container | |
| uses: ./.github/actions/build-container | |
| with: | |
| ARCHITECTURE: arm64 | |
| ARTIFACT_NAME: artifact-jio-build | |
| BADGE_FILENAME: badge-jio-build | |
| BASE_IMAGE: nvcr.io/nvidia/cuda-dl-base:25.06-cuda12.9-devel-ubuntu24.04 | |
| BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }} | |
| CONTAINER_NAME: jio | |
| DOCKERFILE: jax-inference-offloading/dockerfile/oss.dockerfile | |
| RUNNER_SIZE: small | |
| ssh-private-key: ${{ secrets.SSH_PRIVATE_KEY }} | |
| ssh-known-hosts: ${{ vars.SSH_KNOWN_HOSTS }} | |
| github-token: ${{ secrets.GITHUB_TOKEN }} | |
| EXTRA_BUILD_ARGS: | | |
| REF_JIO=${{ github.ref }} | |
| collect-docker-tags: | |
| runs-on: ubuntu-22.04 | |
| if: ${{ !cancelled() }} | |
| needs: | |
| - amd64 | |
| - arm64 | |
| outputs: | |
| TAGS: ${{ steps.collect-tags.outputs.TAGS }} | |
| steps: | |
| - name: Save docker tags as a JSON object | |
| id: collect-tags | |
| run: | | |
| TAGS=$(cat <<EOF | jq -c | |
| [\ | |
| {"flavor": "jax-inference-offloading", "stage": "final", "priority": 900, "tag": "${{ needs.amd64.outputs.DOCKER_TAG_FINAL }}"},\ | |
| {"flavor": "jax-inference-offloading-mealkit", "stage": "mealkit", "priority": 500, "tag": "${{ needs.amd64.outputs.DOCKER_TAG_MEALKIT }}"},\ | |
| {"flavor": "jax-inference-offloading", "stage": "final", "priority": 900, "tag": "${{ needs.arm64.outputs.DOCKER_TAG_FINAL }}"},\ | |
| {"flavor": "jax-inference-offloading-mealkit", "stage": "mealkit", "priority": 500, "tag": "${{ needs.arm64.outputs.DOCKER_TAG_MEALKIT }}"},\ | |
| {}\ | |
| ] | |
| EOF | |
| ) | |
| echo "TAGS=${TAGS}" >> $GITHUB_OUTPUT | |
| make-publish-configs: | |
| runs-on: ubuntu-22.04 | |
| if: ${{ !cancelled() }} | |
| env: | |
| IMAGE_REPO: ${{ needs.metadata.outputs.PUBLISH == 'true' && 'jax' || 'mock-jax' }} | |
| needs: | |
| - metadata | |
| - collect-docker-tags | |
| outputs: | |
| PUBLISH_CONFIGS: ${{ steps.generate-configs.outputs.PUBLISH_CONFIGS }} | |
| steps: | |
| - id: generate-configs | |
| shell: bash -eu -o pipefail {0} | |
| run: | | |
| declare -a FLAVORS=( | |
| jax-inference-offloading | |
| jax-inference-offloading-mealkit | |
| ) | |
| ## create JSON specs for a 1D matrix of container publication jobs | |
| ALL_TAGS=$( | |
| echo '${{ needs.collect-docker-tags.outputs.TAGS }}' | jq -s 'add' | |
| ) | |
| PUBLISH_CONFIGS='[]' | |
| for flavor in "${FLAVORS[@]}";do | |
| # collect images for different platforms, e.g. amd64 and arm64 | |
| matching_tags=$( | |
| echo "$ALL_TAGS" |\ | |
| jq -c ".[] | select(.flavor == \"${flavor}\" and .tag != \"\")" | |
| ) | |
| # source_image is a list of all platform-specific tags | |
| source_image=$(echo "${matching_tags}" | jq -c "[.tag]" | jq -s 'add') | |
| # if the build job failed without producing any images, skip this flavor | |
| n_source_images=$(echo "$source_image" | jq 'length') | |
| if [[ $n_source_images -gt 0 ]]; then | |
| # Determine stage from flavor name | |
| if [[ "${flavor}" == *"-mealkit" ]]; then | |
| stage="mealkit" | |
| else | |
| stage="final" | |
| fi | |
| echo "PUBLISH image $flavor with $n_source_images containers" | |
| # tag priority is the highest priority of all platform-specific tags | |
| priority=$(echo "${matching_tags}" | jq -r ".priority" | jq -s 'max') | |
| # All images go to the same repository (jax or mock-jax) | |
| target_image=${IMAGE_REPO} | |
| PUBLISH_CONFIGS=$( | |
| echo ${PUBLISH_CONFIGS} | jq -c ". + [{ | |
| \"flavor\": \"${flavor}\", | |
| \"target_image\": \"${target_image}\", | |
| \"priority\": \"${priority}\", | |
| \"source_image\": ${source_image}, | |
| \"stage\": \"${stage}\" | |
| }]" | |
| ) | |
| else | |
| echo "SKIPPED image $flavor with 0 containers" | |
| fi | |
| done | |
| PUBLISH_CONFIGS=$(echo "$PUBLISH_CONFIGS" | jq -c '{"config": .}') | |
| echo ${PUBLISH_CONFIGS} | jq | |
| echo "PUBLISH_CONFIGS=${PUBLISH_CONFIGS}" >> $GITHUB_OUTPUT | |
| publish-containers: | |
| needs: | |
| - metadata | |
| - make-publish-configs | |
| if: ${{ !cancelled() && needs.make-publish-configs.outputs.PUBLISH_CONFIGS.config != '{"config":[]}' }} | |
| strategy: | |
| fail-fast: false | |
| matrix: ${{ fromJson(needs.make-publish-configs.outputs.PUBLISH_CONFIGS) }} | |
| uses: ./.github/workflows/_publish_container.yaml | |
| with: | |
| ARTIFACT_NAME: ${{ matrix.config.stage }}-${{ matrix.config.flavor }} | |
| ARTIFACT_TAG: ${{ matrix.config.flavor }}-${{ needs.metadata.outputs.BUILD_DATE }} | |
| SOURCE_IMAGE: ${{ join(matrix.config.source_image, ' ') }} | |
| TARGET_IMAGE: ${{ matrix.config.target_image }} | |
| TARGET_TAGS: | | |
| type=raw,value=${{ matrix.config.flavor }},priority=${{ matrix.config.priority }} | |
| type=raw,value=${{ matrix.config.flavor }}-${{ needs.metadata.outputs.BUILD_DATE }},priority=${{ matrix.config.priority }} | |
| finalize: | |
| needs: [metadata, amd64, arm64, publish-containers] | |
| if: "!cancelled()" | |
| uses: ./.github/workflows/_finalize.yaml | |
| with: | |
| BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }} | |
| PUBLISH_BADGE: ${{ needs.metadata.outputs.PUBLISH == 'true' }} | |
| secrets: inherit | |
| transfer-gke-xpk: | |
| uses: ./.github/workflows/jax-vllm-offloading-gke-transfer.yml | |
| needs: amd64 | |
| with: | |
| JAX_VLLM_OFFLOADING_IMAGE: ${{ needs.amd64.outputs.DOCKER_TAG_FINAL }} | |
| grpo-gke-xpk: | |
| uses: ./.github/workflows/jax-vllm-offloading-gke-grpo.yml | |
| needs: amd64 | |
| with: | |
| JAX_VLLM_OFFLOADING_IMAGE: ${{ needs.amd64.outputs.DOCKER_TAG_FINAL }} | |
| jax-vllm-offloading-transfer-eks: | |
| needs: amd64 | |
| runs-on: eks | |
| env: | |
| JOB_NAME: jax-vllm-offloading-${{ github.run_id }} | |
| JAX_VLLM_OFFLOADING_IMAGE: ${{ needs.amd64.outputs.DOCKER_TAG_FINAL }} | |
| MODEL: "meta-llama/Llama-3.1-8B-Instruct" | |
| steps: | |
| - uses: actions/checkout@v4 | |
| - name: Login to GitHub Container Registry | |
| uses: docker/login-action@v3 | |
| with: | |
| registry: ghcr.io | |
| username: ${{ github.repository_owner }} | |
| password: ${{ secrets.GITHUB_TOKEN }} | |
| - name: K8s GHCR store and delete token | |
| id: store-token | |
| uses: ./.github/actions/store-delete-k8s-ghcr | |
| - name: Configure jax-vllm job | |
| run: | | |
| yq -i ' | |
| # 1. Update the JobSet Name | |
| .metadata.name = strenv(JOB_NAME) | |
| # 2. Update Image (Applies to Gateway, vLLM, and JAX nodes) | |
| | .spec.replicatedJobs[].template.spec.template.spec.containers[].image = strenv(JAX_VLLM_OFFLOADING_IMAGE) | |
| # 3. Update Model Name (Finds the env var named "MODEL_NAME" and updates its value) | |
| | (.spec.replicatedJobs[].template.spec.template.spec.containers[].env[] | select(.name == "MODEL_NAME").value) = strenv(MODEL) | |
| # 4. Add imagePullSecrets to all replicatedJobs (Gateway, vLLM, and JAX nodes) | |
| | .spec.replicatedJobs[].template.spec.template.spec.imagePullSecrets[].name = "${{ steps.store-token.outputs.token-name }}" | |
| ' .github/eks-workflow-files/jio-eks/jio-template.yaml | |
| git diff .github/eks-workflow-files/jio-eks/jio-template.yaml | |
| - name: Apply jax-vllm offloading job to EKS cluster | |
| uses: ./.github/actions/submit-delete-k8s-jobset | |
| with: | |
| jobset-config-file: ".github/eks-workflow-files/jio-eks/jio-template.yaml" | |
| jobset-name: ${{ env.JOB_NAME }} |