Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add CLI command for downloading dataset
  • Loading branch information
tahiyasalam committed Jul 8, 2024
commit 3fc61c36d9564dc85b5d908c3ebfaaa526c0a27c
6 changes: 6 additions & 0 deletions cli/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,12 @@ var app = &cli.App{
Usage: "option to include JSON Lines files for local testing",
Value: false,
},
&cli.UintFlag{
Name: dataFlagParallelDownloads,
Required: false,
Usage: "number of download requests to make in parallel",
Value: 100,
},
},
Action: DatasetDownloadAction,
},
Expand Down
4 changes: 2 additions & 2 deletions cli/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func (c *viamClient) dataExportAction(cCtx *cli.Context) error {

switch cCtx.String(dataFlagDataType) {
case dataTypeBinary:
if err := c.binaryData(cCtx.Path(dataFlagDestination), filter, cCtx.Uint(dataFlagParallelDownloads), false); err != nil {
if err := c.binaryData(cCtx.Path(dataFlagDestination), filter, cCtx.Uint(dataFlagParallelDownloads)); err != nil {
return err
}
case dataTypeTabular:
Expand Down Expand Up @@ -264,7 +264,7 @@ func createDataFilter(c *cli.Context) (*datapb.Filter, error) {
}

// BinaryData downloads binary data matching filter to dst.
func (c *viamClient) binaryData(dst string, filter *datapb.Filter, parallelDownloads uint, includeJSONL bool) error {
func (c *viamClient) binaryData(dst string, filter *datapb.Filter, parallelDownloads uint) error {
if err := c.ensureLoggedIn(); err != nil {
return err
}
Expand Down
67 changes: 43 additions & 24 deletions cli/dataset.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"github.com/urfave/cli/v2"
"go.uber.org/multierr"
datapb "go.viam.com/api/app/data/v1"
v1 "go.viam.com/api/app/data/v1"
datasetpb "go.viam.com/api/app/dataset/v1"
)

Expand Down Expand Up @@ -132,7 +131,7 @@ func (c *viamClient) listDatasetByOrg(orgID string) error {
return nil
}

// DatasetDeleteAction is the corresponding action for 'dataset rename'.
// DatasetDeleteAction is the corresponding action for 'dataset delete'.
func DatasetDeleteAction(c *cli.Context) error {
client, err := newViamClient(c)
if err != nil {
Expand All @@ -158,33 +157,34 @@ func (c *viamClient) deleteDataset(datasetID string) error {
return nil
}

// DatasetDeleteAction is the corresponding action for 'dataset rename'.
// DatasetDownloadAction is the corresponding action for 'dataset download'.
func DatasetDownloadAction(c *cli.Context) error {
client, err := newViamClient(c)
if err != nil {
return err
}
if err := client.downloadDataset(c.Path(dataFlagDestination), c.String(datasetFlagDatasetID), c.Bool(datasetFlagIncludeJSONLines)); err != nil {
if err := client.downloadDataset(c.Path(dataFlagDestination), c.String(datasetFlagDatasetID),
c.Bool(datasetFlagIncludeJSONLines), c.Uint(dataFlagParallelDownloads)); err != nil {
return err
}
return nil
}

// downloadDataset downloads a dataset with the specified ID.
func (c *viamClient) downloadDataset(dst, datasetID string, includeJSONLines bool) error {
func (c *viamClient) downloadDataset(dst, datasetID string, includeJSONLines bool, parallelDownloads uint) error {
if err := c.ensureLoggedIn(); err != nil {
return err
}

var datasetFile *os.File
var err error
if includeJSONLines {
//nolint:gosec
datasetPath := filepath.Join(dst, "dataset.jsonl")
if err := os.MkdirAll(filepath.Dir(datasetPath), 0o700); err != nil {
return errors.Wrapf(err, "could not create dataset directory %s", filepath.Dir(datasetPath))
}
//nolint:gosec
datasetFile, err := os.Create(datasetPath)
datasetFile, err = os.Create(datasetPath)
if err != nil {
return err
}
Expand All @@ -206,7 +206,7 @@ func (c *viamClient) downloadDataset(dst, datasetID string, includeJSONLines boo
},
&datapb.Filter{
DatasetId: datasetID,
}, 100,
}, parallelDownloads,
func(i int32) {
printf(c.c.App.Writer, "Downloaded %d files", i)
},
Expand All @@ -218,17 +218,11 @@ type Annotation struct {
AnnotationLabel string `json:"annotation_label"`
}

// ObjectDetection defines the format of the data in jsonlines for object detection.
type ObjectDetection struct {
ImageGCSURI string `json:"image_gcs_uri"`
BBoxAnnotations []BBoxAnnotation `json:"bounding_box_annotations"`
}

// ImageMetadata defines the format of the data in jsonlines for custom training.
type ImageMetadata struct {
ImagePath string `json:"image_path"`
ClassificationAnnotations []Annotation `json:"classification_annotations"`
BBoxAnnotations []*datapb.BoundingBox `json:"bounding_box_annotations"`
ImagePath string `json:"image_path"`
ClassificationAnnotations []Annotation `json:"classification_annotations"`
BBoxAnnotations []BBoxAnnotation `json:"bounding_box_annotations"`
}

// BBoxAnnotation holds the information associated with each bounding box.
Expand All @@ -240,31 +234,42 @@ type BBoxAnnotation struct {
YMaxNormalized float64 `json:"y_max_normalized"`
}

func binaryDataToJSONLines(ctx context.Context, client v1.DataServiceClient, file *os.File,
func binaryDataToJSONLines(ctx context.Context, client datapb.DataServiceClient, file *os.File,
id *datapb.BinaryID,
) error {
resp, err := client.BinaryDataByIDs(ctx, &datapb.BinaryDataByIDsRequest{
BinaryIds: []*datapb.BinaryID{id},
IncludeBinary: false,
})
var resp *datapb.BinaryDataByIDsResponse
var err error
for count := 0; count < maxRetryCount; count++ {
resp, err = client.BinaryDataByIDs(ctx, &datapb.BinaryDataByIDsRequest{
BinaryIds: []*datapb.BinaryID{id},
IncludeBinary: false,
})
if err == nil {
break
}
}
if err != nil {
return errors.Wrapf(err, serverErrorMessage)
}

data := resp.GetData()
if len(data) != 1 {
return errors.Errorf("expected a single response, received %d", len(data))
}
datum := data[0]

// Make JSONLines file
// Make JSONLines
var jsonl interface{}

annotations := []Annotation{}
for _, tag := range datum.GetMetadata().GetCaptureMetadata().GetTags() {
annotations = append(annotations, Annotation{AnnotationLabel: tag})
}
bboxAnnotations := convertBoundingBoxes(datum.GetMetadata().GetAnnotations().GetBboxes())
jsonl = ImageMetadata{
ImagePath: filenameForDownload(datum.GetMetadata()),
ClassificationAnnotations: annotations,
BBoxAnnotations: datum.GetMetadata().GetAnnotations().GetBboxes(),
BBoxAnnotations: bboxAnnotations,
}

line, err := json.Marshal(jsonl)
Expand All @@ -279,3 +284,17 @@ func binaryDataToJSONLines(ctx context.Context, client v1.DataServiceClient, fil

return nil
}

func convertBoundingBoxes(protoBBoxes []*datapb.BoundingBox) []BBoxAnnotation {
bboxes := make([]BBoxAnnotation, len(protoBBoxes))
for i, box := range protoBBoxes {
bboxes[i] = BBoxAnnotation{
AnnotationLabel: box.GetLabel(),
XMinNormalized: box.GetXMinNormalized(),
XMaxNormalized: box.GetXMaxNormalized(),
YMinNormalized: box.GetYMinNormalized(),
YMaxNormalized: box.GetYMaxNormalized(),
}
}
return bboxes
}