Skip to content

Commit f40a063

Browse files
bixia1tensorflower-gardener
authored andcommitted
[TF:TRT] Enhance InstantiateBuildAndRun to support the case where the input
type and output type are not the same. This is to prepare for a change to enhance the TF-TRT bridge to support the Cast operations that can be represented via IIdentityLayer. PiperOrigin-RevId: 312077452 Change-Id: Iab6bfb54d6a346eef158785f61a1311559cee855
1 parent ea113ef commit f40a063

File tree

1 file changed

+29
-8
lines changed

1 file changed

+29
-8
lines changed

tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1712,7 +1712,7 @@ INSTANTIATE_TEST_CASE_P(
17121712

17131713
// Builds and runs the converted network. Checks output tensor shape. Tests
17141714
// output values using a matcher.
1715-
template <DataType dtype>
1715+
template <DataType input_dtype, DataType output_dtype>
17161716
void BuildAndRunConvertedNetwork(const string& name, OpConverterTest* test,
17171717
const TestParamBase& p,
17181718
const std::vector<float>& input_vec,
@@ -1731,20 +1731,22 @@ void BuildAndRunConvertedNetwork(const string& name, OpConverterTest* test,
17311731
// runtime errors.
17321732
return;
17331733
}
1734-
typedef typename EnumToDataType<dtype>::Type T;
1734+
typedef typename EnumToDataType<input_dtype>::Type Tin;
17351735
TensorShape shape;
17361736
TF_EXPECT_OK(TensorShapeUtils::MakeShape(p.input_dims, &shape));
17371737
const DataVec input_data{
1738-
{"input", test->AsTensor<T>(CastTestVector<float, T>(input_vec), shape)}};
1739-
DataVec output_data{{name, test->ConstructTensor<T>(6)}};
1738+
{"input",
1739+
test->AsTensor<Tin>(CastTestVector<float, Tin>(input_vec), shape)}};
1740+
typedef typename EnumToDataType<output_dtype>::Type Tout;
1741+
DataVec output_data{{name, test->ConstructTensor<Tout>(6)}};
17401742
test->BuildAndRun(input_data, &output_data);
17411743
// Check the shape of the actual output tensor
17421744
TF_EXPECT_OK(TensorShapeUtils::MakeShape(p.expected_output_dims, &shape));
17431745
EXPECT_TRUE(output_data[0].tensor.shape() == shape)
17441746
<< "Expected shape: " << shape.DebugString() << ", actual shape"
17451747
<< output_data[0].tensor.shape().DebugString();
17461748
// Cast the output to float and compare to expected output
1747-
auto out_span = GetSpanForData<T>(output_data[0]);
1749+
auto out_span = GetSpanForData<Tout>(output_data[0]);
17481750
std::vector<float> casted_output(out_span.begin(), out_span.end());
17491751
EXPECT_THAT(casted_output, matcher);
17501752
}
@@ -1754,16 +1756,35 @@ void InstantiateBuildAndRun(DataType tf_dtype, const string& name,
17541756
const std::vector<float>& input_vec,
17551757
const Matcher<std::vector<float>>& matcher) {
17561758
if (tf_dtype == DT_FLOAT) {
1757-
BuildAndRunConvertedNetwork<DT_FLOAT>(name, test, p, input_vec, matcher);
1759+
BuildAndRunConvertedNetwork<DT_FLOAT, DT_FLOAT>(name, test, p, input_vec,
1760+
matcher);
17581761
} else if (tf_dtype == DT_HALF) {
1759-
BuildAndRunConvertedNetwork<DT_HALF>(name, test, p, input_vec, matcher);
1762+
BuildAndRunConvertedNetwork<DT_HALF, DT_HALF>(name, test, p, input_vec,
1763+
matcher);
17601764
} else if (tf_dtype == DT_INT32) {
1761-
BuildAndRunConvertedNetwork<DT_INT32>(name, test, p, input_vec, matcher);
1765+
BuildAndRunConvertedNetwork<DT_INT32, DT_INT32>(name, test, p, input_vec,
1766+
matcher);
17621767
} else {
17631768
FAIL() << "Test not supported for " << tf_dtype;
17641769
}
17651770
}
17661771

1772+
void InstantiateBuildAndRun(DataType input_tf_dtype, DataType output_tf_dtype,
1773+
const string& name, OpConverterTest* test,
1774+
const TestParamBase& p,
1775+
const std::vector<float>& input_vec,
1776+
const Matcher<std::vector<float>>& matcher) {
1777+
if (input_tf_dtype == output_tf_dtype) {
1778+
InstantiateBuildAndRun(input_tf_dtype, name, test, p, input_vec, matcher);
1779+
} else if (input_tf_dtype == DT_HALF && output_tf_dtype) {
1780+
BuildAndRunConvertedNetwork<DT_HALF, DT_FLOAT>(name, test, p, input_vec,
1781+
matcher);
1782+
} else {
1783+
FAIL() << "Test not supported for input " << input_tf_dtype << " output "
1784+
<< output_tf_dtype;
1785+
}
1786+
}
1787+
17671788
template <typename T>
17681789
void CopyTensorElements(const Tensor& tensor, protobuf::RepeatedField<T>* out) {
17691790
out->Clear();

0 commit comments

Comments
 (0)