Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions cli/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -1576,6 +1576,10 @@ var app = &cli.App{
Usage: "number of seconds to wait for large file downloads",
Value: 30,
},
&cli.BoolFlag{
Name: datasetFlagForceLinuxPath,
Usage: "force the use of Linux-style paths for the dataset.jsonl file",
},
},
Action: createCommandWithT[datasetDownloadArgs](DatasetDownloadAction),
},
Expand Down Expand Up @@ -3434,6 +3438,76 @@ This won't work unless you have an existing installation of our GitHub app on yo
},
Action: createCommandWithT[mlTrainingUpdateArgs](MLTrainingUpdateAction),
},
{
Name: "test-local",
Usage: "test training script locally using Docker",
UsageText: createUsageText("training-script test-local", []string{
trainFlagDatasetRoot, trainFlagTrainingScriptDirectory,
trainFlagDatasetFile, trainFlagContainerVersion, trainFlagModelOutputDirectory,
}, true, false),
Description: `Test your training script locally before submitting to the cloud. This runs your training script
in a Docker container using the same environment as cloud training.

REQUIREMENTS:
- Docker must be installed and running
- Training script directory must contain model/training.py and setup.py.
- Dataset root directory must contain:
* dataset.jsonl (or the file specified with --dataset-file)
* All image files referenced in the dataset (using relative paths from dataset root)

DATASET ORGANIZATION:
The dataset root should be organized so that image paths in dataset.jsonl are relative to it.
If downloaded with the 'viam dataset export' command, this will happen automatically.
For example:
dataset_root/
├── dataset.jsonl (contains paths like "data/images/cat.jpg")
└── data/
└── images/
└── cat.jpg

NOTES:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this goes in documentation? I also want to add that if the containers really slow, it could be because the dataset root or training script directory has a bunch of extra files. Apparently mounting volumes can be expensive.

- Training containers only support linux/x86_64 (amd64) architecture
- Ensure Docker Desktop has sufficient resources allocated (memory, CPU)
- The container's working directory will be set to the dataset root, so relative paths resolve correctly
- Model output will be saved to the specified output directory on your host machine
- If using Windows, ensure the dataset file path is in Linux format by passing --force-linux-path to the 'viam dataset export' command
`,
Flags: []cli.Flag{
&cli.StringFlag{
Name: trainFlagDatasetRoot,
Usage: "path to the dataset root directory (where dataset.jsonl and image files are located)." +
" This is where you ran the 'viam dataset export' command from. The container will be mounted to this directory",
Required: true,
},
&cli.StringFlag{
Name: trainFlagDatasetFile,
Usage: "relative path to the dataset file from the dataset root. Defaults to dataset.jsonl",
Value: "dataset.jsonl",
},
&cli.StringFlag{
Name: trainFlagTrainingScriptDirectory,
Usage: "path to the training script directory (must contain setup.py and model/training.py)",
Required: true,
},
&cli.StringFlag{
Name: trainFlagContainerVersion,
Usage: `ml training container version to use.
Must be one of the supported container names found by
calling ListSupportedContainers`,
Required: true,
},
&cli.StringFlag{
Name: trainFlagModelOutputDirectory,
Usage: "directory where the trained model will be saved. Defaults to current directory",
Value: ".",
},
&cli.StringSliceFlag{
Name: trainFlagCustomArgs,
Usage: "custom arguments to pass to the training script (format: key=value)",
},
},
Action: createCommandWithT[mlTrainingScriptTestLocalArgs](MLTrainingScriptTestLocalAction),
},
},
},
{
Expand Down
35 changes: 20 additions & 15 deletions cli/dataset.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@ import (
)

const (
datasetFlagName = "name"
datasetFlagDatasetID = "dataset-id"
datasetFlagDatasetIDs = "dataset-ids"
dataFlagLocationID = "location-id"
dataFlagBinaryDataIDs = "binary-data-ids"
datasetFlagOnlyJSONLines = "only-jsonl"
datasetFlagName = "name"
datasetFlagDatasetID = "dataset-id"
datasetFlagDatasetIDs = "dataset-ids"
dataFlagLocationID = "location-id"
dataFlagBinaryDataIDs = "binary-data-ids"
datasetFlagOnlyJSONLines = "only-jsonl"
datasetFlagForceLinuxPath = "force-linux-path"
)

type datasetCreateArgs struct {
Expand Down Expand Up @@ -196,11 +197,12 @@ func (c *viamClient) deleteDataset(datasetID string) error {
}

type datasetDownloadArgs struct {
Destination string
DatasetID string
OnlyJSONl bool
Parallel uint
Timeout uint
Destination string
DatasetID string
OnlyJSONl bool
ForceLinuxPath bool
Parallel uint
Timeout uint
}

// DatasetDownloadAction is the corresponding action for 'dataset export'.
Expand All @@ -210,14 +212,14 @@ func DatasetDownloadAction(c *cli.Context, args datasetDownloadArgs) error {
return err
}
if err := client.downloadDataset(args.Destination, args.DatasetID,
args.OnlyJSONl, args.Parallel, args.Timeout); err != nil {
args.OnlyJSONl, args.ForceLinuxPath, args.Parallel, args.Timeout); err != nil {
return err
}
return nil
}

// downloadDataset downloads a dataset with the specified ID.
func (c *viamClient) downloadDataset(dst, datasetID string, onlyJSONLines bool, parallelDownloads, timeout uint) error {
func (c *viamClient) downloadDataset(dst, datasetID string, onlyJSONLines, forceLinuxPath bool, parallelDownloads, timeout uint) error {
var datasetFile *os.File
var err error
datasetPath := filepath.Join(dst, "dataset.jsonl")
Expand Down Expand Up @@ -252,7 +254,7 @@ func (c *viamClient) downloadDataset(dst, datasetID string, onlyJSONLines bool,
downloadErr = c.downloadBinary(dst, timeout, id)
datasetFilePath = filepath.Join(dst, dataDir)
}
datasetErr := binaryDataToJSONLines(c.c.Context, c.dataClient, datasetFilePath, datasetFile, id)
datasetErr := binaryDataToJSONLines(c.c.Context, c.dataClient, datasetFilePath, datasetFile, id, forceLinuxPath)

return multierr.Combine(downloadErr, datasetErr)
},
Expand Down Expand Up @@ -294,7 +296,7 @@ type BBoxAnnotation struct {
}

func binaryDataToJSONLines(ctx context.Context, client datapb.DataServiceClient, dst string, file *os.File,
id string,
id string, forceLinuxPath bool,
) error {
var resp *datapb.BinaryDataByIDsResponse
var err error
Expand All @@ -318,6 +320,9 @@ func binaryDataToJSONLines(ctx context.Context, client datapb.DataServiceClient,
datum := data[0]

fileName := filepath.Join(dst, filenameForDownload(datum.GetMetadata()))
if forceLinuxPath {
fileName = filepath.ToSlash(fileName)
}
ext := datum.GetMetadata().GetFileExt()
// If the file is gzipped, unzip.
if ext != gzFileExt && filepath.Ext(fileName) != ext {
Expand Down
Loading