Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
first pass
  • Loading branch information
vpandiarajan20 committed Nov 21, 2025
commit fdff64b5cc05f76d64c91d017bc19c14f91c157c
33 changes: 33 additions & 0 deletions cli/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -3398,6 +3398,39 @@ This won't work unless you have an existing installation of our GitHub app on yo
},
Action: createCommandWithT[mlTrainingUpdateArgs](MLTrainingUpdateAction),
},
{
Name: "test-local",
Usage: "test training script locally using Docker",
UsageText: createUsageText("training-script test-local", []string{trainFlagDatasetFile, trainFlagTrainingScriptDirectory,
trainFlagContainerVersion, trainFlagModelOutputDirectory}, true, false),
Flags: []cli.Flag{
&cli.StringFlag{
Name: trainFlagDatasetFile,
Usage: "path to the dataset file (JSONL format)",
Required: true,
},
&cli.StringFlag{
Name: trainFlagTrainingScriptDirectory,
Usage: "path to the training script directory (must contain setup.py and model/training.py), the container will be mounted to this directory",
Required: true,
},
&cli.StringFlag{
Name: trainFlagContainerVersion,
Usage: "container version to use (e.g., 'tf:2.16'). Defaults to tf:2.16",
Value: "tf:2.16",
},
&cli.StringFlag{
Name: trainFlagModelOutputDirectory,
Usage: "directory where the trained model will be saved. Defaults to current directory",
Value: ".",
},
&cli.StringSliceFlag{
Name: trainFlagCustomArgs,
Usage: "custom arguments to pass to the training script (format: key=value)",
},
},
Action: createCommandWithT[mlTrainingScriptTestLocalArgs](MLTrainingScriptTestLocalAction),
},
},
},
{
Expand Down
248 changes: 248 additions & 0 deletions cli/ml_training.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,14 @@ package cli
import (
"context"
"fmt"
"os"
"os/exec"
"os/signal"
"path/filepath"
"regexp"
"slices"
"strings"
"syscall"
"time"

"github.com/pkg/errors"
Expand All @@ -27,8 +33,17 @@ const (
trainFlagModelLabels = "model-labels"

trainingStatusPrefix = "TRAINING_STATUS_"

// Flags for test-local command
trainFlagContainerVersion = "container-version"
trainFlagDatasetFile = "dataset-file"
trainFlagModelOutputDirectory = "model-output-directory"
trainFlagCustomArgs = "custom-args"
trainFlagTrainingScriptDirectory = "training-script-directory"
)

var dockerVertexImageRegex = regexp.MustCompile(`^us-docker\.pkg\.dev/vertex-ai/training/[^:@\s]+(:[^@\s]+)?(@sha256:[A-Fa-f0-9]{64})?$`)

type mlSubmitCustomTrainingJobArgs struct {
DatasetID string
OrgID string
Expand Down Expand Up @@ -597,3 +612,236 @@ func convertVisibilityToProto(visibility string) (*v1.Visibility, error) {

return &visibilityProto, nil
}

type mlTrainingScriptTestLocalArgs struct {
ContainerVersion string
DatasetFile string
ModelOutputDirectory string
TrainingScriptDirectory string
CustomArgs []string
}

// MLTrainingScriptTestLocalAction runs training locally in a Docker container.
func MLTrainingScriptTestLocalAction(c *cli.Context, args mlTrainingScriptTestLocalArgs) error {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is probably gonna take awhile. I think maybe we should add info logging all over the place so people know that it's working.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only part of this function that takes a significant amount of time is actually running the docker command to run training at the bottom and the logs from the docker container show up in the terminal. I also added timeouts for checkDockerAvailable so I think that should be fine? I'd suggest running this command, sorry I know it's annoying.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue here is that the container is really really really big. In the office downloading this can take 30 minutes. I know that it logs the docker stuff but giving people a heads up somewhere would be helpful.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oop that is a really good point, thanks!

client, err := newViamClient(c)
if err != nil {
return err
}

// Check if Docker is available
if err := checkDockerAvailable(); err != nil {
return err
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, just adding a "Docker is available. Using image BLAH"


// Validate required paths exist
if err := validateLocalTrainingPaths(args.TrainingScriptDirectory, args.DatasetFile); err != nil {
return err
}
// Get absolute paths for volume mounting
scriptDirAbs, err := filepath.Abs(args.TrainingScriptDirectory)
if err != nil {
return errors.Wrapf(err, "failed to get absolute path for training script directory")
}

datasetFileAbs, err := filepath.Abs(args.DatasetFile)
if err != nil {
return errors.Wrapf(err, "failed to get absolute path for dataset file")
}

// Determine output directory (default to current directory if not specified)
outputDir := args.ModelOutputDirectory
if outputDir == "" {
outputDir = "."
}
outputDirAbs, err := filepath.Abs(outputDir)
if err != nil {
return errors.Wrapf(err, "failed to get absolute path for model output directory")
}

// Ensure output directory exists
if err := os.MkdirAll(outputDirAbs, 0o755); err != nil {
return errors.Wrapf(err, "failed to create model output directory")
}

// Create temporary script to run inside container
scriptContent, err := createTrainingScript(args.CustomArgs, datasetFileAbs)
if err != nil {
return err
}

tmpScript, err := os.CreateTemp("", "viam-training-*.sh")
if err != nil {
return errors.Wrap(err, "failed to create temporary script file")
}
defer os.Remove(tmpScript.Name())

if _, err := tmpScript.WriteString(scriptContent); err != nil {
return errors.Wrap(err, "failed to write to temporary script file")
}
if err := tmpScript.Close(); err != nil {
return errors.Wrap(err, "failed to close temporary script file")
}

// Make script executable
if err := os.Chmod(tmpScript.Name(), 0o755); err != nil {
return errors.Wrap(err, "failed to make script executable")
}

// Get container image name
containerImageURI, err := getContainerImageURI(client, args.ContainerVersion)
if err != nil {
return err
}

// Build docker run command
dockerArgs := []string{
"run",
"-i", // Interactive mode to ensure signals are properly handled
"--entrypoint", "/bin/bash",
"--platform", "linux/x86_64",
"--rm",
"-v", fmt.Sprintf("%s:/training_script", scriptDirAbs),
"-v", fmt.Sprintf("%s:/dataset.jsonl", datasetFileAbs),
"-v", fmt.Sprintf("%s:/model_output", outputDirAbs),
"-v", fmt.Sprintf("%s:/run_training.sh", tmpScript.Name()),
"-w", "/training_script",
containerImageURI,
"/run_training.sh",
}

// Setup context with signal handling for Ctrl+C
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()

//nolint:gosec
cmd := exec.CommandContext(ctx, "docker", dockerArgs...)
cmd.Stdout = c.App.Writer
cmd.Stderr = c.App.ErrWriter

if err := cmd.Run(); err != nil {
// Check if the command was interrupted
if ctx.Err() == context.Canceled {
printf(c.App.Writer, "\nTraining interrupted by user")
return errors.New("training interrupted")
}
return errors.Wrap(err, "failed to run training in Docker container")
}

return nil
}

// checkDockerAvailable checks if Docker is installed and running.
func checkDockerAvailable() error {
// Create a context with timeout for Docker commands
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

cmd := exec.CommandContext(ctx, "docker", "--version")
if err := cmd.Run(); err != nil {
if ctx.Err() == context.DeadlineExceeded {
return errors.New("Docker command timed out. Please check if Docker is responding")
}
return errors.New("Docker is not available. Please install Docker and ensure it is running. " +
"Visit https://docs.docker.com/get-docker/ for installation instructions")
}

// Check if Docker daemon is running
ctx, cancel = context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

cmd = exec.CommandContext(ctx, "docker", "ps")
if err := cmd.Run(); err != nil {
if ctx.Err() == context.DeadlineExceeded {
return errors.New("Docker daemon is not responding. It may be starting up - please wait and try again")
}
return errors.New("Docker daemon is not running. Please start Docker and try again")
}

return nil
}

// validateLocalTrainingPaths validates that required paths exist.
func validateLocalTrainingPaths(scriptDir, datasetFile string) error {
// Check training script directory exists
if _, err := os.Stat(scriptDir); os.IsNotExist(err) {
return errors.Errorf("training script directory does not exist: %s", scriptDir)
}

// Check for required files in training script directory
setupPyPath := filepath.Join(scriptDir, "setup.py")
if _, err := os.Stat(setupPyPath); os.IsNotExist(err) {
return errors.Errorf("setup.py not found in training script directory: %s", scriptDir)
}

trainingPyPath := filepath.Join(scriptDir, "model", "training.py")
if _, err := os.Stat(trainingPyPath); os.IsNotExist(err) {
return errors.Errorf("model/training.py not found in training script directory: %s", scriptDir)
}

// Check dataset file exists
if _, err := os.Stat(datasetFile); os.IsNotExist(err) {
return errors.Errorf("dataset file does not exist: %s", datasetFile)
}

return nil
}

// createTrainingScript creates the shell script content to run inside the container.
func createTrainingScript(customArgs []string, datasetPath string) (string, error) {
script := `#!/bin/bash
set -e

echo "Installing training script package..."
pip3 install --no-cache-dir .

echo "Running training..."
`

// Build the python command with arguments
pythonCmd := "python3 -m model.training"

// Add dataset file argument
pythonCmd += " --dataset_file=/dataset.jsonl"

// Add model output directory argument
pythonCmd += " --model_output_directory=/model_output"

// Add custom arguments
for _, arg := range customArgs {
// Validate argument format (should be key=value)
if !strings.Contains(arg, "=") {
return "", errors.Errorf("invalid custom argument format: %s (expected key=value)", arg)
}
pythonCmd += " --" + arg
}

script += pythonCmd + "\n"
script += `
echo "Training completed successfully!"
`

return script, nil
}

// getContainerImageURI returns the full container image URI based on the version.
func getContainerImageURI(c *viamClient, version string) (string, error) {
res, err := c.mlTrainingClient.ListSupportedContainers(context.Background(), &mltrainingpb.ListSupportedContainersRequest{})
if err != nil {
return "", errors.Wrapf(err, "failed to list supported containers")
}

containerKeyList := []string{}
for key := range res.ContainerMap {
containerKeyList = append(containerKeyList, key)
}
slices.Sort(containerKeyList)

container, ok := res.ContainerMap[version]
if !ok {
if dockerVertexImageRegex.MatchString(version) {
return version, nil
}
return "", errors.Errorf("container version %s not found. Supported versions: %s", version, strings.Join(containerKeyList, ", "))
}
return container.Uri, nil
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ require (
go.uber.org/atomic v1.11.0
go.uber.org/multierr v1.11.0
go.uber.org/zap v1.27.0
go.viam.com/api v0.1.487
go.viam.com/api v0.1.494
go.viam.com/test v1.2.4
go.viam.com/utils v0.1.176
goji.io v2.0.2+incompatible
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -1277,8 +1277,8 @@ go.uber.org/zap v1.18.1/go.mod h1:xg/QME4nWcxGxrpdeYfq7UvYrLh66cuVKdrbD1XF/NI=
go.uber.org/zap v1.23.0/go.mod h1:D+nX8jyLsMHMYrln8A0rJjFt/T/9/bGgIhAqxv5URuY=
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
go.viam.com/api v0.1.487 h1:KdX0vQzZ6j/3YgUX8B98a3QgQ1e3AxD/RkN+GFvmJhg=
go.viam.com/api v0.1.487/go.mod h1:p/am76zx8SZ74V/F4rEAYQIpHaaLUwJgY2q3Uw3FIWk=
go.viam.com/api v0.1.494 h1:GcMyqmN9WuwPL59RPmAFqJ9yed0R6mAmIQblD3eWGgk=
go.viam.com/api v0.1.494/go.mod h1:p/am76zx8SZ74V/F4rEAYQIpHaaLUwJgY2q3Uw3FIWk=
go.viam.com/test v1.2.4 h1:JYgZhsuGAQ8sL9jWkziAXN9VJJiKbjoi9BsO33TW3ug=
go.viam.com/test v1.2.4/go.mod h1:zI2xzosHdqXAJ/kFqcN+OIF78kQuTV2nIhGZ8EzvaJI=
go.viam.com/utils v0.1.176 h1:I5TvnuBZtE9i3e6j1VOBzWQ6W3nlUiOR/L4WpTMFhxg=
Expand Down
Loading