# syntax=docker/dockerfile:1-labs ARG BASE_IMAGE=ghcr.io/nvidia/jax-mealkit:jax ARG URLREF_AXLEARN=https://github.com/apple/axlearn.git#main ARG SRC_PATH_AXLEARN=/opt/axlearn ARG DEST_MANIFEST_DIR=/opt/manifest.d ARG GIT_USER_NAME=JAX Toolbox ARG GIT_USER_EMAIL=jax@nvidia.com ############################################################################### ## Download source and configure dependencies ############################################################################### FROM ${BASE_IMAGE} AS mealkit ARG DEST_MANIFEST_DIR ARG SRC_PATH_AXLEARN ARG GIT_USER_NAME ARG GIT_USER_EMAIL ARG URLREF_AXLEARN ARG SRC_PATH_AXLEARN # Run the patch with cloning RUN <<"EOF" bash -exu git config --global user.email "${GIT_USER_EMAIL}" git config --global user.name "${GIT_USER_NAME}" git-clone.sh "${URLREF_AXLEARN}" "${SRC_PATH_AXLEARN}" ${DEST_MANIFEST_DIR}/create-distribution.sh \ --manifest ${DEST_MANIFEST_DIR}/manifest.yaml \ --package axlearn cd ${SRC_PATH_AXLEARN} git remote remove origin EOF # these packages are needed to run axlearn tests # https://github.com/apple/axlearn/blob/main/pyproject.toml as reference WORKDIR /opt/axlearn RUN <<"EOF" bash -ex # do not include jax and jaxlib from axlearn pyproject.toml sed -i '/"jax==/d' ${SRC_PATH_AXLEARN}/pyproject.toml sed -i '/"jaxlib==/d' ${SRC_PATH_AXLEARN}/pyproject.toml sed -i '/"pyarrow=/d' ${SRC_PATH_AXLEARN}/pyproject.toml sed -i '/"tensorflow==/d' ${SRC_PATH_AXLEARN}/pyproject.toml sed -i '/"tensorflow-io==/d' ${SRC_PATH_AXLEARN}/pyproject.toml sed -i '/"tensorflow_text==/d' ${SRC_PATH_AXLEARN}/pyproject.toml sed -i '/"tensorflow-metadata==/d' ${SRC_PATH_AXLEARN}/pyproject.toml # Install core and dev dependencies from pyproject.toml echo "-e ${SRC_PATH_AXLEARN}[core]" > /opt/pip-tools.d/requirements-axlearn.in # Add packages missing from pyproject.toml cat <> /opt/pip-tools.d/requirements-axlearn.in tensorflow==2.20.0 tensorflow-text==2.20.0 pyarrow tensorflow-metadata tensorstore cloudpickle pytest pytest-xdist pytest-reportlog REQUIREMENTS EOF RUN echo "$(cat /opt/pip-tools.d/requirements-axlearn.in)" ############################################################################### ## Add test script to the path ############################################################################### ADD test-axlearn.sh fuji-train-perf.py /usr/local/bin/ ############################################################################### ## Install accumulated packages from the base image and the previous stage ############################################################################### FROM mealkit AS final RUN pip-finalize.sh WORKDIR ${SRC_PATH_AXLEARN}