@@ -18,22 +18,23 @@ limitations under the License.
1818
1919#include < gmock/gmock.h>
2020#include < gtest/gtest.h>
21- #include " tensorflow/lite/core/c/c_api_types.h"
22- #include " tensorflow/lite/core/c/common.h"
23- #include " tensorflow/lite/core/interpreter.h"
2421#include " tensorflow/lite/core/async/async_kernel_internal.h"
2522#include " tensorflow/lite/core/async/backend_async_kernel_interface.h"
2623#include " tensorflow/lite/core/async/c/task.h"
2724#include " tensorflow/lite/core/async/c/types.h"
2825#include " tensorflow/lite/core/async/common.h"
2926#include " tensorflow/lite/core/async/testing/mock_async_kernel.h"
3027#include " tensorflow/lite/core/async/testing/test_backend.h"
28+ #include " tensorflow/lite/core/c/c_api_types.h"
29+ #include " tensorflow/lite/core/c/common.h"
30+ #include " tensorflow/lite/core/interpreter.h"
31+ #include " tensorflow/lite/interpreter_test_util.h"
3132#include " tensorflow/lite/kernels/builtin_op_kernels.h"
3233
3334namespace tflite {
3435namespace async {
3536
36- class AsyncSignatureRunnerTest : public ::testing::Test {
37+ class AsyncSignatureRunnerTest : public InterpreterTest {
3738 protected:
3839 void SetUp () override {
3940 kernel_ =
@@ -53,41 +54,39 @@ class AsyncSignatureRunnerTest : public ::testing::Test {
5354 void * builtin_data_1 = malloc (sizeof (int ));
5455 interpreter_->AddNodeWithParameters ({0 , 0 }, {1 }, nullptr , 0 , builtin_data_1,
5556 reg);
56- signature_def_.signature_key = " serving_default" ;
57- signature_def_.inputs [" input" ] = 0 ;
58- signature_def_.outputs [" output" ] = 1 ;
59- signature_def_.subgraph_index = 0 ;
60- }
61-
62- void BuildAsyncSignatureRunner () {
57+ const char kSignatureKey [] = " serving_default" ;
58+ BuildSignature (kSignatureKey , {{" input" , 0 }}, {{" output" , 1 }});
6359 interpreter_->ModifyGraphWithDelegate (backend_->get_delegate ());
64- signature_runner_ = std::make_unique<AsyncSignatureRunner>(
65- &signature_def_, interpreter_->subgraph (0 ));
6660 }
6761
6862 int GetTensorIndex (TfLiteIoType io_type, const char * name) {
6963 return signature_runner_->GetTensorIndex (io_type, name);
7064 }
7165
72- void TearDown () override { signature_runner_.reset (); }
73-
7466 protected:
7567 std::unique_ptr<::testing::StrictMock<testing::MockAsyncKernel>> kernel_;
7668 std::unique_ptr<testing::TestBackend> backend_;
77- std::unique_ptr<Interpreter> interpreter_;
7869 internal::SignatureDef signature_def_;
79- std::unique_ptr< AsyncSignatureRunner> signature_runner_;
70+ AsyncSignatureRunner* signature_runner_ = nullptr ;
8071};
8172
73+ TEST_F (AsyncSignatureRunnerTest, GetAsyncSignatureRunner) {
74+ EXPECT_EQ (nullptr , signature_runner_);
75+ signature_runner_ = interpreter_->GetAsyncSignatureRunner (" serving_default" );
76+ EXPECT_NE (nullptr , signature_runner_);
77+
78+ EXPECT_EQ (nullptr , interpreter_->GetAsyncSignatureRunner (" foo" ));
79+ }
80+
8281TEST_F (AsyncSignatureRunnerTest, InputNameTest) {
83- BuildAsyncSignatureRunner ( );
82+ signature_runner_ = interpreter_-> GetAsyncSignatureRunner ( " serving_default " );
8483 EXPECT_EQ (0 , GetTensorIndex (TfLiteIoType::kTfLiteIoInput , " input" ));
8584 EXPECT_EQ (-1 , GetTensorIndex (TfLiteIoType::kTfLiteIoInput , " output" ));
8685 EXPECT_EQ (-1 , GetTensorIndex (TfLiteIoType::kTfLiteIoInput , " foo" ));
8786}
8887
8988TEST_F (AsyncSignatureRunnerTest, OutputNameTest) {
90- BuildAsyncSignatureRunner ( );
89+ signature_runner_ = interpreter_-> GetAsyncSignatureRunner ( " serving_default " );
9190 EXPECT_EQ (1 , GetTensorIndex (TfLiteIoType::kTfLiteIoOutput , " output" ));
9291 EXPECT_EQ (-1 , GetTensorIndex (TfLiteIoType::kTfLiteIoOutput , " input" ));
9392 EXPECT_EQ (-1 , GetTensorIndex (TfLiteIoType::kTfLiteIoOutput , " foo" ));
@@ -96,7 +95,7 @@ TEST_F(AsyncSignatureRunnerTest, OutputNameTest) {
9695TEST_F (AsyncSignatureRunnerTest, CreateTaskTest) {
9796 EXPECT_CALL (*kernel_, Finish (::testing::_, ::testing::_));
9897
99- BuildAsyncSignatureRunner ( );
98+ signature_runner_ = interpreter_-> GetAsyncSignatureRunner ( " serving_default " );
10099 auto * task = signature_runner_->CreateTask ();
101100 EXPECT_NE (nullptr , task);
102101
0 commit comments