Skip to content
This repository was archived by the owner on Jan 13, 2022. It is now read-only.

Commit 60523b0

Browse files
committed
update
1 parent 66662b2 commit 60523b0

File tree

1 file changed

+30
-9
lines changed

1 file changed

+30
-9
lines changed

test/enet/main.cpp

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,26 +69,47 @@ int main()
6969
interpreter->SetNumThreads(num_of_threads);
7070
TFLITE_MINIMAL_CHECK(interpreter->AllocateTensors() == kTfLiteOk);
7171
printf("=== Pre-invoke Interpreter State ===\n");
72-
tflite::PrintInterpreterState(interpreter.get());
72+
//tflite::PrintInterpreterState(interpreter.get());
7373

7474
// Set data to input tensor
75+
printf("=== MEM copy ===\n");
7576
float* input = interpreter->typed_input_tensor<float>(0);
76-
//memcpy(input, image.reshape(0, 1).data, sizeof(float) * 1 * 360 * 480 * 3);
7777
memcpy(input, image.data, sizeof(float) * input_array_size);
78-
printf("\n\n=== MEM copied ===\n");
7978

8079
// Run inference
81-
printf("\n\n=== Pre-invoke Interpreter State ===\n");
80+
printf("=== Pre-invoke ===\n");
81+
const auto& start_time = std::chrono::steady_clock::now();
8282
TFLITE_MINIMAL_CHECK(interpreter->Invoke() == kTfLiteOk);
83-
printf("\n\n=== Post-invoke Interpreter State ===\n");
84-
tflite::PrintInterpreterState(interpreter.get());
83+
const std::chrono::duration<double, std::milli>& time_span = std::chrono::steady_clock::now() - start_time;
84+
std::cout << "Inference time: " << time_span.count() << " ms" << std::endl;
85+
printf("=== Post-invoke ===\n");
8586

8687
// Get data from output tensor
87-
float* probs = interpreter->typed_output_tensor<float>(0);
88-
for (int i = 0; i < 10; i++) {
89-
printf("prob of %d: %.3f\n", i, probs[i]);
88+
std::vector<float> output_data;
89+
const auto& output_indices = interpreter->outputs();
90+
const int num_outputs = output_indices.size();
91+
int out_idx = 0;
92+
for (int i = 0; i < num_outputs; ++i)
93+
{
94+
const auto* out_tensor = interpreter->tensor(output_indices[i]);
95+
assert(out_tensor != nullptr);
96+
const int num_values = out_tensor->bytes / sizeof(float);
97+
output_data.resize(out_idx + num_values);
98+
const float* output = interpreter->typed_output_tensor<float>(i);
99+
for (int j = 0; j < num_values; ++j)
100+
{
101+
output_data[out_idx++] = output[j];
102+
}
90103
}
91104

105+
// Create segmantation map.
106+
cv::Mat seg_im(cv::Size(input_tensor_shape[1], input_tensor_shape[2]), CV_8UC3);
107+
LabelToColorMap(output_data, *color_map.get(), seg_im);
108+
109+
// output tensor size => camera resolution
110+
cv::resize(seg_im, seg_im, cv::Size(480, 360));
111+
seg_im = (image / 2) + (seg_im / 2);
112+
92113
cv::waitKey(0);
93114
return 0;
94115
}

0 commit comments

Comments
 (0)