Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,108 @@ struct ShutdownProtobuf {

namespace onnxruntime {

// Helper function to check if a data type is supported by NvTensorRTRTX EP
static bool IsSupportedDataType(ONNXTensorElementDataType data_type) {
switch (data_type) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: // kFLOAT - 32-bit floating point
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: // kHALF - IEEE 16-bit floating-point
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: // kBF16 - Brain float 16
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: // kBOOL - 8-bit boolean
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4: // kINT4 - 4-bit signed integer
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: // kINT8 - 8-bit signed integer
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: // kUINT8 - 8-bit unsigned integer
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: // kINT32 - 32-bit signed integer
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: // kINT64 - 64-bit signed integer
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN: // kFP8 - 8-bit floating point
return true;
default:
return false;
}
}

// Helper function to get data type name as string
static std::string GetDataTypeName(ONNXTensorElementDataType data_type) {
switch (data_type) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
return "FLOAT";
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
return "FLOAT16";
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16:
return "BFLOAT16";
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
return "BOOL";
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4:
return "INT4";
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
return "INT8";
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
return "UINT8";
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
return "INT32";
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
return "INT64";
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN:
return "FLOAT8E4M3FN";
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
return "DOUBLE";
case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
return "STRING";
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
return "UINT16";
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
return "UINT32";
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
return "UINT64";
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
return "INT16";
case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64:
return "COMPLEX64";
case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128:
return "COMPLEX128";
default:
return "UNKNOWN(" + std::to_string(static_cast<int>(data_type)) + ")";
}
}

// Helper function to check if a node has supported data types
static bool CheckNodeDataTypes(const Node* node) {
// Check input data types
for (const auto* input_def : node->InputDefs()) {
if (input_def->Exists()) {
const auto* type_proto = input_def->TypeAsProto();
if (type_proto && type_proto->has_tensor_type()) {
auto data_type = static_cast<ONNXTensorElementDataType>(type_proto->tensor_type().elem_type());
if (!IsSupportedDataType(data_type)) {
LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Node '" << node->Name()
<< "' (OpType: " << node->OpType()
<< ") has unsupported input data type: " << GetDataTypeName(data_type)
<< " for input '" << input_def->Name() << "'";
return false;
}
}
}
}

// Check output data types
for (const auto* output_def : node->OutputDefs()) {
if (output_def->Exists()) {
const auto* type_proto = output_def->TypeAsProto();
if (type_proto && type_proto->has_tensor_type()) {
auto data_type = static_cast<ONNXTensorElementDataType>(type_proto->tensor_type().elem_type());
if (!IsSupportedDataType(data_type)) {
LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Node '" << node->Name()
<< "' (OpType: " << node->OpType()
<< ") has unsupported output data type: " << GetDataTypeName(data_type)
<< " for output '" << output_def->Name() << "'";
return false;
}
}
}
}

return true;
}

void* OutputAllocator::reallocateOutputAsync(char const* /*tensorName*/, void* /*currentMemory*/, uint64_t size,
uint64_t /*alignment*/, cudaStream_t /*stream*/) noexcept {
// Some memory allocators return nullptr when allocating zero bytes, but TensorRT requires a non-null ptr
Expand Down Expand Up @@ -478,10 +580,12 @@ Status BindContextInput(Ort::KernelContext& ctx,
CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t)
CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16, uint16_t)
CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool)
CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4, uint8_t)
CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t)
CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t)
CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t)
CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t)
CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN, uint8_t)
default: {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"NvTensorRTRTX EP input onnx tensor data type: " + std::to_string(tensor_type) + " not supported.");
Expand Down Expand Up @@ -562,10 +666,12 @@ Status BindContextOutput(Ort::KernelContext& ctx,
CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t)
CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16, uint16_t)
CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool)
CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4, uint8_t)
CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t)
CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t)
CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t)
CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t)
CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN, uint8_t)
default: {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"NvTensorRTRTX EP output tensor data type: " + std::to_string(output_type) + " not supported.");
Expand Down Expand Up @@ -624,10 +730,12 @@ Status BindKernelOutput(Ort::KernelContext& ctx,
CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t)
CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16, uint16_t)
CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool)
CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4, uint8_t)
CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t)
CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t)
CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t)
CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t)
CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN, uint8_t)
default: {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"NvTensorRTRTX EP output tensor data type: " + std::to_string(output_type) + " not supported.");
Expand Down Expand Up @@ -1878,6 +1986,7 @@ NvExecutionProvider::GetCapability(const GraphViewer& graph,
/* Iterate all the nodes and exclude the node if:
* 1. It's a control flow op and its subgraph(s) is not fully TRT eligible.
* 2. It's a DDS op.
* 3. It has unsupported data types.
*/
for (const auto& index : nodes_vector) {
const auto& node = graph.GetNode(node_index[index]);
Expand Down Expand Up @@ -1917,6 +2026,16 @@ NvExecutionProvider::GetCapability(const GraphViewer& graph,
supported_node = false;
}

// Check data types and print warnings for unsupported types
if (supported_node) {
if (!CheckNodeDataTypes(node)) {
supported_node = false;
LOGS_DEFAULT(INFO) << "[NvTensorRTRTX EP] Node '" << node->Name()
<< "' (OpType: " << node->OpType()
<< ") excluded due to unsupported data types";
}
}

if (supported_node) {
if (new_subgraph) {
parser_nodes_vector.emplace_back();
Expand Down
Loading