Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Added text searcher files
  • Loading branch information
priankakariat committed Oct 20, 2022
commit 41c6518b7dda0f05cc00a617be8196ddf80aa14a
28 changes: 28 additions & 0 deletions tensorflow_lite_support/ios/task/text/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package(
default_visibility = ["//tensorflow_lite_support:internal"],
licenses = ["notice"], # Apache 2.0
)

objc_library(
name = "TFLTextSearcher",
srcs = [
"sources/TFLTextSearcher.mm",
],
hdrs = [
"sources/TFLTextSearcher.h",
],
copts = [
"-ObjC++",
"-std=c++17",
],
features = ["-layering_check"],
module_name = "TFLTextSearcher",
deps = [
"//tensorflow_lite_support/cc/task/text:text_searcher",
"//tensorflow_lite_support/ios:TFLCommonUtils",
"//tensorflow_lite_support/ios/task/core:TFLBaseOptionsCppHelpers",
"//tensorflow_lite_support/ios/task/processor:TFLEmbeddingOptionsHelpers",
"//tensorflow_lite_support/ios/task/processor:TFLSearchOptionsHelpers",
"//tensorflow_lite_support/ios/task/processor:TFLSearchResultHelpers",
],
)
104 changes: 104 additions & 0 deletions tensorflow_lite_support/ios/task/text/sources/TFLTextSearcher.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#import <Foundation/Foundation.h>

#import "tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions.h"
#import "tensorflow_lite_support/ios/task/processor/sources/TFLEmbeddingOptions.h"
#import "tensorflow_lite_support/ios/task/processor/sources/TFLSearchOptions.h"
#import "tensorflow_lite_support/ios/task/processor/sources/TFLSearchResult.h"

NS_ASSUME_NONNULL_BEGIN

/**
* Options to configure TFLTextSearcher.
*/
NS_SWIFT_NAME(TextSearcherOptions)
@interface TFLTextSearcherOptions : NSObject

/**
* Base options for configuring the TextSearcher. This specifies the TFLite
* model to use for embedding extraction, as well as hardware acceleration
* options to use as inference time.
*/
@property(nonatomic, copy) TFLBaseOptions *baseOptions;

/**
* Options controlling the behavior of the embedding model specified in the
* base options.
*/
@property(nonatomic, copy) TFLEmbeddingOptions *embeddingOptions;

/**
* Options specifying the index to search into and controlling the search behavior.
*/
@property(nonatomic, copy) TFLSearchOptions *searchOptions;

/**
* Initializes a new `TFLTextSearcherOptions` with the absolute path to the model file
* stored locally on the device, set to the given the model path.
*
* @discussion The external model file must be a single standalone TFLite file. It could be packed
* with TFLite Model Metadata[1] and associated files if they exist. Failure to provide the
* necessary metadata and associated files might result in errors. Check the [documentation]
* (https://www.tensorflow.org/lite/convert/metadata) for each task about the specific requirement.
*
* @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device.
*
* @return An instance of `TFLTextSearcherOptions` initialized to the given model path.
*/
- (instancetype)initWithModelPath:(NSString *)modelPath;

@end

/**
* A TensorFlow Lite Task Text Searcher.
*/
NS_SWIFT_NAME(TextSearcher)
@interface TFLTextSearcher : NSObject

/**
* Creates a new instance of `TFLTextSearcher` from the given `TFLTextSearcherOptions`.
*
* @param options The options to use for configuring the `TFLTextSearcher`.
* @param error An optional error parameter populated when there is an error in initializing
* the text searcher.
*
* @return A new instance of `TextSearcher` with the given options. `nil` if there is an error
* in initializing the text searcher.
*/
+ (nullable instancetype)textSearcherWithOptions:(TFLTextSearcherOptions *)options
error:(NSError **)error
NS_SWIFT_NAME(searcher(options:));

+ (instancetype)new NS_UNAVAILABLE;

/**
* Performs embedding extraction on the given text, followed by nearest-neighbor search in the
* index.
*
* @param text An string on which embedding extraction is to be performed, followed by
* nearest-neighbor search in the index.
*
* @return A `TFLSearchResult`. `nil` if there is an error encountered during embedding extraction
* and nearest neighbor search. Please see `TFLSearchResult` for more details.
*/
- (nullable TFLSearchResult *)searchWithText:(NSString *)text
error:(NSError **)error NS_SWIFT_NAME(search(text:));

- (instancetype)init NS_UNAVAILABLE;

@end

NS_ASSUME_NONNULL_END
113 changes: 113 additions & 0 deletions tensorflow_lite_support/ios/task/text/sources/TFLTextSearcher.mm
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#import "tensorflow_lite_support/ios/task/text/sources/TFLTextSearcher.h"
#import "tensorflow_lite_support/ios/sources/TFLCommon.h"
#import "tensorflow_lite_support/ios/sources/TFLCommonUtils.h"
#import "tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions+CppHelpers.h"
#import "tensorflow_lite_support/ios/task/processor/sources/TFLEmbeddingOptions+Helpers.h"
#import "tensorflow_lite_support/ios/task/processor/sources/TFLSearchOptions+Helpers.h"
#import "tensorflow_lite_support/ios/task/processor/sources/TFLSearchResult+Helpers.h"

#include "tensorflow_lite_support/cc/task/text/text_searcher.h"

namespace {
using TextSearcherCpp = ::tflite::task::text::TextSearcher;
using TextSearcherOptionsCpp = ::tflite::task::text::TextSearcherOptions;
using SearchResultCpp = ::tflite::task::processor::SearchResult;
using ::tflite::support::StatusOr;
} // namespace

@interface TFLTextSearcher () {
/** TextSearcher backed by C API */
std::unique_ptr<TextSearcherCpp> _cppTextSearcher;
}
@end

@implementation TFLTextSearcherOptions

- (instancetype)init {
self = [super init];
if (self) {
_baseOptions = [[TFLBaseOptions alloc] init];
_embeddingOptions = [[TFLEmbeddingOptions alloc] init];
_searchOptions = [[TFLSearchOptions alloc] init];
}
return self;
}

- (instancetype)initWithModelPath:(NSString *)modelPath {
self = [self init];
if (self) {
_baseOptions.modelFile.filePath = modelPath;
}
return self;
}

- (TextSearcherOptionsCpp)cppOptions {
TextSearcherOptionsCpp cppOptions = {};
[self.baseOptions copyToCppOptions:cppOptions.mutable_base_options()];
[self.embeddingOptions copyToCppOptions:cppOptions.mutable_embedding_options()];
[self.searchOptions copyToCppOptions:cppOptions.mutable_search_options()];

return cppOptions;
}

@end

@implementation TFLTextSearcher

- (nullable instancetype)initWithCppTextSearcherOptions:(TextSearcherOptionsCpp)cppOptions {
self = [super init];
if (self) {
StatusOr<std::unique_ptr<TextSearcherCpp>> cppTextSearcher =
TextSearcherCpp::CreateFromOptions(cppOptions);
if (cppTextSearcher.ok()) {
_cppTextSearcher = std::move(cppTextSearcher.value());
} else {
return nil;
}
}
return self;
}

+ (nullable instancetype)textSearcherWithOptions:(TFLTextSearcherOptions *)options
error:(NSError **)error {
if (!options) {
[TFLCommonUtils createCustomError:error
withCode:TFLSupportErrorCodeInvalidArgumentError
description:@"TFLTextSearcherOptions argument cannot be nil."];
return nil;
}

TextSearcherOptionsCpp cppOptions = [options cppOptions];

return [[TFLTextSearcher alloc] initWithCppTextSearcherOptions:cppOptions];
}

- (nullable TFLSearchResult *)searchWithText:(NSString *)text error:(NSError **)error {
if (!text) {
[TFLCommonUtils createCustomError:error
withCode:TFLSupportErrorCodeInvalidArgumentError
description:@"GMLImage argument cannot be nil."];
return nil;
}

StatusOr<SearchResultCpp> cppSearchResultStatus = _cppTextSearcher->Search(
std::string([text UTF8String], [text lengthOfBytesUsingEncoding:NSUTF8StringEncoding]));

return [TFLSearchResult searchResultWithCppResult:cppSearchResultStatus error:error];
}

@end
31 changes: 31 additions & 0 deletions tensorflow_lite_support/ios/test/task/text/text_searcher/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
load("@org_tensorflow//tensorflow/lite/ios:ios.bzl", "TFL_DEFAULT_TAGS", "TFL_DISABLED_SANITIZER_TAGS", "TFL_MINIMUM_OS_VERSION")
load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test")
load("@org_tensorflow//tensorflow/lite:special_rules.bzl", "tflite_ios_lab_runner")

package(
default_visibility = ["//visibility:private"],
licenses = ["notice"], # Apache 2.0
)

objc_library(
name = "TFLTextSearcherObjcTestLibrary",
testonly = 1,
srcs = ["TFLTextSearcherTests.m"],
data = [
"//tensorflow_lite_support/cc/test/testdata/task/text:test_searchers",
],
tags = TFL_DEFAULT_TAGS,
deps = [
"//tensorflow_lite_support/ios/task/text:TFLTextSearcher",
],
)

ios_unit_test(
name = "TFLTextSearcherObjcTest",
minimum_os_version = TFL_MINIMUM_OS_VERSION,
runner = tflite_ios_lab_runner("IOS_LATEST"),
tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS,
deps = [
":TFLTextSearcherObjcTestLibrary",
],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#import <XCTest/XCTest.h>

#import "tensorflow_lite_support/ios/task/text/sources/TFLTextSearcher.h"

NS_ASSUME_NONNULL_BEGIN

#define VerifySearchResultCount(searchResult, expectedNearestNeighborsCount) \
XCTAssertEqual(searchResult.nearestNeighbors.count, expectedNearestNeighborsCount);

#define VerifyNearestNeighbor(nearestNeighbor, expectedMetadata, expectedDistance) \
XCTAssertEqualObjects(nearestNeighbor.metadata, expectedMetadata); \
XCTAssertEqualWithAccuracy(nearestNeighbor.distance, expectedDistance, 1e-6);

@interface TFLTextSearcherTests : XCTestCase
@property(nonatomic, nullable) NSString *modelPath;
@end

@implementation TFLTextSearcherTests

- (void)setUp {
[super setUp];
self.modelPath =
[[NSBundle bundleForClass:self.class] pathForResource:@"regex_searcher"
ofType:@"tflite"];
XCTAssertNotNil(self.modelPath);
}

- (TFLTextSearcher *)testSuccessfulCreationOfTextSearcherWithSearchContent:(NSString *)modelPath {
TFLTextSearcherOptions *textSearcherOptions =
[[TFLTextSearcherOptions alloc] initWithModelPath:self.modelPath];

TFLTextSearcher *textSearcher = [TFLTextSearcher textSearcherWithOptions:textSearcherOptions
error:nil];
XCTAssertNotNil(textSearcher);

return textSearcher;
}

- (void)verifySearchResultForInferenceWithSearchContent:(TFLSearchResult *)searchResult {
VerifySearchResultCount(searchResult,
5 // expectedNearestNeighborsCount
);

VerifyNearestNeighbor(searchResult.nearestNeighbors[0],
@"burger", // expectedMetadata
198.456329 // expectedDistance
);
VerifyNearestNeighbor(searchResult.nearestNeighbors[1],
@"car", // expectedMetadata
226.022186 // expectedDistance
);
VerifyNearestNeighbor(searchResult.nearestNeighbors[2],
@"bird", // expectedMetadata
227.297668 // expectedDistance
);
VerifyNearestNeighbor(searchResult.nearestNeighbors[3],
@"dog", // expectedMetadata
229.133789 // expectedDistance
);
VerifyNearestNeighbor(searchResult.nearestNeighbors[4],
@"cat", // expectedMetadata
229.718948 // expectedDistance
);
}

- (void)testSuccessfullInferenceWithSearchContentOnText {
TFLTextSearcher *textSearcher =
[self testSuccessfulCreationOfTextSearcherWithSearchContent:self.modelPath];
// GMLImage *gmlImage =
// [GMLImage imageFromBundleWithClass:self.class fileName:@"burger" ofType:@"jpg"];
// XCTAssertNotNil(gmlImage);

TFLSearchResult *searchResult = [textSearcher searchWithText:@"The weather was excellent." error:nil];
[self verifySearchResultForInferenceWithSearchContent:searchResult];
}

@end

NS_ASSUME_NONNULL_END
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
==============================================================================*/
#import <XCTest/XCTest.h>

#import "tensorflow_lite_support/ios/task/vision/sources/TFLImageSearcher.h"
#import "tensorflow_lite_support/ios/task/text/sources/TFLImageSearcher.h"
#import "tensorflow_lite_support/ios/task/vision/utils/sources/GMLImage+Utils.h"

NS_ASSUME_NONNULL_BEGIN
Expand Down