-
Notifications
You must be signed in to change notification settings - Fork 126
APP-10775: viam training-script test-local #5524
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
fdff64b
fb04bd4
214911d
9c22fec
c76a430
a8f8f8d
3f4da57
e28a0d1
753882e
390ce70
704f696
1fad4ba
001c886
994b80e
c9d36d7
b6efdeb
916fa04
4bc5312
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,8 +3,15 @@ package cli | |
| import ( | ||
| "context" | ||
| "fmt" | ||
| "os" | ||
| "os/exec" | ||
| "os/signal" | ||
| "path/filepath" | ||
| "regexp" | ||
| "slices" | ||
| "strconv" | ||
| "strings" | ||
| "syscall" | ||
| "time" | ||
|
|
||
| "github.com/pkg/errors" | ||
|
|
@@ -27,8 +34,17 @@ const ( | |
| trainFlagModelLabels = "model-labels" | ||
|
|
||
| trainingStatusPrefix = "TRAINING_STATUS_" | ||
|
|
||
| // Flags for test-local command | ||
| trainFlagContainerVersion = "container-version" | ||
| trainFlagDatasetFile = "dataset-file" | ||
| trainFlagModelOutputDirectory = "model-output-directory" | ||
| trainFlagCustomArgs = "custom-args" | ||
| trainFlagTrainingScriptDirectory = "training-script-directory" | ||
| ) | ||
|
|
||
| var dockerVertexImageRegex = regexp.MustCompile(`^us-docker\.pkg\.dev/vertex-ai/training/[^:@\s]+(:[^@\s]+)?(@sha256:[A-Fa-f0-9]{64})?$`) | ||
|
|
||
| type mlSubmitCustomTrainingJobArgs struct { | ||
| DatasetID string | ||
| OrgID string | ||
|
|
@@ -597,3 +613,283 @@ func convertVisibilityToProto(visibility string) (*v1.Visibility, error) { | |
|
|
||
| return &visibilityProto, nil | ||
| } | ||
|
|
||
| type mlTrainingScriptTestLocalArgs struct { | ||
| ContainerVersion string | ||
| DatasetFile string | ||
| ModelOutputDirectory string | ||
| TrainingScriptDirectory string | ||
| CustomArgs []string | ||
| } | ||
|
|
||
| // MLTrainingScriptTestLocalAction runs training locally in a Docker container. | ||
| func MLTrainingScriptTestLocalAction(c *cli.Context, args mlTrainingScriptTestLocalArgs) error { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function is probably gonna take awhile. I think maybe we should add info logging all over the place so people know that it's working.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The only part of this function that takes a significant amount of time is actually running the docker command to run training at the bottom and the logs from the docker container show up in the terminal. I also added timeouts for checkDockerAvailable so I think that should be fine? I'd suggest running this command, sorry I know it's annoying.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The issue here is that the container is really really really big. In the office downloading this can take 30 minutes. I know that it logs the docker stuff but giving people a heads up somewhere would be helpful.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oop that is a really good point, thanks! |
||
| client, err := newViamClient(c) | ||
| if err != nil { | ||
| return err | ||
| } | ||
|
|
||
| // Check if Docker is available | ||
| if err := checkDockerAvailable(); err != nil { | ||
| return err | ||
| } | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For example, just adding a "Docker is available. Using image BLAH" |
||
|
|
||
| // Validate required paths exist | ||
| if err := validateLocalTrainingPaths(args.TrainingScriptDirectory, args.DatasetFile); err != nil { | ||
| return err | ||
| } | ||
| // Get absolute paths for volume mounting | ||
| scriptDirAbs, err := filepath.Abs(args.TrainingScriptDirectory) | ||
| if err != nil { | ||
| return errors.Wrapf(err, "failed to get absolute path for training script directory") | ||
| } | ||
|
|
||
| datasetFileAbs, err := filepath.Abs(args.DatasetFile) | ||
| if err != nil { | ||
| return errors.Wrapf(err, "failed to get absolute path for dataset file") | ||
| } | ||
|
|
||
| // Determine output directory (default to current directory if not specified) | ||
| outputDir := args.ModelOutputDirectory | ||
| if outputDir == "" { | ||
| outputDir = "." | ||
| } | ||
| outputDirAbs, err := filepath.Abs(outputDir) | ||
| if err != nil { | ||
| return errors.Wrapf(err, "failed to get absolute path for model output directory") | ||
| } | ||
|
|
||
| // Ensure output directory exists | ||
| if err := os.MkdirAll(outputDirAbs, 0o750); err != nil { | ||
| return errors.Wrapf(err, "failed to create model output directory") | ||
| } | ||
|
|
||
| // Create temporary script to run inside container | ||
| scriptContent, err := createTrainingScript(args.CustomArgs) | ||
| if err != nil { | ||
| return err | ||
| } | ||
|
|
||
| tmpScript, err := os.CreateTemp("", "viam-training-*.sh") | ||
| if err != nil { | ||
| return errors.Wrap(err, "failed to create temporary script file") | ||
| } | ||
| //nolint:errcheck | ||
| defer os.Remove(tmpScript.Name()) | ||
|
|
||
| if _, err := tmpScript.WriteString(scriptContent); err != nil { | ||
| return errors.Wrap(err, "failed to write to temporary script file") | ||
| } | ||
| if err := tmpScript.Close(); err != nil { | ||
| return errors.Wrap(err, "failed to close temporary script file") | ||
| } | ||
|
|
||
| //nolint:gosec | ||
| if err := os.Chmod(tmpScript.Name(), 0o700); err != nil { | ||
| return errors.Wrap(err, "failed to make script executable") | ||
| } | ||
|
|
||
| // Get container image name | ||
| containerImageURI, err := getContainerImageURI(client, args.ContainerVersion) | ||
| if err != nil { | ||
| return err | ||
| } | ||
|
|
||
| // Build docker run command | ||
| // NOTE: Google Vertex AI training containers are only available for linux/x86_64 (amd64). | ||
| // On ARM systems (e.g., Apple Silicon Macs), Docker will use Rosetta 2 emulation which | ||
| // may be slower but ensures compatibility with the same containers used in cloud training. | ||
| dockerArgs := []string{ | ||
| "run", | ||
| "-i", // Interactive mode to ensure signals are properly handled | ||
| "--entrypoint", "/bin/bash", | ||
| "--platform", "linux/x86_64", | ||
| "--rm", | ||
| "-v", fmt.Sprintf("%s:/training_script", scriptDirAbs), | ||
| "-v", fmt.Sprintf("%s:/dataset.jsonl", datasetFileAbs), | ||
| "-v", fmt.Sprintf("%s:/model_output", outputDirAbs), | ||
| "-v", fmt.Sprintf("%s:/run_training.sh", tmpScript.Name()), | ||
| "-w", "/training_script", | ||
| containerImageURI, | ||
| "/run_training.sh", | ||
| } | ||
|
|
||
| // Setup context with signal handling for Ctrl+C | ||
| ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) | ||
| defer stop() | ||
|
|
||
| //nolint:gosec | ||
| cmd := exec.CommandContext(ctx, "docker", dockerArgs...) | ||
| cmd.Stdout = c.App.Writer | ||
| cmd.Stderr = c.App.ErrWriter | ||
|
|
||
| if err := cmd.Run(); err != nil { | ||
| // Check if the command was interrupted | ||
| if ctx.Err() == context.Canceled { | ||
| printf(c.App.Writer, "\nTraining interrupted by user") | ||
| return errors.New("training interrupted") | ||
| } | ||
|
|
||
| // Provide additional context for platform-related errors | ||
| errMsg := err.Error() | ||
| if strings.Contains(errMsg, "platform") || strings.Contains(errMsg, "architecture") { | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is kind of ugly, but I thought it was worth including. it's also in the command description so potentially removable. |
||
| return errors.Wrap(err, "failed to run training in Docker container. "+ | ||
| "Note: Training containers only support linux/x86_64 (amd64). "+ | ||
| "On ARM systems, ensure Docker Desktop is configured to enable Rosetta 2 emulation for x86_64 containers") | ||
| } | ||
|
|
||
| return errors.Wrap(err, "failed to run training in Docker container") | ||
| } | ||
|
|
||
etai-shuchatowitz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return nil | ||
| } | ||
|
|
||
| // checkDockerAvailable checks if Docker is installed and running. | ||
| func checkDockerAvailable() error { | ||
| // Create a context with timeout for Docker commands | ||
| ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) | ||
| defer cancel() | ||
|
|
||
| cmd := exec.CommandContext(ctx, "docker", "--version") | ||
| if err := cmd.Run(); err != nil { | ||
| if ctx.Err() == context.DeadlineExceeded { | ||
| return errors.New("Docker command timed out. Please check if Docker is responding") | ||
| } | ||
| return errors.New("Docker is not available. Please install Docker and ensure it is running. " + | ||
| "Visit https://docs.docker.com/get-docker/ for installation instructions") | ||
| } | ||
|
|
||
| // Check if Docker daemon is running | ||
| ctx, cancel = context.WithTimeout(context.Background(), 10*time.Second) | ||
| defer cancel() | ||
|
|
||
| cmd = exec.CommandContext(ctx, "docker", "ps") | ||
| if err := cmd.Run(); err != nil { | ||
| if ctx.Err() == context.DeadlineExceeded { | ||
| return errors.New("Docker daemon is not responding. It may be starting up - please wait and try again") | ||
| } | ||
| return errors.New("Docker daemon is not running. Please start Docker and try again") | ||
| } | ||
|
|
||
| return nil | ||
| } | ||
|
|
||
| // validateLocalTrainingPaths validates that required paths exist. | ||
| func validateLocalTrainingPaths(scriptDir, datasetFile string) error { | ||
| // Check training script directory exists | ||
| if _, err := os.Stat(scriptDir); os.IsNotExist(err) { | ||
| return errors.Errorf("training script directory does not exist: %s", scriptDir) | ||
| } | ||
|
|
||
| // Check for required files in training script directory | ||
| setupPyPath := filepath.Join(scriptDir, "setup.py") | ||
| if _, err := os.Stat(setupPyPath); os.IsNotExist(err) { | ||
| return errors.Errorf("setup.py not found in training script directory: %s", scriptDir) | ||
| } | ||
|
|
||
| trainingPyPath := filepath.Join(scriptDir, "model", "training.py") | ||
| if _, err := os.Stat(trainingPyPath); os.IsNotExist(err) { | ||
| return errors.Errorf("model/training.py not found in training script directory: %s", scriptDir) | ||
| } | ||
|
|
||
| // Check dataset file exists | ||
| if _, err := os.Stat(datasetFile); os.IsNotExist(err) { | ||
| return errors.Errorf("dataset file does not exist: %s", datasetFile) | ||
| } | ||
|
|
||
| return nil | ||
| } | ||
|
|
||
| // createTrainingScript creates the shell script content to run inside the container. | ||
| func createTrainingScript(customArgs []string) (string, error) { | ||
| // Validate custom arguments format (key=value) before building script | ||
| for _, arg := range customArgs { | ||
| if !strings.Contains(arg, "=") { | ||
| return "", errors.Errorf("invalid custom argument format: %s (expected key=value)", arg) | ||
| } | ||
|
|
||
| // Validate that the key portion only contains safe characters | ||
| parts := strings.SplitN(arg, "=", 2) | ||
| key := parts[0] | ||
| if !isValidArgumentKey(key) { | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried to clean the input, let me know if there's a better way to do this. |
||
| return "", errors.Errorf("invalid argument key: %s (only alphanumeric characters, underscores, and hyphens are allowed)", key) | ||
| } | ||
| } | ||
|
|
||
| var script strings.Builder | ||
|
|
||
| // Script header and setup | ||
| script.WriteString("#!/bin/bash\n") | ||
| script.WriteString("set -e\n\n") | ||
| script.WriteString("echo \"Installing training script package...\"\n") | ||
etai-shuchatowitz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| script.WriteString("pip3 install --no-cache-dir .\n\n") | ||
| script.WriteString("echo \"Running training...\"\n") | ||
|
|
||
| // Build Python training command | ||
| script.WriteString("python3 -m model.training") | ||
| script.WriteString(" --dataset_file=/dataset.jsonl") | ||
| script.WriteString(" --model_output_directory=/model_output") | ||
|
|
||
| // Add custom arguments | ||
| for _, arg := range customArgs { | ||
| parts := strings.SplitN(arg, "=", 2) | ||
| key := parts[0] | ||
| value := parts[1] | ||
|
|
||
| script.WriteString(" --") | ||
| script.WriteString(key) | ||
| script.WriteString("=") | ||
| // Use strconv.Quote for proper shell quoting - it returns a double-quoted string | ||
| // with all special characters properly escaped | ||
| script.WriteString(strconv.Quote(value)) | ||
| } | ||
| script.WriteString("\n\n") | ||
|
|
||
| // Script footer | ||
| script.WriteString("echo \"Training completed successfully!\"\n") | ||
|
|
||
| return script.String(), nil | ||
| } | ||
|
|
||
| // isValidArgumentKey validates that an argument key only contains safe characters. | ||
| // Allowed characters: letters (a-z, A-Z), digits (0-9), underscores (_), and hyphens (-). | ||
| func isValidArgumentKey(key string) bool { | ||
| if key == "" { | ||
| return false | ||
| } | ||
|
|
||
| for _, char := range key { | ||
| if !((char >= 'a' && char <= 'z') || | ||
| (char >= 'A' && char <= 'Z') || | ||
| (char >= '0' && char <= '9') || | ||
| char == '_' || | ||
| char == '-') { | ||
| return false | ||
| } | ||
| } | ||
|
|
||
| return true | ||
| } | ||
|
|
||
| // getContainerImageURI returns the full container image URI based on the version. | ||
| func getContainerImageURI(c *viamClient, version string) (string, error) { | ||
| res, err := c.mlTrainingClient.ListSupportedContainers(context.Background(), &mltrainingpb.ListSupportedContainersRequest{}) | ||
| if err != nil { | ||
| return "", errors.Wrapf(err, "failed to list supported containers") | ||
| } | ||
|
|
||
| containerKeyList := []string{} | ||
| for key := range res.ContainerMap { | ||
| containerKeyList = append(containerKeyList, key) | ||
| } | ||
| slices.Sort(containerKeyList) | ||
|
|
||
| container, ok := res.ContainerMap[version] | ||
| if !ok { | ||
| if dockerVertexImageRegex.MatchString(version) { | ||
etai-shuchatowitz marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return version, nil | ||
| } | ||
| return "", errors.Errorf("container version %s not found. Supported versions: %s", version, strings.Join(containerKeyList, ", ")) | ||
| } | ||
| return container.Uri, nil | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this goes in documentation? I also want to add that if the containers really slow, it could be because the dataset root or training script directory has a bunch of extra files. Apparently mounting volumes can be expensive.