Skip to content

JAX-vLLM offloading #19

JAX-vLLM offloading

JAX-vLLM offloading #19

name: JAX-vLLM offloading
on:
schedule:
- cron: '30 9 * * *' # Pacific Time 01:30 AM in UTC
pull_request:
types:
- opened
- reopened
- ready_for_review
- synchronize
paths:
- 'jax-inference-offloading/**'
- '.github/gke-workflow/jax-vllm-offloading/**'
- '.github/workflows/jax-vllm-offloading*.yml'
workflow_dispatch:
inputs:
PUBLISH:
type: boolean
description: Publish dated images and update the 'latest' tag?
default: false
required: false
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
permissions:
contents: read # to fetch code
actions: write # to cancel previous workflows
packages: write # to upload containers
jobs:
metadata:
runs-on: ubuntu-22.04
outputs:
BUILD_DATE: ${{ steps.date.outputs.BUILD_DATE }}
PUBLISH: ${{ steps.if-publish.outputs.PUBLISH }}
steps:
- name: Set build date
id: date
shell: bash -x -e {0}
run: |
BUILD_DATE=$(TZ='US/Los_Angeles' date '+%Y-%m-%d')
echo "BUILD_DATE=${BUILD_DATE}" >> $GITHUB_OUTPUT
- name: Determine whether results will be 'published'
id: if-publish
shell: bash -x -e {0}
run: |
echo "PUBLISH=${{ github.event_name == 'schedule' || inputs.PUBLISH }}" >> $GITHUB_OUTPUT
amd64:
needs: metadata
runs-on: [self-hosted, amd64, small]
outputs:
DOCKER_TAG_MEALKIT: ${{ steps.build-container.outputs.DOCKER_TAG_MEALKIT }}
DOCKER_TAG_FINAL: ${{ steps.build-container.outputs.DOCKER_TAG_FINAL }}
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Build container
id: build-container
uses: ./.github/actions/build-container
with:
ARCHITECTURE: amd64
ARTIFACT_NAME: artifact-jio-build
BADGE_FILENAME: badge-jio-build
BASE_IMAGE: nvcr.io/nvidia/cuda-dl-base:25.06-cuda12.9-devel-ubuntu24.04
BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }}
CONTAINER_NAME: jio
DOCKERFILE: jax-inference-offloading/dockerfile/oss.dockerfile
RUNNER_SIZE: small
ssh-private-key: ${{ secrets.SSH_PRIVATE_KEY }}
ssh-known-hosts: ${{ vars.SSH_KNOWN_HOSTS }}
github-token: ${{ secrets.GITHUB_TOKEN }}
EXTRA_BUILD_ARGS: |
REF_JIO=${{ github.ref }}
arm64:
needs: metadata
runs-on: [self-hosted, arm64, small]
outputs:
DOCKER_TAG_MEALKIT: ${{ steps.build-container.outputs.DOCKER_TAG_MEALKIT }}
DOCKER_TAG_FINAL: ${{ steps.build-container.outputs.DOCKER_TAG_FINAL }}
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Build container
id: build-container
uses: ./.github/actions/build-container
with:
ARCHITECTURE: arm64
ARTIFACT_NAME: artifact-jio-build
BADGE_FILENAME: badge-jio-build
BASE_IMAGE: nvcr.io/nvidia/cuda-dl-base:25.06-cuda12.9-devel-ubuntu24.04
BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }}
CONTAINER_NAME: jio
DOCKERFILE: jax-inference-offloading/dockerfile/oss.dockerfile
RUNNER_SIZE: small
ssh-private-key: ${{ secrets.SSH_PRIVATE_KEY }}
ssh-known-hosts: ${{ vars.SSH_KNOWN_HOSTS }}
github-token: ${{ secrets.GITHUB_TOKEN }}
EXTRA_BUILD_ARGS: |
REF_JIO=${{ github.ref }}
collect-docker-tags:
runs-on: ubuntu-22.04
if: ${{ !cancelled() }}
needs:
- amd64
- arm64
outputs:
TAGS: ${{ steps.collect-tags.outputs.TAGS }}
steps:
- name: Save docker tags as a JSON object
id: collect-tags
run: |
TAGS=$(cat <<EOF | jq -c
[\
{"flavor": "jax-inference-offloading", "stage": "final", "priority": 900, "tag": "${{ needs.amd64.outputs.DOCKER_TAG_FINAL }}"},\
{"flavor": "jax-inference-offloading-mealkit", "stage": "mealkit", "priority": 500, "tag": "${{ needs.amd64.outputs.DOCKER_TAG_MEALKIT }}"},\
{"flavor": "jax-inference-offloading", "stage": "final", "priority": 900, "tag": "${{ needs.arm64.outputs.DOCKER_TAG_FINAL }}"},\
{"flavor": "jax-inference-offloading-mealkit", "stage": "mealkit", "priority": 500, "tag": "${{ needs.arm64.outputs.DOCKER_TAG_MEALKIT }}"},\
{}\
]
EOF
)
echo "TAGS=${TAGS}" >> $GITHUB_OUTPUT
make-publish-configs:
runs-on: ubuntu-22.04
if: ${{ !cancelled() }}
env:
IMAGE_REPO: ${{ needs.metadata.outputs.PUBLISH == 'true' && 'jax' || 'mock-jax' }}
needs:
- metadata
- collect-docker-tags
outputs:
PUBLISH_CONFIGS: ${{ steps.generate-configs.outputs.PUBLISH_CONFIGS }}
steps:
- id: generate-configs
shell: bash -eu -o pipefail {0}
run: |
declare -a FLAVORS=(
jax-inference-offloading
jax-inference-offloading-mealkit
)
## create JSON specs for a 1D matrix of container publication jobs
ALL_TAGS=$(
echo '${{ needs.collect-docker-tags.outputs.TAGS }}' | jq -s 'add'
)
PUBLISH_CONFIGS='[]'
for flavor in "${FLAVORS[@]}";do
# collect images for different platforms, e.g. amd64 and arm64
matching_tags=$(
echo "$ALL_TAGS" |\
jq -c ".[] | select(.flavor == \"${flavor}\" and .tag != \"\")"
)
# source_image is a list of all platform-specific tags
source_image=$(echo "${matching_tags}" | jq -c "[.tag]" | jq -s 'add')
# if the build job failed without producing any images, skip this flavor
n_source_images=$(echo "$source_image" | jq 'length')
if [[ $n_source_images -gt 0 ]]; then
# Determine stage from flavor name
if [[ "${flavor}" == *"-mealkit" ]]; then
stage="mealkit"
else
stage="final"
fi
echo "PUBLISH image $flavor with $n_source_images containers"
# tag priority is the highest priority of all platform-specific tags
priority=$(echo "${matching_tags}" | jq -r ".priority" | jq -s 'max')
# All images go to the same repository (jax or mock-jax)
target_image=${IMAGE_REPO}
PUBLISH_CONFIGS=$(
echo ${PUBLISH_CONFIGS} | jq -c ". + [{
\"flavor\": \"${flavor}\",
\"target_image\": \"${target_image}\",
\"priority\": \"${priority}\",
\"source_image\": ${source_image},
\"stage\": \"${stage}\"
}]"
)
else
echo "SKIPPED image $flavor with 0 containers"
fi
done
PUBLISH_CONFIGS=$(echo "$PUBLISH_CONFIGS" | jq -c '{"config": .}')
echo ${PUBLISH_CONFIGS} | jq
echo "PUBLISH_CONFIGS=${PUBLISH_CONFIGS}" >> $GITHUB_OUTPUT
publish-containers:
needs:
- metadata
- make-publish-configs
if: ${{ !cancelled() && needs.make-publish-configs.outputs.PUBLISH_CONFIGS.config != '{"config":[]}' }}
strategy:
fail-fast: false
matrix: ${{ fromJson(needs.make-publish-configs.outputs.PUBLISH_CONFIGS) }}
uses: ./.github/workflows/_publish_container.yaml
with:
ARTIFACT_NAME: ${{ matrix.config.stage }}-${{ matrix.config.flavor }}
ARTIFACT_TAG: ${{ matrix.config.flavor }}-${{ needs.metadata.outputs.BUILD_DATE }}
SOURCE_IMAGE: ${{ join(matrix.config.source_image, ' ') }}
TARGET_IMAGE: ${{ matrix.config.target_image }}
TARGET_TAGS: |
type=raw,value=${{ matrix.config.flavor }},priority=${{ matrix.config.priority }}
type=raw,value=${{ matrix.config.flavor }}-${{ needs.metadata.outputs.BUILD_DATE }},priority=${{ matrix.config.priority }}
finalize:
needs: [metadata, amd64, arm64, publish-containers]
if: "!cancelled()"
uses: ./.github/workflows/_finalize.yaml
with:
BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }}
PUBLISH_BADGE: ${{ needs.metadata.outputs.PUBLISH == 'true' }}
secrets: inherit
transfer-gke-xpk:
uses: ./.github/workflows/jax-vllm-offloading-gke-transfer.yml
needs: amd64
with:
JAX_VLLM_OFFLOADING_IMAGE: ${{ needs.amd64.outputs.DOCKER_TAG_FINAL }}
grpo-gke-xpk:
uses: ./.github/workflows/jax-vllm-offloading-gke-grpo.yml
needs: amd64
with:
JAX_VLLM_OFFLOADING_IMAGE: ${{ needs.amd64.outputs.DOCKER_TAG_FINAL }}
jax-vllm-offloading-transfer-eks:
needs: amd64
runs-on: eks
env:
JOB_NAME: jax-vllm-offloading-${{ github.run_id }}
JAX_VLLM_OFFLOADING_IMAGE: ${{ needs.amd64.outputs.DOCKER_TAG_FINAL }}
MODEL: "meta-llama/Llama-3.1-8B-Instruct"
steps:
- uses: actions/checkout@v4
- name: Login to GitHub Container Registry
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: K8s GHCR store and delete token
id: store-token
uses: ./.github/actions/store-delete-k8s-ghcr
- name: Configure jax-vllm job
run: |
yq -i '
# 1. Update the JobSet Name
.metadata.name = strenv(JOB_NAME)
# 2. Update Image (Applies to Gateway, vLLM, and JAX nodes)
| .spec.replicatedJobs[].template.spec.template.spec.containers[].image = strenv(JAX_VLLM_OFFLOADING_IMAGE)
# 3. Update Model Name (Finds the env var named "MODEL_NAME" and updates its value)
| (.spec.replicatedJobs[].template.spec.template.spec.containers[].env[] | select(.name == "MODEL_NAME").value) = strenv(MODEL)
# 4. Add imagePullSecrets to all replicatedJobs (Gateway, vLLM, and JAX nodes)
| .spec.replicatedJobs[].template.spec.template.spec.imagePullSecrets[].name = "${{ steps.store-token.outputs.token-name }}"
' .github/eks-workflow-files/jio-eks/jio-template.yaml
git diff .github/eks-workflow-files/jio-eks/jio-template.yaml
- name: Apply jax-vllm offloading job to EKS cluster
uses: ./.github/actions/submit-delete-k8s-jobset
with:
jobset-config-file: ".github/eks-workflow-files/jio-eks/jio-template.yaml"
jobset-name: ${{ env.JOB_NAME }}