@@ -1712,7 +1712,7 @@ INSTANTIATE_TEST_CASE_P(
17121712
17131713// Builds and runs the converted network. Checks output tensor shape. Tests
17141714// output values using a matcher.
1715- template <DataType dtype >
1715+ template <DataType input_dtype, DataType output_dtype >
17161716void BuildAndRunConvertedNetwork (const string& name, OpConverterTest* test,
17171717 const TestParamBase& p,
17181718 const std::vector<float >& input_vec,
@@ -1731,20 +1731,22 @@ void BuildAndRunConvertedNetwork(const string& name, OpConverterTest* test,
17311731 // runtime errors.
17321732 return ;
17331733 }
1734- typedef typename EnumToDataType<dtype >::Type T ;
1734+ typedef typename EnumToDataType<input_dtype >::Type Tin ;
17351735 TensorShape shape;
17361736 TF_EXPECT_OK (TensorShapeUtils::MakeShape (p.input_dims , &shape));
17371737 const DataVec input_data{
1738- {" input" , test->AsTensor <T>(CastTestVector<float , T>(input_vec), shape)}};
1739- DataVec output_data{{name, test->ConstructTensor <T>(6 )}};
1738+ {" input" ,
1739+ test->AsTensor <Tin>(CastTestVector<float , Tin>(input_vec), shape)}};
1740+ typedef typename EnumToDataType<output_dtype>::Type Tout;
1741+ DataVec output_data{{name, test->ConstructTensor <Tout>(6 )}};
17401742 test->BuildAndRun (input_data, &output_data);
17411743 // Check the shape of the actual output tensor
17421744 TF_EXPECT_OK (TensorShapeUtils::MakeShape (p.expected_output_dims , &shape));
17431745 EXPECT_TRUE (output_data[0 ].tensor .shape () == shape)
17441746 << " Expected shape: " << shape.DebugString () << " , actual shape"
17451747 << output_data[0 ].tensor .shape ().DebugString ();
17461748 // Cast the output to float and compare to expected output
1747- auto out_span = GetSpanForData<T >(output_data[0 ]);
1749+ auto out_span = GetSpanForData<Tout >(output_data[0 ]);
17481750 std::vector<float > casted_output (out_span.begin (), out_span.end ());
17491751 EXPECT_THAT (casted_output, matcher);
17501752}
@@ -1754,16 +1756,35 @@ void InstantiateBuildAndRun(DataType tf_dtype, const string& name,
17541756 const std::vector<float >& input_vec,
17551757 const Matcher<std::vector<float >>& matcher) {
17561758 if (tf_dtype == DT_FLOAT) {
1757- BuildAndRunConvertedNetwork<DT_FLOAT>(name, test, p, input_vec, matcher);
1759+ BuildAndRunConvertedNetwork<DT_FLOAT, DT_FLOAT>(name, test, p, input_vec,
1760+ matcher);
17581761 } else if (tf_dtype == DT_HALF) {
1759- BuildAndRunConvertedNetwork<DT_HALF>(name, test, p, input_vec, matcher);
1762+ BuildAndRunConvertedNetwork<DT_HALF, DT_HALF>(name, test, p, input_vec,
1763+ matcher);
17601764 } else if (tf_dtype == DT_INT32) {
1761- BuildAndRunConvertedNetwork<DT_INT32>(name, test, p, input_vec, matcher);
1765+ BuildAndRunConvertedNetwork<DT_INT32, DT_INT32>(name, test, p, input_vec,
1766+ matcher);
17621767 } else {
17631768 FAIL () << " Test not supported for " << tf_dtype;
17641769 }
17651770}
17661771
1772+ void InstantiateBuildAndRun (DataType input_tf_dtype, DataType output_tf_dtype,
1773+ const string& name, OpConverterTest* test,
1774+ const TestParamBase& p,
1775+ const std::vector<float >& input_vec,
1776+ const Matcher<std::vector<float >>& matcher) {
1777+ if (input_tf_dtype == output_tf_dtype) {
1778+ InstantiateBuildAndRun (input_tf_dtype, name, test, p, input_vec, matcher);
1779+ } else if (input_tf_dtype == DT_HALF && output_tf_dtype) {
1780+ BuildAndRunConvertedNetwork<DT_HALF, DT_FLOAT>(name, test, p, input_vec,
1781+ matcher);
1782+ } else {
1783+ FAIL () << " Test not supported for input " << input_tf_dtype << " output "
1784+ << output_tf_dtype;
1785+ }
1786+ }
1787+
17671788template <typename T>
17681789void CopyTensorElements (const Tensor& tensor, protobuf::RepeatedField<T>* out) {
17691790 out->Clear ();
0 commit comments