Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions cli/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -3398,6 +3398,55 @@ This won't work unless you have an existing installation of our GitHub app on yo
},
Action: createCommandWithT[mlTrainingUpdateArgs](MLTrainingUpdateAction),
},
{
Name: "test-local",
Usage: "test training script locally using Docker",
UsageText: createUsageText("training-script test-local", []string{
trainFlagDatasetFile, trainFlagTrainingScriptDirectory,
trainFlagContainerVersion, trainFlagModelOutputDirectory,
}, true, false),
Description: `Test your training script locally before submitting to the cloud. This runs your training script
in a Docker container using the same environment as cloud training.

REQUIREMENTS:
- Docker must be installed and running
- Training script directory must contain setup.py and model/training.py
- Dataset must be in JSONL format

NOTES:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this goes in documentation? I also want to add that if the containers really slow, it could be because the dataset root or training script directory has a bunch of extra files. Apparently mounting volumes can be expensive.

- Training containers only support linux/x86_64 (amd64) architecture
- Ensure Docker Desktop has sufficient resources allocated (memory, CPU)
- Model output will be saved to the specified output directory on your host machine
`,
Flags: []cli.Flag{
&cli.StringFlag{
Name: trainFlagDatasetFile,
Usage: "path to the dataset file (JSONL format)",
Required: true,
},
&cli.StringFlag{
Name: trainFlagTrainingScriptDirectory,
Usage: "path to the training script directory (must contain setup.py and model/training.py)," +
" the container will be mounted to this directory",
Required: true,
},
&cli.StringFlag{
Name: trainFlagContainerVersion,
Usage: "container version to use (e.g., 'tf:2.16'). Defaults to tf:2.16",
Value: "tf:2.16",
},
&cli.StringFlag{
Name: trainFlagModelOutputDirectory,
Usage: "directory where the trained model will be saved. Defaults to current directory",
Value: ".",
},
&cli.StringSliceFlag{
Name: trainFlagCustomArgs,
Usage: "custom arguments to pass to the training script (format: key=value)",
},
},
Action: createCommandWithT[mlTrainingScriptTestLocalArgs](MLTrainingScriptTestLocalAction),
},
},
},
{
Expand Down
296 changes: 296 additions & 0 deletions cli/ml_training.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,15 @@ package cli
import (
"context"
"fmt"
"os"
"os/exec"
"os/signal"
"path/filepath"
"regexp"
"slices"
"strconv"
"strings"
"syscall"
"time"

"github.com/pkg/errors"
Expand All @@ -27,8 +34,17 @@ const (
trainFlagModelLabels = "model-labels"

trainingStatusPrefix = "TRAINING_STATUS_"

// Flags for test-local command
trainFlagContainerVersion = "container-version"
trainFlagDatasetFile = "dataset-file"
trainFlagModelOutputDirectory = "model-output-directory"
trainFlagCustomArgs = "custom-args"
trainFlagTrainingScriptDirectory = "training-script-directory"
)

var dockerVertexImageRegex = regexp.MustCompile(`^us-docker\.pkg\.dev/vertex-ai/training/[^:@\s]+(:[^@\s]+)?(@sha256:[A-Fa-f0-9]{64})?$`)

type mlSubmitCustomTrainingJobArgs struct {
DatasetID string
OrgID string
Expand Down Expand Up @@ -597,3 +613,283 @@ func convertVisibilityToProto(visibility string) (*v1.Visibility, error) {

return &visibilityProto, nil
}

type mlTrainingScriptTestLocalArgs struct {
ContainerVersion string
DatasetFile string
ModelOutputDirectory string
TrainingScriptDirectory string
CustomArgs []string
}

// MLTrainingScriptTestLocalAction runs training locally in a Docker container.
func MLTrainingScriptTestLocalAction(c *cli.Context, args mlTrainingScriptTestLocalArgs) error {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is probably gonna take awhile. I think maybe we should add info logging all over the place so people know that it's working.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only part of this function that takes a significant amount of time is actually running the docker command to run training at the bottom and the logs from the docker container show up in the terminal. I also added timeouts for checkDockerAvailable so I think that should be fine? I'd suggest running this command, sorry I know it's annoying.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue here is that the container is really really really big. In the office downloading this can take 30 minutes. I know that it logs the docker stuff but giving people a heads up somewhere would be helpful.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oop that is a really good point, thanks!

client, err := newViamClient(c)
if err != nil {
return err
}

// Check if Docker is available
if err := checkDockerAvailable(); err != nil {
return err
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, just adding a "Docker is available. Using image BLAH"


// Validate required paths exist
if err := validateLocalTrainingPaths(args.TrainingScriptDirectory, args.DatasetFile); err != nil {
return err
}
// Get absolute paths for volume mounting
scriptDirAbs, err := filepath.Abs(args.TrainingScriptDirectory)
if err != nil {
return errors.Wrapf(err, "failed to get absolute path for training script directory")
}

datasetFileAbs, err := filepath.Abs(args.DatasetFile)
if err != nil {
return errors.Wrapf(err, "failed to get absolute path for dataset file")
}

// Determine output directory (default to current directory if not specified)
outputDir := args.ModelOutputDirectory
if outputDir == "" {
outputDir = "."
}
outputDirAbs, err := filepath.Abs(outputDir)
if err != nil {
return errors.Wrapf(err, "failed to get absolute path for model output directory")
}

// Ensure output directory exists
if err := os.MkdirAll(outputDirAbs, 0o750); err != nil {
return errors.Wrapf(err, "failed to create model output directory")
}

// Create temporary script to run inside container
scriptContent, err := createTrainingScript(args.CustomArgs)
if err != nil {
return err
}

tmpScript, err := os.CreateTemp("", "viam-training-*.sh")
if err != nil {
return errors.Wrap(err, "failed to create temporary script file")
}
//nolint:errcheck
defer os.Remove(tmpScript.Name())

if _, err := tmpScript.WriteString(scriptContent); err != nil {
return errors.Wrap(err, "failed to write to temporary script file")
}
if err := tmpScript.Close(); err != nil {
return errors.Wrap(err, "failed to close temporary script file")
}

//nolint:gosec
if err := os.Chmod(tmpScript.Name(), 0o700); err != nil {
return errors.Wrap(err, "failed to make script executable")
}

// Get container image name
containerImageURI, err := getContainerImageURI(client, args.ContainerVersion)
if err != nil {
return err
}

// Build docker run command
// NOTE: Google Vertex AI training containers are only available for linux/x86_64 (amd64).
// On ARM systems (e.g., Apple Silicon Macs), Docker will use Rosetta 2 emulation which
// may be slower but ensures compatibility with the same containers used in cloud training.
dockerArgs := []string{
"run",
"-i", // Interactive mode to ensure signals are properly handled
"--entrypoint", "/bin/bash",
"--platform", "linux/x86_64",
"--rm",
"-v", fmt.Sprintf("%s:/training_script", scriptDirAbs),
"-v", fmt.Sprintf("%s:/dataset.jsonl", datasetFileAbs),
"-v", fmt.Sprintf("%s:/model_output", outputDirAbs),
"-v", fmt.Sprintf("%s:/run_training.sh", tmpScript.Name()),
"-w", "/training_script",
containerImageURI,
"/run_training.sh",
}

// Setup context with signal handling for Ctrl+C
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()

//nolint:gosec
cmd := exec.CommandContext(ctx, "docker", dockerArgs...)
cmd.Stdout = c.App.Writer
cmd.Stderr = c.App.ErrWriter

if err := cmd.Run(); err != nil {
// Check if the command was interrupted
if ctx.Err() == context.Canceled {
printf(c.App.Writer, "\nTraining interrupted by user")
return errors.New("training interrupted")
}

// Provide additional context for platform-related errors
errMsg := err.Error()
if strings.Contains(errMsg, "platform") || strings.Contains(errMsg, "architecture") {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is kind of ugly, but I thought it was worth including. it's also in the command description so potentially removable.

return errors.Wrap(err, "failed to run training in Docker container. "+
"Note: Training containers only support linux/x86_64 (amd64). "+
"On ARM systems, ensure Docker Desktop is configured to enable Rosetta 2 emulation for x86_64 containers")
}

return errors.Wrap(err, "failed to run training in Docker container")
}

return nil
}

// checkDockerAvailable checks if Docker is installed and running.
func checkDockerAvailable() error {
// Create a context with timeout for Docker commands
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

cmd := exec.CommandContext(ctx, "docker", "--version")
if err := cmd.Run(); err != nil {
if ctx.Err() == context.DeadlineExceeded {
return errors.New("Docker command timed out. Please check if Docker is responding")
}
return errors.New("Docker is not available. Please install Docker and ensure it is running. " +
"Visit https://docs.docker.com/get-docker/ for installation instructions")
}

// Check if Docker daemon is running
ctx, cancel = context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

cmd = exec.CommandContext(ctx, "docker", "ps")
if err := cmd.Run(); err != nil {
if ctx.Err() == context.DeadlineExceeded {
return errors.New("Docker daemon is not responding. It may be starting up - please wait and try again")
}
return errors.New("Docker daemon is not running. Please start Docker and try again")
}

return nil
}

// validateLocalTrainingPaths validates that required paths exist.
func validateLocalTrainingPaths(scriptDir, datasetFile string) error {
// Check training script directory exists
if _, err := os.Stat(scriptDir); os.IsNotExist(err) {
return errors.Errorf("training script directory does not exist: %s", scriptDir)
}

// Check for required files in training script directory
setupPyPath := filepath.Join(scriptDir, "setup.py")
if _, err := os.Stat(setupPyPath); os.IsNotExist(err) {
return errors.Errorf("setup.py not found in training script directory: %s", scriptDir)
}

trainingPyPath := filepath.Join(scriptDir, "model", "training.py")
if _, err := os.Stat(trainingPyPath); os.IsNotExist(err) {
return errors.Errorf("model/training.py not found in training script directory: %s", scriptDir)
}

// Check dataset file exists
if _, err := os.Stat(datasetFile); os.IsNotExist(err) {
return errors.Errorf("dataset file does not exist: %s", datasetFile)
}

return nil
}

// createTrainingScript creates the shell script content to run inside the container.
func createTrainingScript(customArgs []string) (string, error) {
// Validate custom arguments format (key=value) before building script
for _, arg := range customArgs {
if !strings.Contains(arg, "=") {
return "", errors.Errorf("invalid custom argument format: %s (expected key=value)", arg)
}

// Validate that the key portion only contains safe characters
parts := strings.SplitN(arg, "=", 2)
key := parts[0]
if !isValidArgumentKey(key) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to clean the input, let me know if there's a better way to do this.

return "", errors.Errorf("invalid argument key: %s (only alphanumeric characters, underscores, and hyphens are allowed)", key)
}
}

var script strings.Builder

// Script header and setup
script.WriteString("#!/bin/bash\n")
script.WriteString("set -e\n\n")
script.WriteString("echo \"Installing training script package...\"\n")
script.WriteString("pip3 install --no-cache-dir .\n\n")
script.WriteString("echo \"Running training...\"\n")

// Build Python training command
script.WriteString("python3 -m model.training")
script.WriteString(" --dataset_file=/dataset.jsonl")
script.WriteString(" --model_output_directory=/model_output")

// Add custom arguments
for _, arg := range customArgs {
parts := strings.SplitN(arg, "=", 2)
key := parts[0]
value := parts[1]

script.WriteString(" --")
script.WriteString(key)
script.WriteString("=")
// Use strconv.Quote for proper shell quoting - it returns a double-quoted string
// with all special characters properly escaped
script.WriteString(strconv.Quote(value))
}
script.WriteString("\n\n")

// Script footer
script.WriteString("echo \"Training completed successfully!\"\n")

return script.String(), nil
}

// isValidArgumentKey validates that an argument key only contains safe characters.
// Allowed characters: letters (a-z, A-Z), digits (0-9), underscores (_), and hyphens (-).
func isValidArgumentKey(key string) bool {
if key == "" {
return false
}

for _, char := range key {
if !((char >= 'a' && char <= 'z') ||
(char >= 'A' && char <= 'Z') ||
(char >= '0' && char <= '9') ||
char == '_' ||
char == '-') {
return false
}
}

return true
}

// getContainerImageURI returns the full container image URI based on the version.
func getContainerImageURI(c *viamClient, version string) (string, error) {
res, err := c.mlTrainingClient.ListSupportedContainers(context.Background(), &mltrainingpb.ListSupportedContainersRequest{})
if err != nil {
return "", errors.Wrapf(err, "failed to list supported containers")
}

containerKeyList := []string{}
for key := range res.ContainerMap {
containerKeyList = append(containerKeyList, key)
}
slices.Sort(containerKeyList)

container, ok := res.ContainerMap[version]
if !ok {
if dockerVertexImageRegex.MatchString(version) {
return version, nil
}
return "", errors.Errorf("container version %s not found. Supported versions: %s", version, strings.Join(containerKeyList, ", "))
}
return container.Uri, nil
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ require (
go.uber.org/atomic v1.11.0
go.uber.org/multierr v1.11.0
go.uber.org/zap v1.27.0
go.viam.com/api v0.1.487
go.viam.com/api v0.1.494
go.viam.com/test v1.2.4
go.viam.com/utils v0.1.176
goji.io v2.0.2+incompatible
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -1277,8 +1277,8 @@ go.uber.org/zap v1.18.1/go.mod h1:xg/QME4nWcxGxrpdeYfq7UvYrLh66cuVKdrbD1XF/NI=
go.uber.org/zap v1.23.0/go.mod h1:D+nX8jyLsMHMYrln8A0rJjFt/T/9/bGgIhAqxv5URuY=
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
go.viam.com/api v0.1.487 h1:KdX0vQzZ6j/3YgUX8B98a3QgQ1e3AxD/RkN+GFvmJhg=
go.viam.com/api v0.1.487/go.mod h1:p/am76zx8SZ74V/F4rEAYQIpHaaLUwJgY2q3Uw3FIWk=
go.viam.com/api v0.1.494 h1:GcMyqmN9WuwPL59RPmAFqJ9yed0R6mAmIQblD3eWGgk=
go.viam.com/api v0.1.494/go.mod h1:p/am76zx8SZ74V/F4rEAYQIpHaaLUwJgY2q3Uw3FIWk=
go.viam.com/test v1.2.4 h1:JYgZhsuGAQ8sL9jWkziAXN9VJJiKbjoi9BsO33TW3ug=
go.viam.com/test v1.2.4/go.mod h1:zI2xzosHdqXAJ/kFqcN+OIF78kQuTV2nIhGZ8EzvaJI=
go.viam.com/utils v0.1.176 h1:I5TvnuBZtE9i3e6j1VOBzWQ6W3nlUiOR/L4WpTMFhxg=
Expand Down
Loading