Skip to content

Commit be61711

Browse files
sirakiintensorflower-gardener
authored andcommitted
Add API to interpreter experimental to retrieve async signature runner.
PiperOrigin-RevId: 507869775
1 parent 287820d commit be61711

File tree

9 files changed

+69
-62
lines changed

9 files changed

+69
-62
lines changed

tensorflow/lite/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ if(NOT "${CMAKE_SYSTEM_NAME}" STREQUAL "iOS")
245245
endif()
246246
populate_tflite_source_vars("core" TFLITE_CORE_SRCS)
247247
populate_tflite_source_vars("core/api" TFLITE_CORE_API_SRCS)
248+
populate_tflite_source_vars("core/async" TFLITE_CORE_ASYNC_SRCS)
248249
populate_tflite_source_vars("core/c" TFLITE_CORE_C_SRCS)
249250
populate_tflite_source_vars("core/experimental/acceleration/configuration" TFLITE_CORE_EXPERIMENTAL_SRCS)
250251
populate_tflite_source_vars("core/kernels" TFLITE_CORE_KERNELS_SRCS)
@@ -512,6 +513,7 @@ set(_ALL_TFLITE_SRCS
512513
${TFLITE_CORE_EXPERIMENTAL_SRCS}
513514
${TFLITE_CORE_KERNELS_SRCS}
514515
${TFLITE_CORE_SRCS}
516+
${TFLITE_CORE_ASYNC_SRCS}
515517
${TFLITE_CORE_TOOLS_SRCS}
516518
${TFLITE_C_SRCS}
517519
${TFLITE_DELEGATES_FLEX_SRCS}

tensorflow/lite/core/BUILD

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ cc_library(
4545
"//tensorflow/lite:type_to_tflitetype",
4646
"//tensorflow/lite/core/api",
4747
"//tensorflow/lite/core/api:error_reporter",
48+
"//tensorflow/lite/core/async:async_signature_runner",
4849
"//tensorflow/lite/core/c:common",
4950
"//tensorflow/lite/experimental/remat:metadata_util",
5051
"//tensorflow/lite/experimental/resource",
@@ -109,6 +110,7 @@ cc_library(
109110
"//tensorflow/lite/c:common_internal",
110111
"//tensorflow/lite/core/api",
111112
"//tensorflow/lite/core/api:verifier",
113+
"//tensorflow/lite/core/async:async_signature_runner",
112114
"//tensorflow/lite/core/c:common",
113115
"//tensorflow/lite/experimental/remat:metadata_util",
114116
"//tensorflow/lite/experimental/resource",
@@ -180,6 +182,7 @@ cc_library(
180182
"//tensorflow/lite/c:common_internal",
181183
"//tensorflow/lite/core/api",
182184
"//tensorflow/lite/core/api:verifier",
185+
"//tensorflow/lite/core/async:async_signature_runner",
183186
"//tensorflow/lite/core/c:common",
184187
"//tensorflow/lite/experimental/remat:metadata_util",
185188
"//tensorflow/lite/experimental/resource",
@@ -237,6 +240,7 @@ cc_library(
237240
"//tensorflow/lite/c:common_internal",
238241
"//tensorflow/lite/core/api",
239242
"//tensorflow/lite/core/api:verifier",
243+
"//tensorflow/lite/core/async:async_signature_runner",
240244
"//tensorflow/lite/core/c:c_api_types",
241245
"//tensorflow/lite/core/c:common",
242246
"//tensorflow/lite/delegates:telemetry",
@@ -313,6 +317,7 @@ cc_library(
313317
"//tensorflow/lite/c:common_internal",
314318
"//tensorflow/lite/core/api",
315319
"//tensorflow/lite/core/api:verifier",
320+
"//tensorflow/lite/core/async:async_signature_runner",
316321
"//tensorflow/lite/core/c:c_api_types",
317322
"//tensorflow/lite/core/c:common",
318323
"//tensorflow/lite/experimental/remat:metadata_util",

tensorflow/lite/core/async/BUILD

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,13 @@ cc_library(
5454
name = "async_subgraph",
5555
srcs = ["async_subgraph.cc"],
5656
hdrs = ["async_subgraph.h"],
57+
compatible_with = get_compatible_with_portable(),
5758
deps = [
5859
":async_kernel_internal",
5960
":common",
6061
":task_internal",
61-
"//tensorflow/lite:framework",
6262
"//tensorflow/lite:minimal_logging",
63+
"//tensorflow/lite/core:subgraph",
6364
"//tensorflow/lite/core/async/interop/c:types",
6465
"//tensorflow/lite/core/c:c_api_types",
6566
"//tensorflow/lite/core/c:common",
@@ -103,14 +104,13 @@ cc_library(
103104
name = "async_signature_runner",
104105
srcs = ["async_signature_runner.cc"],
105106
hdrs = ["async_signature_runner.h"],
107+
compatible_with = get_compatible_with_portable(),
106108
deps = [
107109
":async_kernel_internal",
108110
":async_subgraph",
109111
":common",
110112
":task_internal",
111-
"//tensorflow/lite:framework",
112-
"//tensorflow/lite/c:c_api_without_op_resolver",
113-
"//tensorflow/lite/core/c:c_api",
113+
"//tensorflow/lite/core:subgraph",
114114
"//tensorflow/lite/core/c:c_api_types",
115115
"//tensorflow/lite/core/c:common",
116116
"//tensorflow/lite/internal:signature_def",
@@ -126,7 +126,7 @@ cc_test(
126126
":backend_async_kernel_interface",
127127
":common",
128128
"//tensorflow/lite:framework",
129-
"//tensorflow/lite/c:c_api_for_testing",
129+
"//tensorflow/lite:interpreter_test_util",
130130
"//tensorflow/lite/core:headers",
131131
"//tensorflow/lite/core/async/c:task",
132132
"//tensorflow/lite/core/async/c:types",

tensorflow/lite/core/async/async_signature_runner.cc

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,23 +26,8 @@ limitations under the License.
2626
#include "tensorflow/lite/core/async/async_subgraph.h"
2727
#include "tensorflow/lite/core/async/common.h"
2828
#include "tensorflow/lite/core/async/task_internal.h"
29-
#include "tensorflow/lite/signature_runner.h"
3029

3130
namespace tflite {
32-
33-
// This is a temporary helper class that will be removed after this API is
34-
// moved out of experimental.
35-
class SignatureRunnerHelper {
36-
public:
37-
static Subgraph* GetSubgraph(SignatureRunner* runner) {
38-
return runner->subgraph_;
39-
}
40-
static const internal::SignatureDef* GetSignatureDef(
41-
SignatureRunner* runner) {
42-
return runner->signature_def_;
43-
}
44-
};
45-
4631
namespace async {
4732

4833
namespace {
@@ -79,11 +64,6 @@ int AsyncSignatureRunner::GetTensorIndex(TfLiteIoType io_type,
7964
return tensor_index;
8065
}
8166

82-
AsyncSignatureRunner::AsyncSignatureRunner(SignatureRunner* signature_runner)
83-
: AsyncSignatureRunner(
84-
SignatureRunnerHelper::GetSignatureDef(signature_runner),
85-
SignatureRunnerHelper::GetSubgraph(signature_runner)) {}
86-
8767
AsyncSignatureRunner::AsyncSignatureRunner(
8868
const internal::SignatureDef* signature_def, Subgraph* subgraph)
8969
: signature_def_(signature_def), subgraph_(subgraph) {

tensorflow/lite/core/async/async_signature_runner.h

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,12 @@ limitations under the License.
1818
#include <memory>
1919
#include <vector>
2020

21-
#include "tensorflow/lite/core/c/c_api.h"
2221
#include "tensorflow/lite/core/c/common.h"
2322
#include "tensorflow/lite/core/subgraph.h"
2423
#include "tensorflow/lite/core/async/async_kernel_internal.h"
2524
#include "tensorflow/lite/core/async/async_subgraph.h"
2625
#include "tensorflow/lite/core/async/common.h"
2726
#include "tensorflow/lite/internal/signature_def.h"
28-
#include "tensorflow/lite/signature_runner.h"
2927

3028
namespace tflite {
3129
namespace async {
@@ -39,20 +37,8 @@ class AsyncSignatureRunnerTest;
3937
// SignatureDef.
4038
class AsyncSignatureRunner {
4139
public:
42-
// TODO(b/191883048): Move ctor to private and use `Create` function as
43-
// factory method.
44-
// Currently we don't have way to expose signature def from interpreter
45-
// without changes to interpreter.
46-
//
47-
// static AsyncSignatureRunner* Create(const TfLiteInterpreter* interpreter,
48-
// const char* signature_key);
49-
// WARNING: This is a temporary constructor before we stablize the API.
50-
// This if for avoiding making intrusive changes to non experimental code.
51-
// For now, users can construct AsyncSignatureRunner as follows:
52-
// std::unique_ptr<tflite::Interpreter> interpreter;
53-
// InterpreterBuilder(model, resolver)(&interpreter);
54-
// AsyncSignatureRunner runner(interpreter->GetSignatureRunner("func"));
55-
explicit AsyncSignatureRunner(SignatureRunner* signature_runner);
40+
// Builds the AsyncSignatureRunner given the provided signature_def and
41+
// subgraph.
5642
AsyncSignatureRunner(const internal::SignatureDef* signature_def,
5743
Subgraph* subgraph);
5844

tensorflow/lite/core/async/async_signature_runner_test.cc

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,23 @@ limitations under the License.
1818

1919
#include <gmock/gmock.h>
2020
#include <gtest/gtest.h>
21-
#include "tensorflow/lite/core/c/c_api_types.h"
22-
#include "tensorflow/lite/core/c/common.h"
23-
#include "tensorflow/lite/core/interpreter.h"
2421
#include "tensorflow/lite/core/async/async_kernel_internal.h"
2522
#include "tensorflow/lite/core/async/backend_async_kernel_interface.h"
2623
#include "tensorflow/lite/core/async/c/task.h"
2724
#include "tensorflow/lite/core/async/c/types.h"
2825
#include "tensorflow/lite/core/async/common.h"
2926
#include "tensorflow/lite/core/async/testing/mock_async_kernel.h"
3027
#include "tensorflow/lite/core/async/testing/test_backend.h"
28+
#include "tensorflow/lite/core/c/c_api_types.h"
29+
#include "tensorflow/lite/core/c/common.h"
30+
#include "tensorflow/lite/core/interpreter.h"
31+
#include "tensorflow/lite/interpreter_test_util.h"
3132
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
3233

3334
namespace tflite {
3435
namespace async {
3536

36-
class AsyncSignatureRunnerTest : public ::testing::Test {
37+
class AsyncSignatureRunnerTest : public InterpreterTest {
3738
protected:
3839
void SetUp() override {
3940
kernel_ =
@@ -53,41 +54,39 @@ class AsyncSignatureRunnerTest : public ::testing::Test {
5354
void* builtin_data_1 = malloc(sizeof(int));
5455
interpreter_->AddNodeWithParameters({0, 0}, {1}, nullptr, 0, builtin_data_1,
5556
reg);
56-
signature_def_.signature_key = "serving_default";
57-
signature_def_.inputs["input"] = 0;
58-
signature_def_.outputs["output"] = 1;
59-
signature_def_.subgraph_index = 0;
60-
}
61-
62-
void BuildAsyncSignatureRunner() {
57+
const char kSignatureKey[] = "serving_default";
58+
BuildSignature(kSignatureKey, {{"input", 0}}, {{"output", 1}});
6359
interpreter_->ModifyGraphWithDelegate(backend_->get_delegate());
64-
signature_runner_ = std::make_unique<AsyncSignatureRunner>(
65-
&signature_def_, interpreter_->subgraph(0));
6660
}
6761

6862
int GetTensorIndex(TfLiteIoType io_type, const char* name) {
6963
return signature_runner_->GetTensorIndex(io_type, name);
7064
}
7165

72-
void TearDown() override { signature_runner_.reset(); }
73-
7466
protected:
7567
std::unique_ptr<::testing::StrictMock<testing::MockAsyncKernel>> kernel_;
7668
std::unique_ptr<testing::TestBackend> backend_;
77-
std::unique_ptr<Interpreter> interpreter_;
7869
internal::SignatureDef signature_def_;
79-
std::unique_ptr<AsyncSignatureRunner> signature_runner_;
70+
AsyncSignatureRunner* signature_runner_ = nullptr;
8071
};
8172

73+
TEST_F(AsyncSignatureRunnerTest, GetAsyncSignatureRunner) {
74+
EXPECT_EQ(nullptr, signature_runner_);
75+
signature_runner_ = interpreter_->GetAsyncSignatureRunner("serving_default");
76+
EXPECT_NE(nullptr, signature_runner_);
77+
78+
EXPECT_EQ(nullptr, interpreter_->GetAsyncSignatureRunner("foo"));
79+
}
80+
8281
TEST_F(AsyncSignatureRunnerTest, InputNameTest) {
83-
BuildAsyncSignatureRunner();
82+
signature_runner_ = interpreter_->GetAsyncSignatureRunner("serving_default");
8483
EXPECT_EQ(0, GetTensorIndex(TfLiteIoType::kTfLiteIoInput, "input"));
8584
EXPECT_EQ(-1, GetTensorIndex(TfLiteIoType::kTfLiteIoInput, "output"));
8685
EXPECT_EQ(-1, GetTensorIndex(TfLiteIoType::kTfLiteIoInput, "foo"));
8786
}
8887

8988
TEST_F(AsyncSignatureRunnerTest, OutputNameTest) {
90-
BuildAsyncSignatureRunner();
89+
signature_runner_ = interpreter_->GetAsyncSignatureRunner("serving_default");
9190
EXPECT_EQ(1, GetTensorIndex(TfLiteIoType::kTfLiteIoOutput, "output"));
9291
EXPECT_EQ(-1, GetTensorIndex(TfLiteIoType::kTfLiteIoOutput, "input"));
9392
EXPECT_EQ(-1, GetTensorIndex(TfLiteIoType::kTfLiteIoOutput, "foo"));
@@ -96,7 +95,7 @@ TEST_F(AsyncSignatureRunnerTest, OutputNameTest) {
9695
TEST_F(AsyncSignatureRunnerTest, CreateTaskTest) {
9796
EXPECT_CALL(*kernel_, Finish(::testing::_, ::testing::_));
9897

99-
BuildAsyncSignatureRunner();
98+
signature_runner_ = interpreter_->GetAsyncSignatureRunner("serving_default");
10099
auto* task = signature_runner_->CreateTask();
101100
EXPECT_NE(nullptr, task);
102101

tensorflow/lite/core/async/async_subgraph.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ TfLiteAsyncKernel* AsyncSubgraph::async_kernel() const {
4747
AsyncSubgraph::AsyncSubgraph(Subgraph* subgraph) : subgraph_(subgraph) {
4848
// Currently we only support one delegate and fully delegated subgph.
4949
if (!IsFullyDelegated()) {
50-
subgraph->ReportError("Model is no fully delegated by 1 backend.");
50+
subgraph->ReportError("Model is not fully delegated by 1 backend.");
5151
return;
5252
}
5353
// TODO(b/191883048): Add/Check delegate flag to indicate kernel support.

tensorflow/lite/core/interpreter.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ limitations under the License.
4141

4242
#include "tensorflow/lite/allocation.h"
4343
#include "tensorflow/lite/core/api/error_reporter.h"
44+
#include "tensorflow/lite/core/async/async_signature_runner.h"
4445
#include "tensorflow/lite/core/api/profiler.h"
4546
#include "tensorflow/lite/core/c/common.h" // IWYU pragma: export
4647
#include "tensorflow/lite/core/subgraph.h"
@@ -344,6 +345,14 @@ class Interpreter {
344345
/// and the SignatureRunner class is *not* thread-safe.
345346
SignatureRunner* GetSignatureRunner(const char* signature_key);
346347

348+
/// \warning Experimental interface, subject to change. \n
349+
/// \brief Returns a pointer to the AsyncSignatureRunner instance to run the
350+
/// part of the graph identified by a SignatureDef. The nullptr is returned if
351+
/// the given signature key is not valid.
352+
/// The async delegate should be applied before calling this function.
353+
async::AsyncSignatureRunner* GetAsyncSignatureRunner(
354+
const char* signature_key);
355+
347356
/// \warning Experimental interface, subject to change. \n
348357
/// \brief Return the subgraph index that corresponds to a SignatureDef,
349358
/// defined by 'signature_key'.
@@ -946,6 +955,12 @@ class Interpreter {
946955
// its SignatureDef.
947956
std::map<std::string, SignatureRunner> signature_runner_map_;
948957

958+
// Map of signature key to its corresponding AsyncSignatureRunner object.
959+
// An AsyncSignatureRunner is basically a wrapper of the AsyncSubgraph
960+
// corresponding to its SignatureDef.
961+
std::map<std::string, async::AsyncSignatureRunner>
962+
async_signature_runner_map_;
963+
949964
// Model metadata stored as mapping of name (key) to buffer (value).
950965
// Data is mapped from the Metadata in TFLite flatbuffer model.
951966
std::map<std::string, std::string> metadata_;

tensorflow/lite/core/interpreter_experimental.cc

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ limitations under the License.
2727
#include "tensorflow/lite/c/common_internal.h"
2828
#include "tensorflow/lite/core/api/error_reporter.h"
2929
#include "tensorflow/lite/core/api/profiler.h"
30+
#include "tensorflow/lite/core/async/async_signature_runner.h"
3031
#include "tensorflow/lite/core/c/c_api_types.h"
3132
#include "tensorflow/lite/core/c/common.h"
3233
#include "tensorflow/lite/core/interpreter.h"
@@ -165,4 +166,23 @@ SignatureRunner* Interpreter::GetSignatureRunner(const char* signature_key) {
165166
return nullptr;
166167
}
167168

169+
async::AsyncSignatureRunner* Interpreter::GetAsyncSignatureRunner(
170+
const char* signature_key) {
171+
auto iter = async_signature_runner_map_.find(signature_key);
172+
if (iter != async_signature_runner_map_.end()) {
173+
return &(iter->second);
174+
}
175+
176+
for (const auto& signature : signature_defs_) {
177+
if (signature.signature_key == signature_key) {
178+
auto status = async_signature_runner_map_.insert(
179+
{signature_key, async::AsyncSignatureRunner(
180+
&signature, subgraph(signature.subgraph_index))});
181+
return &(status.first->second);
182+
}
183+
}
184+
185+
return nullptr;
186+
}
187+
168188
} // namespace tflite

0 commit comments

Comments
 (0)