-
Notifications
You must be signed in to change notification settings - Fork 126
APP-10775: viam training-script test-local #5524
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 1 commit
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
fdff64b
first pass
vpandiarajan20 fb04bd4
security and linting
vpandiarajan20 214911d
descrip text
vpandiarajan20 9c22fec
Merge branch 'viamrobotics:main' into APP-10775
vpandiarajan20 c76a430
dataset paths
vpandiarajan20 a8f8f8d
CLI help text
vpandiarajan20 3f4da57
lint
vpandiarajan20 e28a0d1
clean up
vpandiarajan20 753882e
Merge branch 'main' into APP-10775
vpandiarajan20 390ce70
Merge branch 'main' into APP-10775
vpandiarajan20 704f696
warning message
vpandiarajan20 1fad4ba
warning message 2
vpandiarajan20 001c886
change help text
vpandiarajan20 994b80e
remove validation
vpandiarajan20 c9d36d7
lint
vpandiarajan20 b6efdeb
to make windows work
vpandiarajan20 916fa04
Merge branch 'viamrobotics:main' into APP-10775
vpandiarajan20 4bc5312
help text
vpandiarajan20 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next
Next commit
first pass
- Loading branch information
commit fdff64b5cc05f76d64c91d017bc19c14f91c157c
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,8 +3,14 @@ package cli | |
| import ( | ||
| "context" | ||
| "fmt" | ||
| "os" | ||
| "os/exec" | ||
| "os/signal" | ||
| "path/filepath" | ||
| "regexp" | ||
| "slices" | ||
| "strings" | ||
| "syscall" | ||
| "time" | ||
|
|
||
| "github.com/pkg/errors" | ||
|
|
@@ -27,8 +33,17 @@ const ( | |
| trainFlagModelLabels = "model-labels" | ||
|
|
||
| trainingStatusPrefix = "TRAINING_STATUS_" | ||
|
|
||
| // Flags for test-local command | ||
| trainFlagContainerVersion = "container-version" | ||
| trainFlagDatasetFile = "dataset-file" | ||
| trainFlagModelOutputDirectory = "model-output-directory" | ||
| trainFlagCustomArgs = "custom-args" | ||
| trainFlagTrainingScriptDirectory = "training-script-directory" | ||
| ) | ||
|
|
||
| var dockerVertexImageRegex = regexp.MustCompile(`^us-docker\.pkg\.dev/vertex-ai/training/[^:@\s]+(:[^@\s]+)?(@sha256:[A-Fa-f0-9]{64})?$`) | ||
|
|
||
| type mlSubmitCustomTrainingJobArgs struct { | ||
| DatasetID string | ||
| OrgID string | ||
|
|
@@ -597,3 +612,236 @@ func convertVisibilityToProto(visibility string) (*v1.Visibility, error) { | |
|
|
||
| return &visibilityProto, nil | ||
| } | ||
|
|
||
| type mlTrainingScriptTestLocalArgs struct { | ||
| ContainerVersion string | ||
| DatasetFile string | ||
| ModelOutputDirectory string | ||
| TrainingScriptDirectory string | ||
| CustomArgs []string | ||
| } | ||
|
|
||
| // MLTrainingScriptTestLocalAction runs training locally in a Docker container. | ||
| func MLTrainingScriptTestLocalAction(c *cli.Context, args mlTrainingScriptTestLocalArgs) error { | ||
| client, err := newViamClient(c) | ||
| if err != nil { | ||
| return err | ||
| } | ||
|
|
||
| // Check if Docker is available | ||
| if err := checkDockerAvailable(); err != nil { | ||
| return err | ||
| } | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For example, just adding a "Docker is available. Using image BLAH" |
||
|
|
||
| // Validate required paths exist | ||
| if err := validateLocalTrainingPaths(args.TrainingScriptDirectory, args.DatasetFile); err != nil { | ||
| return err | ||
| } | ||
| // Get absolute paths for volume mounting | ||
| scriptDirAbs, err := filepath.Abs(args.TrainingScriptDirectory) | ||
| if err != nil { | ||
| return errors.Wrapf(err, "failed to get absolute path for training script directory") | ||
| } | ||
|
|
||
| datasetFileAbs, err := filepath.Abs(args.DatasetFile) | ||
| if err != nil { | ||
| return errors.Wrapf(err, "failed to get absolute path for dataset file") | ||
| } | ||
|
|
||
| // Determine output directory (default to current directory if not specified) | ||
| outputDir := args.ModelOutputDirectory | ||
| if outputDir == "" { | ||
| outputDir = "." | ||
| } | ||
| outputDirAbs, err := filepath.Abs(outputDir) | ||
| if err != nil { | ||
| return errors.Wrapf(err, "failed to get absolute path for model output directory") | ||
| } | ||
|
|
||
| // Ensure output directory exists | ||
| if err := os.MkdirAll(outputDirAbs, 0o755); err != nil { | ||
| return errors.Wrapf(err, "failed to create model output directory") | ||
| } | ||
|
|
||
| // Create temporary script to run inside container | ||
| scriptContent, err := createTrainingScript(args.CustomArgs, datasetFileAbs) | ||
| if err != nil { | ||
| return err | ||
| } | ||
|
|
||
| tmpScript, err := os.CreateTemp("", "viam-training-*.sh") | ||
| if err != nil { | ||
| return errors.Wrap(err, "failed to create temporary script file") | ||
| } | ||
| defer os.Remove(tmpScript.Name()) | ||
|
|
||
| if _, err := tmpScript.WriteString(scriptContent); err != nil { | ||
| return errors.Wrap(err, "failed to write to temporary script file") | ||
| } | ||
| if err := tmpScript.Close(); err != nil { | ||
| return errors.Wrap(err, "failed to close temporary script file") | ||
| } | ||
|
|
||
| // Make script executable | ||
| if err := os.Chmod(tmpScript.Name(), 0o755); err != nil { | ||
| return errors.Wrap(err, "failed to make script executable") | ||
| } | ||
|
|
||
| // Get container image name | ||
| containerImageURI, err := getContainerImageURI(client, args.ContainerVersion) | ||
| if err != nil { | ||
| return err | ||
| } | ||
|
|
||
| // Build docker run command | ||
| dockerArgs := []string{ | ||
| "run", | ||
| "-i", // Interactive mode to ensure signals are properly handled | ||
| "--entrypoint", "/bin/bash", | ||
| "--platform", "linux/x86_64", | ||
| "--rm", | ||
| "-v", fmt.Sprintf("%s:/training_script", scriptDirAbs), | ||
| "-v", fmt.Sprintf("%s:/dataset.jsonl", datasetFileAbs), | ||
| "-v", fmt.Sprintf("%s:/model_output", outputDirAbs), | ||
| "-v", fmt.Sprintf("%s:/run_training.sh", tmpScript.Name()), | ||
| "-w", "/training_script", | ||
| containerImageURI, | ||
| "/run_training.sh", | ||
| } | ||
|
|
||
| // Setup context with signal handling for Ctrl+C | ||
| ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) | ||
| defer stop() | ||
|
|
||
| //nolint:gosec | ||
| cmd := exec.CommandContext(ctx, "docker", dockerArgs...) | ||
| cmd.Stdout = c.App.Writer | ||
| cmd.Stderr = c.App.ErrWriter | ||
|
|
||
| if err := cmd.Run(); err != nil { | ||
| // Check if the command was interrupted | ||
| if ctx.Err() == context.Canceled { | ||
| printf(c.App.Writer, "\nTraining interrupted by user") | ||
| return errors.New("training interrupted") | ||
| } | ||
| return errors.Wrap(err, "failed to run training in Docker container") | ||
| } | ||
|
|
||
etai-shuchatowitz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return nil | ||
| } | ||
|
|
||
| // checkDockerAvailable checks if Docker is installed and running. | ||
| func checkDockerAvailable() error { | ||
| // Create a context with timeout for Docker commands | ||
| ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) | ||
| defer cancel() | ||
|
|
||
| cmd := exec.CommandContext(ctx, "docker", "--version") | ||
| if err := cmd.Run(); err != nil { | ||
| if ctx.Err() == context.DeadlineExceeded { | ||
| return errors.New("Docker command timed out. Please check if Docker is responding") | ||
| } | ||
| return errors.New("Docker is not available. Please install Docker and ensure it is running. " + | ||
| "Visit https://docs.docker.com/get-docker/ for installation instructions") | ||
| } | ||
|
|
||
| // Check if Docker daemon is running | ||
| ctx, cancel = context.WithTimeout(context.Background(), 10*time.Second) | ||
| defer cancel() | ||
|
|
||
| cmd = exec.CommandContext(ctx, "docker", "ps") | ||
| if err := cmd.Run(); err != nil { | ||
| if ctx.Err() == context.DeadlineExceeded { | ||
| return errors.New("Docker daemon is not responding. It may be starting up - please wait and try again") | ||
| } | ||
| return errors.New("Docker daemon is not running. Please start Docker and try again") | ||
| } | ||
|
|
||
| return nil | ||
| } | ||
|
|
||
| // validateLocalTrainingPaths validates that required paths exist. | ||
| func validateLocalTrainingPaths(scriptDir, datasetFile string) error { | ||
| // Check training script directory exists | ||
| if _, err := os.Stat(scriptDir); os.IsNotExist(err) { | ||
| return errors.Errorf("training script directory does not exist: %s", scriptDir) | ||
| } | ||
|
|
||
| // Check for required files in training script directory | ||
| setupPyPath := filepath.Join(scriptDir, "setup.py") | ||
| if _, err := os.Stat(setupPyPath); os.IsNotExist(err) { | ||
| return errors.Errorf("setup.py not found in training script directory: %s", scriptDir) | ||
| } | ||
|
|
||
| trainingPyPath := filepath.Join(scriptDir, "model", "training.py") | ||
| if _, err := os.Stat(trainingPyPath); os.IsNotExist(err) { | ||
| return errors.Errorf("model/training.py not found in training script directory: %s", scriptDir) | ||
| } | ||
|
|
||
| // Check dataset file exists | ||
| if _, err := os.Stat(datasetFile); os.IsNotExist(err) { | ||
| return errors.Errorf("dataset file does not exist: %s", datasetFile) | ||
| } | ||
|
|
||
| return nil | ||
| } | ||
|
|
||
| // createTrainingScript creates the shell script content to run inside the container. | ||
| func createTrainingScript(customArgs []string, datasetPath string) (string, error) { | ||
| script := `#!/bin/bash | ||
| set -e | ||
|
|
||
| echo "Installing training script package..." | ||
| pip3 install --no-cache-dir . | ||
|
|
||
| echo "Running training..." | ||
| ` | ||
|
|
||
| // Build the python command with arguments | ||
| pythonCmd := "python3 -m model.training" | ||
|
|
||
| // Add dataset file argument | ||
| pythonCmd += " --dataset_file=/dataset.jsonl" | ||
|
|
||
| // Add model output directory argument | ||
| pythonCmd += " --model_output_directory=/model_output" | ||
|
|
||
| // Add custom arguments | ||
| for _, arg := range customArgs { | ||
| // Validate argument format (should be key=value) | ||
| if !strings.Contains(arg, "=") { | ||
| return "", errors.Errorf("invalid custom argument format: %s (expected key=value)", arg) | ||
| } | ||
| pythonCmd += " --" + arg | ||
| } | ||
|
|
||
| script += pythonCmd + "\n" | ||
| script += ` | ||
| echo "Training completed successfully!" | ||
| ` | ||
|
|
||
| return script, nil | ||
| } | ||
|
|
||
| // getContainerImageURI returns the full container image URI based on the version. | ||
| func getContainerImageURI(c *viamClient, version string) (string, error) { | ||
| res, err := c.mlTrainingClient.ListSupportedContainers(context.Background(), &mltrainingpb.ListSupportedContainersRequest{}) | ||
| if err != nil { | ||
| return "", errors.Wrapf(err, "failed to list supported containers") | ||
| } | ||
|
|
||
| containerKeyList := []string{} | ||
| for key := range res.ContainerMap { | ||
| containerKeyList = append(containerKeyList, key) | ||
| } | ||
| slices.Sort(containerKeyList) | ||
|
|
||
| container, ok := res.ContainerMap[version] | ||
| if !ok { | ||
| if dockerVertexImageRegex.MatchString(version) { | ||
etai-shuchatowitz marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return version, nil | ||
| } | ||
| return "", errors.Errorf("container version %s not found. Supported versions: %s", version, strings.Join(containerKeyList, ", ")) | ||
| } | ||
| return container.Uri, nil | ||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is probably gonna take awhile. I think maybe we should add info logging all over the place so people know that it's working.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The only part of this function that takes a significant amount of time is actually running the docker command to run training at the bottom and the logs from the docker container show up in the terminal. I also added timeouts for checkDockerAvailable so I think that should be fine? I'd suggest running this command, sorry I know it's annoying.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue here is that the container is really really really big. In the office downloading this can take 30 minutes. I know that it logs the docker stuff but giving people a heads up somewhere would be helpful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oop that is a really good point, thanks!