Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Pass tensor & metato tensor_spec. rm tensor count.
  • Loading branch information
jonpsy authored Dec 22, 2021
commit 31c2b31f4c20db4c5529658126d50cba4aa0e07b
62 changes: 10 additions & 52 deletions tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,35 +40,20 @@ using ::tflite::support::TfLiteSupportStatus;
using ::tflite::task::core::TfLiteEngine;

StatusOr<const TensorMetadata*> GetTensorMetadataIfAny(
const ModelMetadataExtractor& metadata_extractor, const bool is_input) {
const ModelMetadataExtractor& metadata_extractor,
const TensorMetadata* tensor_metadata) {
if (metadata_extractor.GetModelMetadata() == nullptr ||
metadata_extractor.GetModelMetadata()->subgraph_metadata() == nullptr) {
// Some models have no metadata at all (or very partial), so exit early.
return nullptr;
} else if (is_input && metadata_extractor.GetInputTensorCount() != 1) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
"Models are assumed to have a single input TensorMetadata.",
TfLiteSupportStatus::kInvalidNumInputTensorsError);
} else if (!is_input && metadata_extractor.GetOutputTensorCount() != 1) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
"Models are assumed to have a single output TensorMetadata.",
TfLiteSupportStatus::kInvalidNumOutputTensorsError);
}

const TensorMetadata* input_metadata =
metadata_extractor.GetInputTensorMetadata(0);
const TensorMetadata* output_metadata =
metadata_extractor.GetOutputTensorMetadata(0);

if (input_metadata == nullptr) {
if (tensor_metadata == nullptr) {
// Should never happen.
return CreateStatusWithPayload(StatusCode::kInternal,
"Input TensorMetadata is null.");
"Provided TensorMetadata is null.");
}

return is_input ? input_metadata : output_metadata;
return tensor_metadata;
}
} // namespace

Expand Down Expand Up @@ -143,41 +128,15 @@ StatusOr<absl::optional<NormalizationOptions>> GetNormalizationOptionsIfAny(

StatusOr<ImageTensorSpecs> BuildImageTensorSpecs(
const TfLiteEngine::Interpreter& interpreter,
const tflite::metadata::ModelMetadataExtractor& metadata_extractor,
const bool is_input) {
ASSIGN_OR_RETURN(const TensorMetadata* metadata,
GetTensorMetadataIfAny(metadata_extractor, is_input));

const TensorMetadata* tensor_metadata, const TfLiteTensor* tensor) {
const ImageProperties* props = nullptr;
absl::optional<NormalizationOptions> normalization_options;
if (metadata != nullptr) {
ASSIGN_OR_RETURN(props, GetImagePropertiesIfAny(*metadata));
if (tensor_metadata != nullptr) {
ASSIGN_OR_RETURN(props, GetImagePropertiesIfAny(*tensor_metadata));
ASSIGN_OR_RETURN(normalization_options,
GetNormalizationOptionsIfAny(*metadata));
if (!is_input && !normalization_options.has_value()) {
ASSIGN_OR_RETURN(normalization_options,
GetNormalizationOptionsIfAny(
*metadata_extractor.GetInputTensorMetadata(0)));
}
GetNormalizationOptionsIfAny(*tensor_metadata));
}

if (is_input && TfLiteEngine::InputCount(&interpreter) != 1) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
"Models are assumed to have a single input.",
TfLiteSupportStatus::kInvalidNumInputTensorsError);
} else if (!is_input && TfLiteEngine::OutputCount(&interpreter) != 1) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
"Models are assumed to have a single output.",
TfLiteSupportStatus::kInvalidNumOutputTensorsError);
}

// Input/output-related specifications.
const TfLiteTensor* tensor = is_input
? TfLiteEngine::GetInput(&interpreter, 0)
: TfLiteEngine::GetOutput(&interpreter, 0);

if (tensor->dims->size != 4) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
Expand Down Expand Up @@ -265,8 +224,7 @@ StatusOr<ImageTensorSpecs> BuildImageTensorSpecs(
result.image_height = height;
result.color_space = ColorSpaceType_RGB;
result.tensor_type = tensor_type;
result.normalization_options =
normalization_options;
result.normalization_options = normalization_options;

return result;
}
Expand Down