Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
dataset paths
  • Loading branch information
vpandiarajan20 committed Nov 24, 2025
commit c76a430d2a3d7a0ee436f19bcd599bdbc5e6cd35
27 changes: 22 additions & 5 deletions cli/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -3402,28 +3402,45 @@ This won't work unless you have an existing installation of our GitHub app on yo
Name: "test-local",
Usage: "test training script locally using Docker",
UsageText: createUsageText("training-script test-local", []string{
trainFlagDatasetFile, trainFlagTrainingScriptDirectory,
trainFlagContainerVersion, trainFlagModelOutputDirectory,
trainFlagDatasetRoot, trainFlagTrainingScriptDirectory,
trainFlagDatasetFile, trainFlagContainerVersion, trainFlagModelOutputDirectory,
}, true, false),
Description: `Test your training script locally before submitting to the cloud. This runs your training script
in a Docker container using the same environment as cloud training.

REQUIREMENTS:
- Docker must be installed and running
- Training script directory must contain setup.py and model/training.py
- Dataset must be in JSONL format
- Dataset root directory must contain:
* dataset.jsonl (or the file specified with --dataset-file)
* All image files referenced in the dataset (using relative paths from dataset root)

DATASET ORGANIZATION:
The dataset root should be organized so that image paths in dataset.jsonl are relative to it.
For example:
dataset_root/
├── dataset.jsonl (contains paths like "data/images/cat.jpg")
└── data/
└── images/
└── cat.jpg

NOTES:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this goes in documentation? I also want to add that if the containers really slow, it could be because the dataset root or training script directory has a bunch of extra files. Apparently mounting volumes can be expensive.

- Training containers only support linux/x86_64 (amd64) architecture
- Ensure Docker Desktop has sufficient resources allocated (memory, CPU)
- The container's working directory will be set to the dataset root, so relative paths resolve correctly
- Model output will be saved to the specified output directory on your host machine
`,
Flags: []cli.Flag{
&cli.StringFlag{
Name: trainFlagDatasetFile,
Usage: "path to the dataset file (JSONL format)",
Name: trainFlagDatasetRoot,
Usage: "path to the dataset root directory (where dataset.jsonl and image files are located)",
Required: true,
},
&cli.StringFlag{
Name: trainFlagDatasetFile,
Usage: "relative path to the dataset file from the dataset root. Defaults to dataset.jsonl",
Value: "dataset.jsonl",
},
&cli.StringFlag{
Name: trainFlagTrainingScriptDirectory,
Usage: "path to the training script directory (must contain setup.py and model/training.py)," +
Expand Down
59 changes: 45 additions & 14 deletions cli/ml_training.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ const (
// Flags for test-local command
trainFlagContainerVersion = "container-version"
trainFlagDatasetFile = "dataset-file"
trainFlagDatasetRoot = "dataset-root"
trainFlagModelOutputDirectory = "model-output-directory"
trainFlagCustomArgs = "custom-args"
trainFlagTrainingScriptDirectory = "training-script-directory"
Expand Down Expand Up @@ -617,6 +618,7 @@ func convertVisibilityToProto(visibility string) (*v1.Visibility, error) {
type mlTrainingScriptTestLocalArgs struct {
ContainerVersion string
DatasetFile string
DatasetRoot string
ModelOutputDirectory string
TrainingScriptDirectory string
CustomArgs []string
Expand All @@ -634,21 +636,37 @@ func MLTrainingScriptTestLocalAction(c *cli.Context, args mlTrainingScriptTestLo
return err
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, just adding a "Docker is available. Using image BLAH"


// Get absolute path for dataset root directory
datasetRootAbs, err := filepath.Abs(args.DatasetRoot)
if err != nil {
return errors.Wrapf(err, "failed to get absolute path for dataset root directory")
}

// Dataset file is relative to dataset root (defaults to "dataset.jsonl")
datasetFileRelative := args.DatasetFile
if datasetFileRelative == "" {
datasetFileRelative = "dataset.jsonl"
}

// Clean the path to normalize it (remove .., ., extra slashes, etc.)
datasetFileRelative = filepath.Clean(datasetFileRelative)

// Ensure the path doesn't try to escape the dataset root
if strings.HasPrefix(datasetFileRelative, "..") || filepath.IsAbs(datasetFileRelative) {
return errors.Errorf("dataset file path must be relative to dataset root and cannot escape it: %s", datasetFileRelative)
}

// Validate required paths exist
if err := validateLocalTrainingPaths(args.TrainingScriptDirectory, args.DatasetFile); err != nil {
if err := validateLocalTrainingPaths(args.TrainingScriptDirectory, datasetRootAbs, filepath.Join(datasetRootAbs, datasetFileRelative)); err != nil {
return err
}
// Get absolute paths for volume mounting

// Get absolute path for training script directory
scriptDirAbs, err := filepath.Abs(args.TrainingScriptDirectory)
if err != nil {
return errors.Wrapf(err, "failed to get absolute path for training script directory")
}

datasetFileAbs, err := filepath.Abs(args.DatasetFile)
if err != nil {
return errors.Wrapf(err, "failed to get absolute path for dataset file")
}

// Determine output directory (default to current directory if not specified)
outputDir := args.ModelOutputDirectory
if outputDir == "" {
Expand All @@ -665,7 +683,7 @@ func MLTrainingScriptTestLocalAction(c *cli.Context, args mlTrainingScriptTestLo
}

// Create temporary script to run inside container
scriptContent, err := createTrainingScript(args.CustomArgs)
scriptContent, err := createTrainingScript(args.CustomArgs, datasetFileRelative)
if err != nil {
return err
}
Expand Down Expand Up @@ -706,10 +724,10 @@ func MLTrainingScriptTestLocalAction(c *cli.Context, args mlTrainingScriptTestLo
"--platform", "linux/x86_64",
"--rm",
"-v", fmt.Sprintf("%s:/training_script", scriptDirAbs),
"-v", fmt.Sprintf("%s:/dataset.jsonl", datasetFileAbs),
"-v", fmt.Sprintf("%s:/dataset_root", datasetRootAbs),
"-v", fmt.Sprintf("%s:/model_output", outputDirAbs),
"-v", fmt.Sprintf("%s:/run_training.sh", tmpScript.Name()),
"-w", "/training_script",
"-w", "/dataset_root", // Set working directory to dataset root so relative paths resolve correctly
containerImageURI,
"/run_training.sh",
}
Expand Down Expand Up @@ -775,7 +793,7 @@ func checkDockerAvailable() error {
}

// validateLocalTrainingPaths validates that required paths exist.
func validateLocalTrainingPaths(scriptDir, datasetFile string) error {
func validateLocalTrainingPaths(scriptDir, datasetRoot, datasetFile string) error {
// Check training script directory exists
if _, err := os.Stat(scriptDir); os.IsNotExist(err) {
return errors.Errorf("training script directory does not exist: %s", scriptDir)
Expand All @@ -792,6 +810,17 @@ func validateLocalTrainingPaths(scriptDir, datasetFile string) error {
return errors.Errorf("model/training.py not found in training script directory: %s", scriptDir)
}

// Check dataset root directory exists
if _, err := os.Stat(datasetRoot); os.IsNotExist(err) {
return errors.Errorf("dataset root directory does not exist: %s", datasetRoot)
}

// Check that the data folder exists
dataFolderPath := filepath.Join(datasetRoot, "data")
if _, err := os.Stat(dataFolderPath); os.IsNotExist(err) {
return errors.Errorf("data folder does not exist in dataset root directory: %s", datasetRoot)
}

// Check dataset file exists
if _, err := os.Stat(datasetFile); os.IsNotExist(err) {
return errors.Errorf("dataset file does not exist: %s", datasetFile)
Expand All @@ -801,7 +830,8 @@ func validateLocalTrainingPaths(scriptDir, datasetFile string) error {
}

// createTrainingScript creates the shell script content to run inside the container.
func createTrainingScript(customArgs []string) (string, error) {
// datasetFileRelative is the path to the dataset file relative to the dataset root (which is the CWD).
func createTrainingScript(customArgs []string, datasetFileRelative string) (string, error) {
// Validate custom arguments format (key=value) before building script
for _, arg := range customArgs {
if !strings.Contains(arg, "=") {
Expand All @@ -822,12 +852,13 @@ func createTrainingScript(customArgs []string) (string, error) {
script.WriteString("#!/bin/bash\n")
script.WriteString("set -e\n\n")
script.WriteString("echo \"Installing training script package...\"\n")
script.WriteString("pip3 install --no-cache-dir .\n\n")
script.WriteString("pip3 install --no-cache-dir /training_script\n\n")
script.WriteString("echo \"Running training...\"\n")

// Build Python training command
script.WriteString("python3 -m model.training")
script.WriteString(" --dataset_file=/dataset.jsonl")
script.WriteString(" --dataset_file=")
script.WriteString(datasetFileRelative)
script.WriteString(" --model_output_directory=/model_output")

// Add custom arguments
Expand Down
Loading