diff --git a/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h b/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h index 11cc6f131dab3..dc27204017caa 100644 --- a/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h +++ b/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h @@ -32,9 +32,8 @@ constexpr const char* kProfilesMinShapes = "nv_profile_min_shapes"; constexpr const char* kProfilesMaxShapes = "nv_profile_max_shapes"; constexpr const char* kProfilesOptShapes = "nv_profile_opt_shapes"; constexpr const char* kCudaGraphEnable = "nv_cuda_graph_enable"; -constexpr const char* kONNXBytestream = "nv_onnx_bytestream"; -constexpr const char* kONNXBytestreamSize = "nv_onnx_bytestream_size"; constexpr const char* kMultiProfileEnable = "nv_multi_profile_enable"; +constexpr const char* kUseExternalDataInitializer = "nv_use_external_data_initializer"; } // namespace provider_option_names namespace run_option_names { diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index cc9d9f3da1d81..4619faddba150 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -11,6 +11,7 @@ #include "core/common/common.h" #include "core/common/narrow.h" #include "core/common/safeint.h" +#include "core/framework/ort_value.h" #include "nv_execution_provider.h" #include "nv_execution_provider_utils.h" #include "nv_execution_provider_custom_ops.h" @@ -487,7 +488,7 @@ Status BindContextInput(Ort::KernelContext& ctx, if (!trt_context->setTensorAddress(input_name, &shape_tensor_values[input_name][0])) { std::string error_input_name = input_name; std::string error_msg = - "Nv EP failed to call nvinfer1::IExecutionContext::setTensorAddress() for shape input '" + + "NvTensorRTRTX EP failed to call nvinfer1::IExecutionContext::setTensorAddress() for shape input '" + error_input_name + "'"; ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, error_msg)); } @@ -510,7 +511,7 @@ Status BindContextInput(Ort::KernelContext& ctx, if (!trt_context->setTensorAddress(input_name, &shape_tensor_values_int64[input_name][0])) { std::string error_input_name = input_name; std::string error_msg = - "Nv EP failed to call nvinfer1::IExecutionContext::setTensorAddress() for shape input '" + + "NvTensorRTRTX EP failed to call nvinfer1::IExecutionContext::setTensorAddress() for shape input '" + error_input_name + "'"; ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, error_msg)); } @@ -532,7 +533,7 @@ Status BindContextInput(Ort::KernelContext& ctx, if (!trt_context->setInputShape(input_name, dims)) { std::string error_input_name = input_name; ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "Nv EP failed to call nvinfer1::IExecutionContext::setInputShape() for input '" + error_input_name + "'")); + "NvTensorRTRTX EP failed to call nvinfer1::IExecutionContext::setInputShape() for input '" + error_input_name + "'")); } // Bind "execution tensor" input buffer @@ -553,7 +554,7 @@ Status BindContextInput(Ort::KernelContext& ctx, CASE_GET_CAST_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, double, float) default: { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "Nv EP input onnx tensor data type: " + std::to_string(tensor_type) + " not supported."); + "NvTensorRTRTX EP input onnx tensor data type: " + std::to_string(tensor_type) + " not supported."); } } trt_context->setTensorAddress(input_name, data); @@ -644,7 +645,7 @@ Status BindContextOutput(Ort::KernelContext& ctx, CASE_GET_CAST_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, double, float) default: { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "Nv EP output tensor data type: " + std::to_string(output_type) + " not supported."); + "NvTensorRTRTX EP output tensor data type: " + std::to_string(output_type) + " not supported."); } } trt_context->setTensorAddress(output_name, buffers[output_name]); @@ -707,7 +708,7 @@ Status BindKernelOutput(Ort::KernelContext& ctx, CASE_CAST_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, float, double) default: { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "Nv EP output tensor data type: " + std::to_string(output_type) + " not supported."); + "NvTensorRTRTX EP output tensor data type: " + std::to_string(output_type) + " not supported."); } } return Status::OK(); @@ -836,7 +837,12 @@ NvExecutionProvider::NvExecutionProvider(const NvExecutionProviderInfo& info) cudaDeviceProp prop; CUDA_CALL_THROW(cudaGetDeviceProperties(&prop, device_id_)); - compute_capability_ = GetComputeCapacity(prop); + auto cc = prop.major * 10 + prop.minor; + if (!(cc == 86 || cc == 89 || cc >= 120)) { + ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "[NvTensorRTRTX EP] The execution provider only supports RTX devices with compute capabilities 86, 89, 120 and above")); + } + compute_capability_ = GetComputeCapability(prop); if (info.has_user_compute_stream) { external_stream_ = true; stream_ = static_cast(info.user_compute_stream); @@ -866,6 +872,15 @@ NvExecutionProvider::NvExecutionProvider(const NvExecutionProviderInfo& info) "When providing either 'trt_onnx_bytestream_size' or " "'trt_onnx_bytestream' both have to be provided")); } + use_external_data_initializer_ = info.use_external_data_initializer; + onnx_external_data_bytestream_ = info.external_data_bytestream; + onnx_external_data_bytestream_size_ = info.external_data_bytestream_size; + if ((onnx_external_data_bytestream_ != nullptr && onnx_external_data_bytestream_size_ == 0) || + (onnx_external_data_bytestream_ == nullptr && onnx_external_data_bytestream_size_ != 0)) { + ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "When providing either 'onnx_external_data_bytestream_size' or " + "'onnx_external_data_bytestream' both have to be provided")); + } detailed_build_log_ = info.detailed_build_log; dump_ep_context_model_ = info.dump_ep_context_model; ep_context_file_path_ = info.ep_context_file_path; @@ -979,13 +994,13 @@ NvExecutionProvider::NvExecutionProvider(const NvExecutionProviderInfo& info) LIBTYPE handle = OPENLIB(engine_decryption_lib_path_.c_str()); if (handle == nullptr) { ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "Nv EP could not open shared library from " + engine_decryption_lib_path_)); + "NvTensorRTRTX EP could not open shared library from " + engine_decryption_lib_path_)); } engine_decryption_ = (int (*)(const char*, char*, size_t*))LIBFUNC(handle, "decrypt"); engine_encryption_ = (int (*)(const char*, char*, size_t))LIBFUNC(handle, "encrypt"); if (engine_decryption_ == nullptr) { ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "Nv EP could not find decryption function in shared library from " + engine_decryption_lib_path_)); + "NvTensorRTRTX EP could not find decryption function in shared library from " + engine_decryption_lib_path_)); } } @@ -1029,6 +1044,8 @@ NvExecutionProvider::NvExecutionProvider(const NvExecutionProviderInfo& info) << ", nv_ep_context_embed_mode: " << ep_context_embed_mode_ << ", nv_cache_prefix: " << cache_prefix_ << ", nv_onnx_model_bytestream_size_: " << onnx_model_bytestream_size_ + << ", nv_onnx_external_bytestream_size_: " << onnx_external_data_bytestream_size_ + << ", nv_use_external_data_initializer_: " << use_external_data_initializer_ << ", nv_op_types_to_exclude: " << op_types_to_exclude_; } @@ -1140,6 +1157,9 @@ nvinfer1::IBuilder* NvExecutionProvider::GetBuilder(TensorrtLogger& trt_logger) { auto lock = GetApiLock(); builder_ = std::unique_ptr(nvinfer1::createInferBuilder(trt_logger)); + unsigned int num_threads = std::thread::hardware_concurrency(); + builder_->setMaxThreads(num_threads / 2); + LOGS_DEFAULT(INFO) << "[NvTensorRTRTX EP] Set threads that the builder can use to:" << builder_->getMaxThreads(); } } return builder_.get(); @@ -1450,8 +1470,11 @@ SubGraphCollection_t NvExecutionProvider::GetSupportedList(SubGraphCollection_t SetAllGraphInputs(graph_build); } - ORT_ENFORCE(graph_build.Resolve().IsOK()); - + auto status = graph_build.Resolve(); + if (!status.IsOK()) { + LOGS_DEFAULT(ERROR) << status.ErrorMessage(); + ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "ONNX graph resolve failed: " + status.ErrorMessage())); + } // Add parent graph output to the subgraph int i = 0; std::vector subgraph_outputs; @@ -1502,7 +1525,37 @@ SubGraphCollection_t NvExecutionProvider::GetSupportedList(SubGraphCollection_t // When creating model proto from graph viewer, let ORT use priority-based topological sort based on node index. // The reason is, in some cases, for example ResNet50, using default topological sort will end up with generating // the model proto that has different node ordering compared to original onnx model. - graph_viewer->ToProto(*model_proto->mutable_graph(), true, true, 1 /*priority-based topological sort*/); + + // save user provided external data in memory instead of writing to ModelProto + // needed for models > 2GB + std::vector userWeights; + if (use_external_data_initializer_) { + auto c_api = Ort::GetApi(); + const InitializedTensorSet& allInitializers = graph_viewer->GetAllInitializedTensors(); + userWeights.reserve(allInitializers.size()); + for (auto& entry : allInitializers) { + OrtValue initializer_value; + auto* tp = entry.second; + if (utils::HasRawData(*tp)) { + userWeights.emplace_back(TensorrtUserWeights(tp->name(), tp->raw_data().data(), tp->raw_data().size())); + } else if (graph_viewer->GetOrtValueInitializer(tp->name(), initializer_value)) { + // the initializer was marked as external data by the ORT graph at load time since it was provided in memory + size_t size = 0; + const void* ptr = nullptr; + c_api.GetTensorSizeInBytes(&initializer_value, &size); + c_api.GetTensorData(&initializer_value, &ptr); + userWeights.emplace_back(tp->name(), ptr, size); + } else if (utils::HasExternalDataInMemory(*tp)) { + // only copy and take ownership of the data if none of the above conditions are met + std::unique_ptr full_init; + ORT_THROW_IF_ERROR(utils::GetTensorProtoWithDataIfInMemory(*tp, full_init)); + userWeights.emplace_back(std::move(full_init->name()), std::move(full_init->raw_data())); + } + } + } + + graph_viewer->ToProto(*model_proto->mutable_graph(), true, true, 1 /*priority-based topological sort*/, !use_external_data_initializer_ /*include raw initializers*/); + model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); std::string string_buf; @@ -1521,11 +1574,25 @@ SubGraphCollection_t NvExecutionProvider::GetSupportedList(SubGraphCollection_t auto network_flags = 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED); auto trt_network = std::unique_ptr(trt_builder->createNetworkV2(network_flags)); + bool is_model_supported = false; + // limit the scope of trt_parser so that model gets unloaded from memory asap { auto trt_parser = tensorrt_ptr::unique_pointer(nvonnxparser::createParser(*trt_network, trt_logger)); - auto is_model_supported = trt_parser->supportsModelV2(string_buf.data(), string_buf.size(), model_path_); + if (use_external_data_initializer_) { +#if TRT_MAJOR_RTX > 1 || TRT_MINOR_RTX >= 1 + trt_parser->loadModelProto(string_buf.data(), string_buf.size(), model_path_); + for (auto const& userWeight : userWeights) { + trt_parser->loadInitializer(userWeight.Name(), userWeight.Data(), userWeight.Size()); + } + is_model_supported = trt_parser->parseModelProto(); +#else + ORT_THROW("'nv_use_external_data_initializer' is only supported on TensorRT RTX 1.1.x.x and above."); +#endif + } else { + is_model_supported = trt_parser->supportsModelV2(string_buf.data(), string_buf.size(), model_path_); + } // Note: Calling getNbSubgraphs or getSubgraphNodes before calling supportsModelV2 results in undefined behavior. auto num_subgraphs = trt_parser->getNbSubgraphs(); @@ -1708,21 +1775,33 @@ NvExecutionProvider::GetCapability(const GraphViewer& graph, #endif model_path_[sizeof(model_path_) - 1] = '\0'; - // If the model consists of only a single "EPContext" contrib op, it means TRT EP can fetch the precompiled engine info from the node and - // load the engine directly without having to go through the processes of graph proto reconstruction, calling TRT parser and engine compilation. - // So, simply return the ComputeCapability here. - if (graph.NumberOfNodes() == 1 && GraphHasCtxNode(graph)) { - SubGraph_t supported_node_vector = {{0}, true}; - std::unique_ptr sub_graph = GetSubGraph(supported_node_vector, graph, TRTGenerateId(graph, std::to_string(trt_version_), std::to_string(cuda_version_)), 0); - result.push_back(ComputeCapability::Create(std::move(sub_graph))); - return result; - } + const int number_of_ort_nodes = graph.NumberOfNodes(); + const std::vector& node_index = graph.GetNodesInTopologicalOrder(1 /*priority-based topological sort*/); // Generate unique kernel name for TRT graph HashValue model_hash = TRTGenerateId(graph, std::to_string(trt_version_), std::to_string(cuda_version_)); - // Get supported node list from TensorRT parser - const int number_of_ort_nodes = graph.NumberOfNodes(); + // If there are "EPContext" contrib op nodes, it means TRT EP can fetch the precompiled engine info from the node and + // load the engine directly without having to go through the processes of graph proto reconstruction, calling TRT + // parser and engine compilation. So, simply return subgraphs consists of single ep context nodes here. + int subgraph_idx = 0; + for (size_t node_idx : node_index) { + const auto& node = graph.GetNode(node_idx); + const bool is_context_node = node && !node->OpType().empty() && node->OpType() == EPCONTEXT_OP; + if (is_context_node) { + SubGraph_t supported_node_vector(std::make_pair(std::vector{node_idx}, true)); + std::unique_ptr sub_graph = GetSubGraph(supported_node_vector, graph, model_hash, subgraph_idx++); + + result.push_back(ComputeCapability::Create(std::move(sub_graph))); + } + } + // return early if context nodes where found + if (!result.empty()) { + return result; + } + + // For regular ONNX nodes, get supported node list from TensorRT parser + std::vector nodes_vector(number_of_ort_nodes); std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0); @@ -1741,7 +1820,6 @@ NvExecutionProvider::GetCapability(const GraphViewer& graph, auto exclude_ops_set = get_exclude_ops_set(op_types_to_exclude_); SubGraphCollection_t parser_nodes_vector, supported_nodes_vector; - const std::vector& node_index = graph.GetNodesInTopologicalOrder(1 /*priority-based topological sort*/); bool new_subgraph = true; /* Iterate all the nodes and exclude the node if: @@ -1932,14 +2010,16 @@ NvExecutionProvider::GetCapability(const GraphViewer& graph, */ common::Status NvExecutionProvider::RefitEngine(std::string onnx_model_filename, std::string& onnx_model_folder_path, - std::string& weight_stripped_engine_cath_path, bool path_check, const void* onnx_model_bytestream, size_t onnx_model_bytestream_size, + const void* onnx_external_data_bytestream, + size_t onnx_external_data_bytestream_size, nvinfer1::ICudaEngine* trt_engine, - bool serialize_refitted_engine, bool detailed_build_log) { bool refit_from_file = onnx_model_bytestream == nullptr && onnx_model_bytestream_size == 0; + bool refit_with_external_data = onnx_external_data_bytestream != nullptr && onnx_external_data_bytestream_size != 0; + bool refit_complete = false; std::filesystem::path onnx_model_path{onnx_model_folder_path}; if (refit_from_file) { if (!onnx_model_filename.empty()) { @@ -1976,34 +2056,145 @@ common::Status NvExecutionProvider::RefitEngine(std::string onnx_model_filename, auto refitter = std::unique_ptr(nvinfer1::createInferRefitter(*trt_engine, trt_logger)); auto parser_refitter = std::unique_ptr( nvonnxparser::createParserRefitter(*refitter, trt_logger)); - if (refit_from_file) { - LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Refitting from file on disk: " << onnx_model_path.string(); - if (!parser_refitter->refitFromFile(onnx_model_path.string().c_str())) { + + // New refit APIs + if (refit_with_external_data) { +#if TRT_MAJOR_RTX > 1 || TRT_MINOR_RTX >= 1 + // A valid model bytestream must be passed. + if (refit_from_file) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "Nv EP's IParserRefitter could not refit deserialized weight-stripped engine with weights contained in: " + onnx_model_path.string()); + "NvTensorRTRTX EP's refit with external data must be called with a valid ONNX model bytestream"); } - } else { - LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Refitting from byte array"; - if (!parser_refitter->refitFromBytes(onnx_model_bytestream, onnx_model_bytestream_size)) { + + if (!parser_refitter->loadModelProto(onnx_model_bytestream, onnx_model_bytestream_size, nullptr)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "NvTensorRTRTX EP's IParserRefitter could not load model from provided onnx_model_bytestream"); + } + + // Extract weight information from the Refitter. + int required_weights = refitter->getAllWeights(0, nullptr); + std::vector refit_names_prealocated(required_weights); + refitter->getAllWeights(required_weights, refit_names_prealocated.data()); + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Refitter requires " << required_weights << " weights"; + std::unordered_set refit_names(std::make_move_iterator(refit_names_prealocated.begin()), + std::make_move_iterator(refit_names_prealocated.end())); + + // Vectors to keep track of data pointers. + std::vector names; + names.reserve(required_weights); + std::vector bytes; + bytes.reserve(required_weights); + std::vector sizes; + sizes.reserve(required_weights); + + auto onnx_model = ModelProto::Create(); + TensorProtos* allInitializers_byte_stream; + + // Reconstruct onnx model view. + const auto onnx_model_view = std::string((const char*)onnx_model_bytestream, + onnx_model_bytestream_size); + if (!onnx_model->ParseFromString(onnx_model_view)) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "Nv EP's IParserRefitter could not refit deserialized weight-stripped engine with weights contained in the provided bytestraem"); + "The provided ONNX bytestream to refit could not be parsed."); + } + + // Extract graph and initializer information. + auto const& graph = onnx_model->mutable_graph(); + allInitializers_byte_stream = graph->mutable_initializer(); + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Initializers that were found " << allInitializers_byte_stream->size(); + + // Loop through all initializers + int missing_initializer_data = 0; + for (int initializer_idx = 0; initializer_idx < allInitializers_byte_stream->size(); ++initializer_idx) { + auto& proto = allInitializers_byte_stream->at(initializer_idx); + auto& proto_name = proto.name(); + if (refit_names.find(proto_name) != refit_names.end()) { + if (proto.has_data_location()) { + if (proto.data_location() == TensorProto_DataLocation_EXTERNAL) { + // Default values for reading into external_data blob. + int64_t offset = 0; + size_t length = 0; + auto external_data = proto.mutable_external_data(); + const std::string kOffset = "offset", kLength = "length"; + for (int entry_idx = 0; entry_idx < external_data->size(); ++entry_idx) { + auto current_key = external_data->at(entry_idx).mutable_key(); + auto current_value = external_data->at(entry_idx).mutable_value(); + if (*current_key == kOffset && !current_value->empty()) { + offset = std::stoll(*current_value); + } else if (*current_key == kLength && !current_value->empty()) { + length = std::stoul(*current_value); + } + } + names.push_back(proto.name()); + bytes.push_back(static_cast(onnx_external_data_bytestream) + offset); + sizes.push_back(length); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "[NvTensorRTRTX EP] Proto: " + proto_name + " expected to have external datalocation, but default datalocation was provided instead."); + } + } else if (proto.has_raw_data()) { + auto& raw_data = proto.raw_data(); + names.push_back(proto.name()); + bytes.push_back(raw_data.c_str()); + sizes.push_back(raw_data.size()); + } else { + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Proto: " + proto_name + " has no raw nor external data."; + ++missing_initializer_data; + } + } else { + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Initializer with name: " << proto_name << " was not marked as refittable"; + } + } + if (missing_initializer_data) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "[NvTensorRTRTX EP] RefitEngine is missing " + std::to_string(missing_initializer_data) + " initializers."); + } + + // Load extracted initializers into the parser + if (!names.empty()) { + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Number of initializers submitted to refitter " << names.size(); + for (size_t i = 0; i < names.size(); i++) { + bool refloadInit = parser_refitter->loadInitializer(names[i].c_str(), bytes[i], sizes[i]); + if (!refloadInit) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "NvTensorRTRTX EP's IParserRefitter could not refit deserialized weight-stripped engine with weights contained in the provided bytestream"); + } + } + } + // Perform refit. + if (!parser_refitter->refitModelProto()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "NvTensorRTRTX EP's IParserRefitter refitModelProto() failed with the provided external data bytestream."); + } + refit_complete = true; +#else + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "Refit with external data is only supported on TensorRT RTX 1.1.x.x and above."); +#endif + } + + // If new refit flow was not completed, then fallback to refit_from_file. + if (!refit_complete) { + if (refit_from_file) { + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Refitting from file on disk: " << onnx_model_path.string(); + if (!parser_refitter->refitFromFile(onnx_model_path.string().c_str())) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "NvTensorRTRTX EP's IParserRefitter could not refit deserialized weight-stripped engine with weights contained in: " + onnx_model_path.string()); + } + } else { + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Refitting from byte array"; + if (!parser_refitter->refitFromBytes(onnx_model_bytestream, onnx_model_bytestream_size)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "NvTensorRTRTX EP's IParserRefitter could not refit deserialized weight-stripped engine with weights contained in the provided bytestream"); + } } } if (refitter->refitCudaEngine()) { LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Successfully refitted the weight-stripped engine."; } else { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "Nv EP's IRefitter could not refit deserialized weight-stripped engine with weights contained in: " + onnx_model_path.string()); + "NvTensorRTRTX EP's IRefitter could not refit deserialized weight-stripped engine with weights contained in: " + onnx_model_path.string()); } - // serialize the refitted engine to disk - if (serialize_refitted_engine) { - std::string refitted_engine_cache = GetWeightRefittedEnginePath(weight_stripped_engine_cath_path); - nvinfer1::IHostMemory* serialized_engine = trt_engine->serialize(); - std::ofstream engine_file(refitted_engine_cache, std::ios::binary | std::ios::out); - engine_file.write(reinterpret_cast(serialized_engine->data()), serialized_engine->size()); - LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Serialize the refitted engine to " << refitted_engine_cache; - } return Status::OK(); } @@ -2029,8 +2220,10 @@ common::Status NvExecutionProvider::Compile(const std::vector } Status status; - if (GraphHasCtxNode(graph_body_viewer)) { + size_t node_idx = 0; + if (GraphHasCtxNode(graph_body_viewer, node_idx)) { status = CreateNodeComputeInfoFromPrecompiledEngine(graph_body_viewer, + node_idx, fused_node, input_map, output_map, @@ -2135,6 +2328,16 @@ static bool IsIOBindingRequired(TRTState* const trt_state, const Ort::KernelCont return require_io_binding; } +const InlinedVector NvExecutionProvider::GetEpContextNodes() const { + InlinedVector ep_context_nodes; + if (ep_context_model_) { + for (auto* node : ep_context_model_->MainGraph().Nodes()) { + ep_context_nodes.push_back(node); + } + } + return ep_context_nodes; +} + Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& graph_body_viewer, const Node& fused_node, std::unordered_map& input_map, @@ -2144,11 +2347,38 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr auto model = graph_body_viewer.CreateModel(*GetLogger()); auto model_proto = model->ToProto(); + // exclude weights if external + std::vector userWeights; + if (use_external_data_initializer_) { + auto c_api = Ort::GetApi(); + const InitializedTensorSet& allInitializers = graph_body_viewer.GetAllInitializedTensors(); + userWeights.reserve(allInitializers.size()); + for (auto& entry : allInitializers) { + OrtValue initializer_value; + auto* tp = entry.second; + if (utils::HasRawData(*tp)) { + userWeights.emplace_back(TensorrtUserWeights(tp->name(), tp->raw_data().data(), tp->raw_data().size())); + } else if (graph_body_viewer.GetOrtValueInitializer(tp->name(), initializer_value)) { + // the initializer was marked as external data by the ORT graph at load time since it was provided in memory + size_t size = 0; + const void* ptr = nullptr; + c_api.GetTensorSizeInBytes(&initializer_value, &size); + c_api.GetTensorData(&initializer_value, &ptr); + userWeights.emplace_back(tp->name(), ptr, size); + } else if (utils::HasExternalDataInMemory(*tp)) { + // only copy and take ownership of the data if none of the above conditions are met + std::unique_ptr full_init; + ORT_THROW_IF_ERROR(utils::GetTensorProtoWithDataIfInMemory(*tp, full_init)); + userWeights.emplace_back(TensorrtUserWeights(std::move(full_init->name()), std::move(full_init->raw_data()))); + } + } + } + // ORT's default topological sort is using reversed DFS. // When creating model proto from graph viewer, let ORT use priority-based topological sort based on node index. // The reason is, in some cases, for example ResNet50, using default topological sort will end up with generating // the model proto that has different node ordering compared to original onnx model. - graph_body_viewer.ToProto(*model_proto->mutable_graph(), true, true, 1 /*priority-based topological sort*/); + graph_body_viewer.ToProto(*model_proto->mutable_graph(), true, true, 1 /*priority-based topological sort*/, !use_external_data_initializer_ /*include raw initializers*/); model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); std::string string_buf; model_proto->SerializeToString(string_buf); @@ -2165,7 +2395,21 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr auto trt_network = std::unique_ptr(trt_builder->createNetworkV2(network_flags)); auto trt_config = std::unique_ptr(trt_builder->createBuilderConfig()); auto trt_parser = tensorrt_ptr::unique_pointer(nvonnxparser::createParser(*trt_network, trt_logger)); - trt_parser->parse(string_buf.data(), string_buf.size(), model_path_); + + if (use_external_data_initializer_) { +#if TRT_MAJOR_RTX > 1 || TRT_MINOR_RTX >= 1 + trt_parser->loadModelProto(string_buf.data(), string_buf.size(), model_path_); + for (auto const& userWeight : userWeights) { + trt_parser->loadInitializer(userWeight.Name(), userWeight.Data(), userWeight.Size()); + } + trt_parser->parseModelProto(); +#else + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "'nv_use_external_data_initializer' is only supported on TensorRT RTX 1.1.x.x and above."); +#endif + } else { + trt_parser->parse(string_buf.data(), string_buf.size(), model_path_); + } + if (max_workspace_size_ > 0) { trt_config->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, max_workspace_size_); } @@ -2329,7 +2573,6 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr ; } } - std::string trt_node_name_with_precision = fused_node.Name() + "_strong_typed"; // enable sparse weights if (sparsity_enable_) { @@ -2358,32 +2601,6 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr std::unique_ptr trt_engine; std::unique_ptr trt_context; - std::string cache_path = ""; - std::string cache_suffix = ""; - // Customize cache prefix if assigned - if (!cache_prefix_.empty()) { - // Generate cache suffix in case user would like to customize cache prefix - cache_suffix = "_" + GetCacheSuffix(fused_node.Name(), trt_node_name_with_precision); - cache_path = GetCachePath(cache_path_, cache_prefix_) + cache_suffix; - } else { - cache_path = GetCachePath(cache_path_, trt_node_name_with_precision); - } - - // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache - // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity - const std::string cache_path_prefix = cache_path; - std::string engine_cache_path = cache_path_prefix + ".engine"; - const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted"; - const std::string profile_cache_path = cache_path_prefix + ".profile"; - - // If weight-stripped engine is enabled and refitted engine cache is not present, - // TRT EP will use the engine cache with ".stripped.engine" appended to the end. - const std::filesystem::path engine_cache_fs_path = engine_cache_path; - if (weight_stripped_engine_enable_ && !std::filesystem::exists(engine_cache_fs_path)) { - engine_cache_path = cache_path_prefix + ".stripped.engine"; - weight_stripped_engine_refit_ = true; - } - // Generate file name for dumping ep context model if (dump_ep_context_model_ && ctx_model_path_.empty()) { ctx_model_path_ = GetCtxModelPath(ep_context_file_path_, model_path_); @@ -2398,49 +2615,63 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr std::unique_ptr serialized_engine{trt_builder->buildSerializedNetwork(*trt_network, *trt_config)}; if (serialized_engine == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "Nv EP failed to create engine from network for fused node: " + fused_node.Name()); + "NvTensorRTRTX EP failed to create engine from network for fused node: " + fused_node.Name()); } trt_engine = std::unique_ptr(runtime_->deserializeCudaEngine(serialized_engine->data(), serialized_engine->size())); if (trt_engine == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "Nv EP failed to deserialize engine for fused node: " + fused_node.Name()); + "NvTensorRTRTX EP failed to deserialize engine for fused node: " + fused_node.Name()); } if (detailed_build_log_) { auto engine_build_stop = std::chrono::steady_clock::now(); - LOGS_DEFAULT(INFO) << "TensorRT engine build for " << trt_node_name_with_precision << " took: " << std::chrono::duration_cast(engine_build_stop - engine_build_start).count() << "ms" << std::endl; + LOGS_DEFAULT(INFO) << "TensorRT engine build for " << fused_node.Name() << " took: " << std::chrono::duration_cast(engine_build_stop - engine_build_start).count() << "ms" << std::endl; } // dump EP context node model if (dump_ep_context_model_) { // "ep_cache_context" node attribute should be a relative path to context model directory - if (ep_cache_context_attr_.empty()) { - auto cache_file_name = std::filesystem::path(engine_cache_path).filename(); - ep_cache_context_attr_ = std::filesystem::path(engine_cache_relative_path_to_context_model_dir).append(cache_file_name.string()).string(); + + std::string cache_path = ""; + // Customize cache prefix if assigned + if (!cache_prefix_.empty()) { + // Generate cache suffix in case user would like to customize cache prefix + cache_path = GetCachePath(cache_path_, cache_prefix_) + fused_node.Name() + ".engine"; + ; + } else { + cache_path = GetCachePath(cache_path_, fused_node.Name()) + ".engine"; + ; + } + // NV TRT EP per default generates hardware compatible engines for any RTX device with compute capability > 80 + std::string compute_capability_hw_compat = "80+"; + if (!ep_context_model_) { + ep_context_model_ = Model::Create("nv_trt_rtx_ep_context_model", false, *GetLogger()); + } + + auto status = CreateCtxNode(graph_body_viewer, + ep_context_model_->MainGraph(), + cache_path, + reinterpret_cast(serialized_engine->data()), + serialized_engine->size(), + ep_context_embed_mode_, + compute_capability_hw_compat, + model_path_, + fused_node.Name(), + trt_version_); + if (status != Status::OK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); } - std::string compute_capability_hw_compat = compute_capability_ + "+"; - std::unique_ptr model_proto{CreateCtxModel(graph_body_viewer, - ep_cache_context_attr_, - reinterpret_cast(serialized_engine->data()), - serialized_engine->size(), - ep_context_embed_mode_, - compute_capability_hw_compat, - model_path_, - GetLogger())}; - DumpCtxModel(model_proto.get(), ctx_model_path_); } } if (weight_stripped_engine_refit_) { LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Refit engine from main ONNX file after engine build"; - char* onnx = string_buf.data(); - size_t onnx_size = string_buf.size(); auto status = RefitEngine(model_path_, onnx_model_folder_path_, - engine_cache_path, false /* path check for security */, - onnx, - onnx_size, + onnx_model_bytestream_, + onnx_model_bytestream_size_, + onnx_external_data_bytestream_, + onnx_external_data_bytestream_size_, trt_engine.get(), - false /* serialize refitted engine to disk */, detailed_build_log_); if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); @@ -2453,7 +2684,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr trt_context = std::unique_ptr(trt_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); if (!trt_context) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "Nv EP could not build execution context for fused node: " + fused_node.Name()); + "NvTensorRTRTX EP could not build execution context for fused node: " + fused_node.Name()); } bool is_dynamic_shape_context = false; @@ -2499,12 +2730,12 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr *p = {context->allocate_func, context->release_func, context->allocator_handle, context->node_name, builder_.get(), &parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name], &networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name], - input_shape_ranges_[context->node_name], &tensorrt_mu_, trt_node_name_with_precision, + input_shape_ranges_[context->node_name], &tensorrt_mu_, engine_cache_enable_, cache_path_, runtime_.get(), profiles_[context->node_name], engine_decryption_enable_, engine_decryption_, engine_encryption_, detailed_build_log_, sparsity_enable_, - auxiliary_streams_, cuda_graph_enable_, is_dynamic_shape_context, cache_prefix_, cache_suffix}; + auxiliary_streams_, cuda_graph_enable_, is_dynamic_shape_context, cache_prefix_}; *state = p.release(); return 0; }; @@ -2552,7 +2783,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr if (multi_profile_enable_ == true) { if (!trt_context->setOptimizationProfileAsync(nv_profile_index_, stream)) - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Nv EP select an optimization profile for the current context failed"); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "NvTensorRTRTX EP select an optimization profile for the current context failed"); } // Check before using trt_engine @@ -2666,7 +2897,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr // Run TRT inference if (!trt_context->enqueueV3(stream)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Nv EP execution context enqueue failed."); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "NvTensorRTRTX EP execution context enqueue failed."); } /* @@ -2743,6 +2974,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr } Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const GraphViewer& graph_body_viewer, + size_t node_idx, const Node& fused_node, std::unordered_map& input_map, std::unordered_map& output_map, @@ -2762,8 +2994,10 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra onnx_model_folder_path_, onnx_model_bytestream_, onnx_model_bytestream_size_, + onnx_external_data_bytestream_, + onnx_external_data_bytestream_size_, detailed_build_log_); - auto status = trt_cache_model_handler.GetEpContextFromGraph(graph_body_viewer); + auto status = trt_cache_model_handler.GetEpContextFromGraph(*graph_body_viewer.GetNode(node_idx)); if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); } @@ -2775,7 +3009,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra trt_context = std::unique_ptr(trt_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); if (!trt_context) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "Nv EP could not build execution context for fused node: " + fused_node.Name()); + "NvTensorRTRTX EP could not build execution context for fused node: " + fused_node.Name()); } bool is_dynamic_shape_context = false; @@ -2980,7 +3214,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra // Run TRT inference if (!trt_context->enqueueV3(stream)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Nv EP execution context enqueue failed."); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "NvTensorRTRTX EP execution context enqueue failed."); } /* diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h index 83b89a2e9d1fb..e3dd38eb837ff 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h @@ -153,6 +153,41 @@ struct TensorParams { } }; +// Data structure to hold user weights when ModelProtos are serialized with external data +class TensorrtUserWeights { + public: + TensorrtUserWeights(const std::string& name, const std::string& data) : name_(name), + data_cpy_(data) { + }; + + TensorrtUserWeights(const std::string& name, const void* data, size_t size) : name_(name), data_(data), size_(size) { + }; + + const char* Name() const { + return name_.c_str(); + }; + + const void* Data() const { + if (!data_cpy_.empty()) { + return data_cpy_.data(); + } + return data_; + } + + int64_t Size() const { + if (!data_cpy_.empty()) { + return static_cast(data_cpy_.size()); + } + return static_cast(size_); + } + + private: + std::string name_{}; + std::string data_cpy_{}; + void const* data_; + size_t size_; +}; + // Information to construct kernel function state. struct TensorrtFuncState { AllocateFunc test_allocate_func = nullptr; @@ -168,7 +203,6 @@ struct TensorrtFuncState { std::vector> output_info; std::unordered_map>>> input_shape_ranges; std::mutex* tensorrt_mu_ptr = nullptr; - std::string trt_node_name_with_precision; bool engine_cache_enable = false; std::string engine_cache_path; nvinfer1::IRuntime* runtime = nullptr; @@ -183,6 +217,7 @@ struct TensorrtFuncState { bool is_dynamic_shape = false; std::string cache_prefix; std::string cache_suffix; + // runtime parameters std::vector> scratch_buffers; std::vector input_tensors; std::vector output_tensors; @@ -204,6 +239,7 @@ struct TensorrtShortFuncState { std::vector> output_info; std::mutex* tensorrt_mu_ptr = nullptr; bool is_dynamic_shape = false; + // runtime parameters std::vector> scratch_buffers; std::vector input_tensors; std::vector output_tensors; @@ -275,14 +311,16 @@ class NvExecutionProvider : public IExecutionProvider { static common::Status RefitEngine(std::string onnx_model_filename, std::string& onnx_model_folder_path, - std::string& weight_stripped_engine_cath_path, bool path_check, const void* onnx_model_bytestream, size_t onnx_model_bytestream_size, + const void* onnx_external_data_bytestream, + size_t onnx_external_data_bytestream_size, nvinfer1::ICudaEngine* trt_engine, - bool serialize_refitted_engine, bool detailed_build_log); + const InlinedVector GetEpContextNodes() const override; + private: mutable NvExecutionProviderInfo info_; bool external_stream_ = false; @@ -299,6 +337,9 @@ class NvExecutionProvider : public IExecutionProvider { std::string onnx_model_folder_path_; const void* onnx_model_bytestream_; size_t onnx_model_bytestream_size_; + bool use_external_data_initializer_ = false; + const void* onnx_external_data_bytestream_ = nullptr; + size_t onnx_external_data_bytestream_size_ = 0; bool sparsity_enable_ = false; int auxiliary_streams_ = -1; std::string cache_path_, engine_decryption_lib_path_; @@ -317,6 +358,7 @@ class NvExecutionProvider : public IExecutionProvider { std::string cache_prefix_; std::string op_types_to_exclude_; int nv_profile_index_ = 0; + std::unique_ptr ep_context_model_; // The format is as for TENSORRT_VERSION: (MAJOR * 100 + MINOR) * 100 + PATCH int32_t trt_version_; @@ -331,7 +373,6 @@ class NvExecutionProvider : public IExecutionProvider { std::string ep_context_file_path_; int ep_context_embed_mode_ = 0; std::string ctx_model_path_; - std::string ep_cache_context_attr_; std::string engine_cache_relative_path_to_context_model_dir; std::unordered_set control_flow_op_set_ = {"If", "Loop", "Scan"}; @@ -550,6 +591,7 @@ class NvExecutionProvider : public IExecutionProvider { * going through the time-consuming processes of model parsing and engine building. */ Status CreateNodeComputeInfoFromPrecompiledEngine(const GraphViewer& graph_body_viewer, + size_t node_idx, const Node& fused_node, std::unordered_map& input_map, std::unordered_map& output_map, diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc index f90bf24ef4975..527a37f6c2b57 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc @@ -17,6 +17,7 @@ NvExecutionProviderInfo NvExecutionProviderInfo::FromProviderOptions(const Provi NvExecutionProviderInfo info{}; void* user_compute_stream = nullptr; void* onnx_bytestream = nullptr; + void* external_data_bytestream = nullptr; ORT_THROW_IF_ERROR( ProviderOptionsParser{} .AddValueParser( @@ -48,21 +49,14 @@ NvExecutionProviderInfo NvExecutionProviderInfo::FromProviderOptions(const Provi .AddAssignmentToReference(nv::provider_option_names::kProfilesMaxShapes, info.profile_max_shapes) .AddAssignmentToReference(nv::provider_option_names::kProfilesOptShapes, info.profile_opt_shapes) .AddAssignmentToReference(nv::provider_option_names::kCudaGraphEnable, info.cuda_graph_enable) + .AddAssignmentToReference(nv::provider_option_names::kUseExternalDataInitializer, info.use_external_data_initializer) .AddAssignmentToReference(nv::provider_option_names::kMultiProfileEnable, info.multi_profile_enable) - .AddValueParser( - nv::provider_option_names::kONNXBytestream, - [&onnx_bytestream](const std::string& value_str) -> Status { - size_t address; - ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); - onnx_bytestream = reinterpret_cast(address); - return Status::OK(); - }) - .AddAssignmentToReference(nv::provider_option_names::kONNXBytestreamSize, info.onnx_bytestream_size) .Parse(options)); // add new provider option here. info.user_compute_stream = user_compute_stream; info.has_user_compute_stream = (user_compute_stream != nullptr); info.onnx_bytestream = onnx_bytestream; + info.external_data_bytestream = external_data_bytestream; // EP context settings // when EP context is enabled, default is to embed the engine in the context model @@ -73,7 +67,8 @@ NvExecutionProviderInfo NvExecutionProviderInfo::FromProviderOptions(const Provi info.dump_ep_context_model = false; } else if (ep_context_enable == "1") { info.dump_ep_context_model = true; - info.weight_stripped_engine_enable = true; + // We want to reenable weightless engines as soon constant initializers are supported as inputs + info.weight_stripped_engine_enable = false; } else { ORT_THROW("Invalid ", kOrtSessionOptionEpContextEnable, " must 0 or 1"); } @@ -110,9 +105,7 @@ ProviderOptions NvExecutionProviderInfo::ToProviderOptions(const NvExecutionProv {nv::provider_option_names::kProfilesMaxShapes, MakeStringWithClassicLocale(info.profile_max_shapes)}, {nv::provider_option_names::kProfilesOptShapes, MakeStringWithClassicLocale(info.profile_opt_shapes)}, {nv::provider_option_names::kCudaGraphEnable, MakeStringWithClassicLocale(info.cuda_graph_enable)}, - {nv::provider_option_names::kONNXBytestream, MakeStringWithClassicLocale(info.onnx_bytestream)}, - {nv::provider_option_names::kONNXBytestreamSize, MakeStringWithClassicLocale(info.onnx_bytestream_size)}, - }; + {nv::provider_option_names::kUseExternalDataInitializer, MakeStringWithClassicLocale(info.use_external_data_initializer)}}; return options; } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h index 4d6c6fe116076..b826925361b05 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h @@ -31,6 +31,9 @@ struct NvExecutionProviderInfo { std::string onnx_model_folder_path{""}; const void* onnx_bytestream{nullptr}; size_t onnx_bytestream_size{0}; + bool use_external_data_initializer{false}; + const void* external_data_bytestream{nullptr}; + size_t external_data_bytestream_size{0}; bool engine_decryption_enable{false}; std::string engine_decryption_lib_path{""}; bool force_sequential_engine_build{false}; diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_utils.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_utils.h index ea586ba445ba2..c564fe65c3d5c 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_utils.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_utils.h @@ -386,22 +386,11 @@ std::string GetCachePath(const std::string& root, const std::string& name) { * Get compute capability * */ -std::string GetComputeCapacity(const cudaDeviceProp& prop) { +std::string GetComputeCapability(const cudaDeviceProp& prop) { const std::string compute_capability = std::to_string(prop.major * 10 + prop.minor); return compute_capability; } -/* - * Get Timing by compute capability - * - */ -std::string GetTimingCachePath(const std::string& root, std::string& compute_cap) { - // append compute capability of the GPU as this invalidates the cache and TRT will throw when loading the cache - const std::string timing_cache_name = "NvExecutionProvider_cache_sm" + - compute_cap + ".timing"; - return GetCachePath(root, timing_cache_name); -} - /* * Get cache by type * diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc index 21d964b0c341f..1f34a0f25877d 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc @@ -20,10 +20,11 @@ extern TensorrtLogger& GetTensorrtLogger(bool verbose_log); * * Note: Please see more details about "EPContext" contrib op in contrib_defs.cc */ -bool GraphHasCtxNode(const GraphViewer& graph_viewer) { +bool GraphHasCtxNode(const GraphViewer& graph_viewer, size_t& node_idx) { for (int i = 0; i < graph_viewer.MaxNodeIndex(); ++i) { auto node = graph_viewer.GetNode(i); if (node != nullptr && node->OpType() == EPCONTEXT_OP) { + node_idx = i; return true; } } @@ -63,19 +64,18 @@ void UpdateCtxNodeModelEngineContext(ONNX_NAMESPACE::ModelProto* model_proto, } /* - * Create "EP context node" model where engine information is embedded + * Create EP context node where engine information is embedded */ -ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer, - const std::string engine_cache_path, - char* engine_data, - size_t size, - const int64_t embed_mode, - const std::string compute_capability, - const std::string onnx_model_path, - const logging::Logger* logger) { - auto model_build = graph_viewer.CreateModel(*logger); - auto& graph_build = model_build->MainGraph(); - +Status CreateCtxNode(const GraphViewer& graph_viewer, + Graph& graph_build, + const std::string engine_cache_path, + char* engine_data, + size_t size, + const int64_t embed_mode, + const std::string compute_capability, + const std::string onnx_model_path, + const std::string& ep_context_node_name, + int32_t trt_version) { // Get graph inputs and outputs std::vector inputs, outputs; for (auto input : graph_viewer.GetInputs()) { @@ -89,55 +89,71 @@ ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer, } // Create EP context node attributes - auto attr_0 = ONNX_NAMESPACE::AttributeProto::Create(); // embed_mode - auto attr_1 = ONNX_NAMESPACE::AttributeProto::Create(); // ep_cache_context - auto attr_2 = ONNX_NAMESPACE::AttributeProto::Create(); // hardware_architecture - auto attr_3 = ONNX_NAMESPACE::AttributeProto::Create(); // onnx_model_filename + auto attr_embed_mode = ONNX_NAMESPACE::AttributeProto::Create(); + auto attr_main_context = ONNX_NAMESPACE::AttributeProto::Create(); + auto attr_ep_cache_context = ONNX_NAMESPACE::AttributeProto::Create(); + auto attr_sdk_version = ONNX_NAMESPACE::AttributeProto::Create(); + auto attr_hw_architecture = ONNX_NAMESPACE::AttributeProto::Create(); + auto attr_onnx_filename = ONNX_NAMESPACE::AttributeProto::Create(); + auto attr_partition_name = ONNX_NAMESPACE::AttributeProto::Create(); std::string engine_data_str = ""; - attr_0->set_name(EMBED_MODE); - attr_0->set_type(onnx::AttributeProto_AttributeType_INT); - attr_0->set_i(embed_mode); - attr_1->set_name(EP_CACHE_CONTEXT); - attr_1->set_type(onnx::AttributeProto_AttributeType_STRING); + attr_main_context->set_name(MAIN_CONTEXT); + attr_main_context->set_type(onnx::AttributeProto_AttributeType_INT); + attr_main_context->set_i(0); // we do not support a main context node but each has it's own engine payload + attr_embed_mode->set_name(EMBED_MODE); + attr_embed_mode->set_type(onnx::AttributeProto_AttributeType_INT); + attr_embed_mode->set_i(embed_mode); + attr_ep_cache_context->set_name(EP_CACHE_CONTEXT); + attr_ep_cache_context->set_type(onnx::AttributeProto_AttributeType_STRING); if (embed_mode) { if (size > 0) { engine_data_str.assign(engine_data, size); } - attr_1->set_s(engine_data_str); - // TODO(maximilianm) we might want to disable this warning as we only support weightless engines that are really small - // the reason we had this was that the field will be hashed and storing a large bytestream has significant overhead - LOGS_DEFAULT(WARNING) << EPCONTEXT_WARNING; + attr_ep_cache_context->set_s(engine_data_str); } else { - attr_1->set_s(engine_cache_path); + std::string engine_cache_filename = std::filesystem::path(engine_cache_path).filename().string(); + attr_ep_cache_context->set_s(engine_cache_filename); + std::fstream engine_cache_file(engine_cache_path, std::ios::binary | std::ios::out); + if (engine_cache_file.is_open()) { + engine_cache_file.write(engine_data, size); + engine_cache_file.close(); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "NvTensorRTRTX EP could not write cache to ", engine_cache_path); + } } - attr_2->set_name(COMPUTE_CAPABILITY); - attr_2->set_type(onnx::AttributeProto_AttributeType_STRING); - attr_2->set_s(compute_capability); - attr_3->set_name(ONNX_MODEL_FILENAME); - attr_3->set_type(onnx::AttributeProto_AttributeType_STRING); - attr_3->set_s(std::filesystem::path(onnx_model_path).filename().string()); + + attr_hw_architecture->set_name(COMPUTE_CAPABILITY); + attr_hw_architecture->set_type(onnx::AttributeProto_AttributeType_STRING); + attr_hw_architecture->set_s(compute_capability); + + attr_partition_name->set_name(PARTITION_NAME); + attr_partition_name->set_type(onnx::AttributeProto_AttributeType_STRING); + attr_partition_name->set_s(ep_context_node_name); // includes hash of the subgraph that was built + + attr_onnx_filename->set_name(ONNX_MODEL_FILENAME); + attr_onnx_filename->set_type(onnx::AttributeProto_AttributeType_STRING); + attr_onnx_filename->set_s(std::filesystem::path(onnx_model_path).filename().string()); + + attr_sdk_version->set_name(SDK_VERSION); + attr_sdk_version->set_type(onnx::AttributeProto_AttributeType_STRING); + attr_sdk_version->set_s(std::to_string(trt_version)); auto node_attributes = ONNX_NAMESPACE::NodeAttributes::Create(); constexpr int num_attributes = 4; node_attributes->reserve(num_attributes); - node_attributes->emplace(EMBED_MODE, *attr_0); - node_attributes->emplace(EP_CACHE_CONTEXT, *attr_1); - node_attributes->emplace(COMPUTE_CAPABILITY, *attr_2); - node_attributes->emplace(ONNX_MODEL_FILENAME, *attr_3); + node_attributes->emplace(MAIN_CONTEXT, *attr_main_context); + node_attributes->emplace(EMBED_MODE, *attr_embed_mode); + node_attributes->emplace(EP_CACHE_CONTEXT, *attr_ep_cache_context); + node_attributes->emplace(COMPUTE_CAPABILITY, *attr_hw_architecture); + node_attributes->emplace(PARTITION_NAME, *attr_partition_name); + node_attributes->emplace(ONNX_MODEL_FILENAME, *attr_onnx_filename); + node_attributes->emplace(SDK_VERSION, *attr_sdk_version); // Create EP context node - graph_build.AddNode(EPCONTEXT_OP, EPCONTEXT_OP, "", inputs, outputs, node_attributes.get(), EPCONTEXT_OP_DOMAIN); + graph_build.AddNode(ep_context_node_name, EPCONTEXT_OP, "", inputs, outputs, node_attributes.get(), EPCONTEXT_OP_DOMAIN); ORT_ENFORCE(graph_build.Resolve().IsOK()); - - // Serialize modelproto to string - auto new_graph_viewer = graph_build.CreateGraphViewer(); - auto& metadata = graph_viewer.GetGraph().GetModel().MetaData(); - auto model = new_graph_viewer->CreateModel(*logger, metadata); - auto model_proto = model->ToProto(); - new_graph_viewer->ToProto(*model_proto->mutable_graph(), true, true); - model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); - - return model_proto.release(); + return Status::OK(); } /* @@ -206,17 +222,6 @@ std::string GetCtxModelPath(const std::string& ep_context_file_path, return ctx_model_path; } -/* - * Dump "EP context" model - * - */ -void DumpCtxModel(ONNX_NAMESPACE::ModelProto* model_proto, - const std::string& ctx_model_path) { - std::fstream dump(ctx_model_path, std::ios::out | std::ios::trunc | std::ios::binary); - model_proto->SerializeToOstream(dump); - LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Dumped " + ctx_model_path; -} - bool IsAbsolutePath(const std::string& path_string) { #ifdef _WIN32 onnxruntime::PathString ort_path_string = onnxruntime::ToPathString(path_string); @@ -248,38 +253,12 @@ bool IsRelativePathToParentPath(const std::string& path_string) { #endif } -/* - * Get the weight-refitted engine cache path from a weight-stripped engine cache path - * - * Weight-stipped engine: - * An engine with weights stripped and its size is smaller than a regualr engine. - * The cache name of weight-stripped engine is NvExecutionProvider_TRTKernel_XXXXX.stripped.engine - * - * Weight-refitted engine: - * An engine that its weights have been refitted and it's simply a regular engine. - * The cache name of weight-refitted engine is NvExecutionProvider_TRTKernel_XXXXX.engine - */ -std::string GetWeightRefittedEnginePath(std::string stripped_engine_cache) { - std::filesystem::path stripped_engine_cache_path(stripped_engine_cache); - std::string refitted_engine_cache_path = stripped_engine_cache_path.stem().stem().string() + ".engine"; - return refitted_engine_cache_path; -} - -bool IsWeightStrippedEngineCache(std::filesystem::path& engine_cache_path) { - // The weight-stripped engine cache has the naming of xxx.stripped.engine - return engine_cache_path.stem().extension().string() == ".stripped"; -} - -Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph_viewer) { - if (!ValidateEPCtxNode(graph_viewer)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "It's not a valid EP Context node"); - } - auto node = graph_viewer.GetNode(0); - auto& attrs = node->GetAttributes(); +Status TensorRTCacheModelHandler::GetEpContextFromGraph(const Node& node) { + auto& attrs = node.GetAttributes(); const int64_t embed_mode = attrs.at(EMBED_MODE).i(); // Only make path checks if model not provided as byte buffer - bool make_secure_path_checks = !GetModelPath(graph_viewer).empty(); + bool make_secure_path_checks = ep_context_model_path_.empty(); if (embed_mode) { // Get engine from byte stream. @@ -294,15 +273,14 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph if (weight_stripped_engine_refit_) { const std::string onnx_model_filename = attrs.at(ONNX_MODEL_FILENAME).s(); - std::string placeholder; auto status = NvExecutionProvider::RefitEngine(onnx_model_filename, onnx_model_folder_path_, - placeholder, make_secure_path_checks, onnx_model_bytestream_, onnx_model_bytestream_size_, + onnx_external_data_bytestream_, + onnx_external_data_bytestream_size_, (*trt_engine_).get(), - false /* serialize refitted engine to disk */, detailed_build_log_); if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); @@ -327,21 +305,6 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph auto engine_cache_path = ctx_model_dir.append(cache_path); LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] GetEpContextFromGraph engine_cache_path: " + engine_cache_path.string(); - // If it's a weight-stripped engine cache, it needs to be refitted even though the refit flag is not enabled - if (!weight_stripped_engine_refit_) { - weight_stripped_engine_refit_ = IsWeightStrippedEngineCache(engine_cache_path); - } - - // If the serialized refitted engine is present, use it directly without refitting the engine again - if (weight_stripped_engine_refit_) { - const std::filesystem::path refitted_engine_cache_path = GetWeightRefittedEnginePath(engine_cache_path.string()); - if (std::filesystem::exists(refitted_engine_cache_path)) { - LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] " + refitted_engine_cache_path.string() + " exists."; - engine_cache_path = refitted_engine_cache_path.string(); - weight_stripped_engine_refit_ = false; - } - } - if (!std::filesystem::exists(engine_cache_path)) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "Nv EP can't find engine cache: " + engine_cache_path.string() + @@ -366,12 +329,12 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph std::string weight_stripped_engine_cache = engine_cache_path.string(); auto status = NvExecutionProvider::RefitEngine(onnx_model_filename, onnx_model_folder_path_, - weight_stripped_engine_cache, make_secure_path_checks, onnx_model_bytestream_, onnx_model_bytestream_size_, + onnx_external_data_bytestream_, + onnx_external_data_bytestream_size_, (*trt_engine_).get(), - true /* serialize refitted engine to disk */, detailed_build_log_); if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); @@ -384,11 +347,8 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph /* * The sanity check for EP context contrib op. */ -bool TensorRTCacheModelHandler::ValidateEPCtxNode(const GraphViewer& graph_viewer) { - assert(graph_viewer.NumberOfNodes() == 1); - assert(graph_viewer.GetNode(0)->OpType() == EPCONTEXT_OP); - auto node = graph_viewer.GetNode(0); - auto& attrs = node->GetAttributes(); +bool TensorRTCacheModelHandler::ValidateEPCtxNode(const Node& node) { + auto& attrs = node.GetAttributes(); // Show the warning if compute capability is not matched if (attrs.count(COMPUTE_CAPABILITY) > 0) { @@ -413,7 +373,7 @@ bool TensorRTCacheModelHandler::ValidateEPCtxNode(const GraphViewer& graph_viewe const int64_t embed_mode = attrs.at(EMBED_MODE).i(); if (embed_mode == 1) { // engine binary data - LOGS_DEFAULT(WARNING) << EPCONTEXT_WARNING; + // LOGS_DEFAULT(WARNING) << EPCONTEXT_WARNING; } return true; diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.h b/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.h index f0a05c42414e5..7c52f26cc9177 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.h @@ -7,6 +7,7 @@ #include #include #include +#include #include "core/providers/nv_tensorrt_rtx/nv_includes.h" #include "core/providers/shared_library/provider_api.h" @@ -14,33 +15,32 @@ namespace onnxruntime { static const std::string EPCONTEXT_OP = "EPContext"; +static const std::string MAIN_CONTEXT = "main_context"; static const std::string EMBED_MODE = "embed_mode"; static const std::string EP_CACHE_CONTEXT = "ep_cache_context"; static const std::string COMPUTE_CAPABILITY = "hardware_architecture"; static const std::string ONNX_MODEL_FILENAME = "onnx_model_filename"; +static const std::string PARTITION_NAME = "partition_name"; +static const std::string SDK_VERSION = "ep_sdk_version"; static const std::string EPCONTEXT_OP_DOMAIN = "com.microsoft"; -static const std::string EPCONTEXT_WARNING = - "It's suggested to set the ORT graph optimization level to 0 and \ - make \"embed_mode\" to 0 (\"ep_cache_context\" is the cache path)\ - for the best model loading time"; -bool GraphHasCtxNode(const GraphViewer& graph_viewer); +bool GraphHasCtxNode(const GraphViewer& graph_viewer, size_t& node_idx); const std::filesystem::path& GetModelPath(const GraphViewer& graph_viewer); std::filesystem::path GetPathOrParentPathOfCtxModel(const std::string& ep_context_file_path); -ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer, - const std::string engine_cache_path, - char* engine_data, - size_t size, - const int64_t embed_mode, - const std::string compute_capability, - const std::string onnx_model_path, - const logging::Logger* logger); +Status CreateCtxNode(const GraphViewer& graph_viewer, + Graph& graph_build, + const std::string engine_cache_path, + char* engine_data, + size_t size, + const int64_t embed_mode, + const std::string compute_capability, + const std::string onnx_model_path, + const std::string& ep_context_node_name, + int trt_version); std::string GetCtxModelPath(const std::string& ep_context_file_path, const std::string& original_model_path); bool IsAbsolutePath(const std::string& path_string); bool IsRelativePathToParentPath(const std::string& path_string); -void DumpCtxModel(ONNX_NAMESPACE::ModelProto* model_proto, - const std::string& ctx_model_path); void UpdateCtxNodeModelEngineContext(ONNX_NAMESPACE::ModelProto* model_proto, char* engine_data, size_t size); @@ -55,6 +55,8 @@ class TensorRTCacheModelHandler { std::string onnx_model_folder_path, const void* onnx_model_bytestream, size_t onnx_model_bytestream_size, + const void* onnx_external_data_bytestream, + size_t onnx_external_data_bytestream_size, bool detailed_build_log) : trt_engine_(trt_engine), trt_runtime_(trt_runtime), @@ -64,13 +66,15 @@ class TensorRTCacheModelHandler { onnx_model_folder_path_(onnx_model_folder_path), onnx_model_bytestream_(onnx_model_bytestream), onnx_model_bytestream_size_(onnx_model_bytestream_size), + onnx_external_data_bytestream_(onnx_external_data_bytestream), + onnx_external_data_bytestream_size_(onnx_external_data_bytestream_size), detailed_build_log_(detailed_build_log) { } ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TensorRTCacheModelHandler); - bool ValidateEPCtxNode(const GraphViewer& graph_viewer); + bool ValidateEPCtxNode(const Node& node); - Status GetEpContextFromGraph(const GraphViewer& graph_viewer); + Status GetEpContextFromGraph(const Node& node); private: std::unique_ptr* trt_engine_; @@ -81,6 +85,8 @@ class TensorRTCacheModelHandler { std::string onnx_model_folder_path_; const void* onnx_model_bytestream_; size_t onnx_model_bytestream_size_; + const void* onnx_external_data_bytestream_; + size_t onnx_external_data_bytestream_size_; bool detailed_build_log_; }; // TRTCacheModelHandler } // namespace onnxruntime diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc index 19505da1bbe56..2327bc2094d1a 100644 --- a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc @@ -1,25 +1,16 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // Licensed under the MIT License. #include "core/graph/onnx_protobuf.h" #include "core/session/inference_session.h" #include "test/providers/provider_test_utils.h" #include "test/framework/test_utils.h" -#include "gtest/gtest.h" + #include "test/util/include/scoped_env_vars.h" #include "test/common/trt_op_test_utils.h" #include "test/common/random_generator.h" #include "test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h" -#include "test/util/include/api_asserts.h" -#include "test/util/include/asserts.h" -#include -#include -#include -#include -#include #include -#include #include using namespace std; @@ -30,200 +21,6 @@ namespace onnxruntime { namespace test { -template -class NvExecutionProviderTest : public ::testing::Test { - protected: - std::string getTypeAsName() { - std::string dtype_name = ""; - if constexpr (std::is_same::value) { - dtype_name = "fp64"; - } else if constexpr (std::is_same::value) { - dtype_name = "fp32"; - } else if constexpr (std::is_same::value) { - dtype_name = "bf16"; - } else if constexpr (std::is_same::value) { - dtype_name = "fp16"; - } else if constexpr (std::is_same::value) { - dtype_name = "int8"; - } else if constexpr (std::is_same::value) { - dtype_name = "uint8"; - } else if constexpr (std::is_same::value) { - dtype_name = "int32"; - } else if constexpr (std::is_same::value) { - dtype_name = "int64"; - } - return dtype_name; - } -}; - -using NvExecutionProviderTestTypes = ::testing::Types; // double, -TYPED_TEST_SUITE(NvExecutionProviderTest, NvExecutionProviderTestTypes); - -std::string PathToUTF8(const PathString& path) { -#ifdef WIN32 - std::wstring_convert> converter; - return converter.to_bytes(path); -#else - return path.c_str(); -#endif -} - -void clearFileIfExists(PathString path) { - if (std::filesystem::exists(path)) { - std::filesystem::remove(path); - } -} - -template -void VerifyOutputs(const std::vector& fetches, const std::vector& expected_dims, - const std::vector& expected_values) { - ASSERT_EQ(1, fetches.size()); - auto& rtensor = fetches.front().Get(); - TensorShape expected_shape(expected_dims); - ASSERT_EQ(expected_shape, rtensor.Shape()); - const std::vector found(rtensor.Data(), rtensor.Data() + expected_values.size()); - ASSERT_EQ(expected_values, found); -} - -/** - * Create a simple model with dynamic or non-dynamic input shape. - * \param model_name - model name - * \param graph_name - graph name - * \param dims - input dimensions - * \param add_fast_gelu - add FastGelu node which makes the whole model partition into TRT EP and CUDA EP subgraphs. - * - * input: "X", "Y" and "Z" - * you can specify input dimensions, for example (1, 3, 2), (1, 2) or (1, -1, -1)). Note: -1 means the dimension is dynamic. - * All three inputs have the same dimensions. - * output: "M" - * - * "X" "Y" - * \ / - * "Z" Add - * \ / - * Add - * / - * Add (+ float scalar "S") - * / - * "O" - * - * or - * - * "X" "Y" - * \ / - * "Z" Add - * \ / - * Add - * / - * FastGelu (This node will be placed on CUDA EP) - * / - * * Add (+ float scalar "S") - * / - * "O" - */ -static void CreateBaseModel(const PathString& model_name, - std::string graph_name, - std::vector dims, - bool add_fast_gelu = false, - ONNX_NAMESPACE::TensorProto_DataType dtype = ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { - onnxruntime::Model model(graph_name, false, DefaultLoggingManager().DefaultLogger()); - auto& graph = model.MainGraph(); - std::vector inputs; - std::vector outputs; - - // FLOAT tensor - ONNX_NAMESPACE::TypeProto float_tensor; - float_tensor.mutable_tensor_type()->set_elem_type(dtype); - - for (auto dim : dims) { - float_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim); - } - ONNX_NAMESPACE::TypeProto dyn_float_tensor; - dyn_float_tensor.mutable_tensor_type()->set_elem_type(dtype); - - auto& input_arg_1 = graph.GetOrCreateNodeArg("X", &float_tensor); - auto& input_arg_2 = graph.GetOrCreateNodeArg("Y", &float_tensor); - inputs.push_back(&input_arg_1); - inputs.push_back(&input_arg_2); - auto& output_arg = graph.GetOrCreateNodeArg("node_1_out_1", &float_tensor); - outputs.push_back(&output_arg); - graph.AddNode("node_1", "Add", "node 1.", inputs, outputs); - - auto& input_arg_3 = graph.GetOrCreateNodeArg("Z", &float_tensor); - inputs.clear(); - inputs.push_back(&output_arg); - inputs.push_back(&input_arg_3); - - auto& output_arg_2 = graph.GetOrCreateNodeArg("node_2_out_1", &float_tensor); - outputs.clear(); - outputs.push_back(&output_arg_2); - graph.AddNode("node_2", "Add", "node 2.", inputs, outputs); - - inputs.clear(); - inputs.push_back(&output_arg_2); - - if (add_fast_gelu) { - auto& output_arg_3 = graph.GetOrCreateNodeArg("node_3_out_1", &dyn_float_tensor); - outputs.clear(); - outputs.push_back(&output_arg_3); - - graph.AddNode("node_3", "FastGelu", "node 3.", inputs, outputs, - /* attributes */ nullptr, kMSDomain); - - inputs.clear(); - inputs.push_back(&output_arg_3); - } - - ONNX_NAMESPACE::TypeProto float_scalar; - float_scalar.mutable_tensor_type()->set_elem_type(dtype); - float_scalar.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); - auto& input_scalar = graph.GetOrCreateNodeArg("S", &float_scalar); - inputs.push_back(&input_scalar); - - auto& output_arg_4 = graph.GetOrCreateNodeArg("O", &dyn_float_tensor); - - outputs.clear(); - outputs.push_back(&output_arg_4); - graph.AddNode("node_5", "Add", "node 5.", inputs, outputs); - - auto status = graph.Resolve(); - ASSERT_TRUE(status.IsOK()); - status = onnxruntime::Model::Save(model, model_name); - ASSERT_TRUE(status.IsOK()); -} - -static Ort::IoBinding generate_io_binding(Ort::Session& session, std::map> shape_overwrites = {}) { - Ort::IoBinding binding(session); - auto allocator = Ort::AllocatorWithDefaultOptions(); - for (int input_idx = 0; input_idx < int(session.GetInputCount()); ++input_idx) { - auto input_name = session.GetInputNameAllocated(input_idx, Ort::AllocatorWithDefaultOptions()); - auto full_tensor_info = session.GetInputTypeInfo(input_idx); - auto tensor_info = full_tensor_info.GetTensorTypeAndShapeInfo(); - auto shape = tensor_info.GetShape(); - auto type = tensor_info.GetElementType(); - if (shape_overwrites.find(input_name.get()) == shape_overwrites.end()) { - for (auto& v : shape) { - if (v == -1) { - v = 1; - } - } - } else { - shape = shape_overwrites[input_name.get()]; - } - auto input_value = Ort::Value::CreateTensor(allocator, - shape.data(), - shape.size(), - type); - binding.BindInput(input_name.get(), input_value); - } - - for (int output_idx = 0; output_idx < int(session.GetOutputCount()); ++output_idx) { - auto output_name = session.GetOutputNameAllocated(output_idx, Ort::AllocatorWithDefaultOptions()); - binding.BindOutput(output_name.get(), allocator.GetInfo()); - } - return binding; -} - TEST(NvExecutionProviderTest, ContextEmbedAndReload) { PathString model_name = ORT_TSTR("nv_execution_provider_test.onnx"); PathString model_name_ctx = ORT_TSTR("nv_execution_provider_test_ctx.onnx"); @@ -233,11 +30,6 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReload) { std::vector dims = {1, 3, 2}; CreateBaseModel(model_name, graph_name, dims); - - auto env = Ort::Env(); - auto logging_level = OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING; - env.UpdateEnvWithCustomLogLevel(logging_level); - // AOT time { auto start = std::chrono::high_resolution_clock::now(); @@ -246,7 +38,7 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReload) { so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, model_name_ctx_str.c_str()); so.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {}); - Ort::Session session_object(env, model_name.c_str(), so); + Ort::Session session_object(*ort_env, model_name.c_str(), so); auto stop = std::chrono::high_resolution_clock::now(); std::cout << "Session creation AOT: " << std::chrono::duration_cast((stop - start)).count() << " ms" << std::endl; @@ -261,7 +53,7 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReload) { Ort::RunOptions run_options; so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); so.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {}); - Ort::Session session_object(env, model_name_ctx.c_str(), so); + Ort::Session session_object(*ort_env, model_name_ctx.c_str(), so); auto stop = std::chrono::high_resolution_clock::now(); std::cout << "Session creation JIT: " << std::chrono::duration_cast((stop - start)).count() << " ms" << std::endl; @@ -280,10 +72,6 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReloadDynamic) { CreateBaseModel(model_name, graph_name, dims); - auto env = Ort::Env(); - auto logging_level = OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING; - env.UpdateEnvWithCustomLogLevel(logging_level); - // AOT time { auto start = std::chrono::high_resolution_clock::now(); @@ -292,7 +80,7 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReloadDynamic) { so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, model_name_ctx_str.c_str()); so.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {}); - Ort::Session session_object(env, model_name.c_str(), so); + Ort::Session session_object(*ort_env, model_name.c_str(), so); auto stop = std::chrono::high_resolution_clock::now(); std::cout << "Session creation AOT: " << std::chrono::duration_cast((stop - start)).count() << " ms" << std::endl; @@ -307,7 +95,7 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReloadDynamic) { Ort::RunOptions run_options; so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); so.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {}); - Ort::Session session_object(env, model_name_ctx.c_str(), so); + Ort::Session session_object(*ort_env, model_name_ctx.c_str(), so); auto stop = std::chrono::high_resolution_clock::now(); std::cout << "Session creation JIT: " << std::chrono::duration_cast((stop - start)).count() << " ms" << std::endl; @@ -329,10 +117,6 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReloadDataDynamic) { CreateBaseModel(model_name, graph_name, dims); - auto env = Ort::Env(); - auto logging_level = OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING; - env.UpdateEnvWithCustomLogLevel(logging_level); - // AOT time { auto start = std::chrono::high_resolution_clock::now(); @@ -341,7 +125,7 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReloadDataDynamic) { so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, model_name_ctx_str.c_str()); so.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {}); - Ort::Session session_object(env, model_name.c_str(), so); + Ort::Session session_object(*ort_env, model_name.c_str(), so); auto stop = std::chrono::high_resolution_clock::now(); std::cout << "Session creation AOT: " << std::chrono::duration_cast((stop - start)).count() << " ms" << std::endl; @@ -356,7 +140,7 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReloadDataDynamic) { Ort::RunOptions run_options; so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); so.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {}); - Ort::Session session_object(env, model_name_ctx.c_str(), so); + Ort::Session session_object(*ort_env, model_name_ctx.c_str(), so); auto stop = std::chrono::high_resolution_clock::now(); std::cout << "Session creation JIT: " << std::chrono::duration_cast((stop - start)).count() << " ms" << std::endl; @@ -368,33 +152,71 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReloadDataDynamic) { } } -TYPED_TEST(NvExecutionProviderTest, IOTypeTests) { - std::string dtype_name = this->getTypeAsName(); +std::string getTypeAsName(ONNX_NAMESPACE::TensorProto_DataType dtype) { + switch (dtype) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: + return "fp64"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + return "fp32"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: + return "fp16"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: + return "bf16"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: + return "int64"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + return "int32"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: + return "int8"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + return "uint8"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4: + return "int4"; + default: + return "Unkwon type"; + } +} + +class TypeTests : public ::testing::TestWithParam { + public: +}; + +TEST_P(TypeTests, IOTypes) { + const std::string dtype_name = getTypeAsName(GetParam()); ASSERT_FALSE(dtype_name.empty()); const std::string model_name_str = "nv_execution_provider_" + dtype_name + ".onnx"; const PathString model_name = ToPathString(model_name_str); - std::string graph_name = "test" + dtype_name; - std::vector dims = {1, -1, -1}; - - CreateBaseModel(model_name, graph_name, dims); + const std::string graph_name = "test" + dtype_name; + const std::vector dims = {1, 5, 10}; - auto env = Ort::Env(); - auto logging_level = OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING; - env.UpdateEnvWithCustomLogLevel(logging_level); + CreateBaseModel(model_name, graph_name, dims, false, GetParam()); // AOT time { Ort::SessionOptions so; Ort::RunOptions run_options; so.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {}); - Ort::Session session_object(env, model_name.c_str(), so); + Ort::Session session_object(*ort_env, model_name.c_str(), so); auto io_binding = generate_io_binding(session_object); session_object.Run(run_options, io_binding); } } -#if defined(WIN32) +INSTANTIATE_TEST_SUITE_P(NvExecutionProviderTest, TypeTests, + ::testing::Values(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, + ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16, + ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, + ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 + // disabled low precision integer types since a specific quantize/dequantize model is required + // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, + // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, + // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4 + ), + [](const testing::TestParamInfo& info) { return getTypeAsName(info.param); }); + +#ifdef _WIN32 static bool SessionHasEp(Ort::Session& session, const char* ep_name) { // Access the underlying InferenceSession. const OrtSession* ort_session = session; @@ -420,20 +242,16 @@ TEST(NvExecutionProviderTest, AutoEp_PreferGpu) { CreateBaseModel(model_name, graph_name, dims); - auto env = Ort::Env(); - auto logging_level = OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING; - env.UpdateEnvWithCustomLogLevel(logging_level); - { - env.RegisterExecutionProviderLibrary(kNvTensorRTRTXExecutionProvider, ORT_TSTR("onnxruntime_providers_nv_tensorrt_rtx.dll")); + ort_env->RegisterExecutionProviderLibrary(kNvTensorRTRTXExecutionProvider, ORT_TSTR("onnxruntime_providers_nv_tensorrt_rtx.dll")); Ort::SessionOptions so; so.SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy_PREFER_GPU); - Ort::Session session_object(env, model_name.c_str(), so); + Ort::Session session_object(*ort_env, model_name.c_str(), so); EXPECT_TRUE(SessionHasEp(session_object, kNvTensorRTRTXExecutionProvider)); } - env.UnregisterExecutionProviderLibrary(kNvTensorRTRTXExecutionProvider); + ort_env->UnregisterExecutionProviderLibrary(kNvTensorRTRTXExecutionProvider); } TEST(NvExecutionProviderTest, GetSharedAllocator) { @@ -580,7 +398,7 @@ TEST(NvExecutionProviderTest, DataTransfer) { device_tensor = Ort::Value(); } -#endif // defined(WIN32) +#endif } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_ep_context_test.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_ep_context_test.cc new file mode 100644 index 0000000000000..ce49ae81c81c0 --- /dev/null +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_ep_context_test.cc @@ -0,0 +1,206 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Licensed under the MIT License. +#include "core/common/path_utils.h" +#include "test/framework/test_utils.h" +#include "test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h" + +#include + +extern std::unique_ptr ort_env; + +namespace onnxruntime { + +namespace test { + +RegisteredEpDeviceUniquePtr AppendTrtEtxEP(Ort::SessionOptions& session_options, std::unordered_map& option_map) { + RegisteredEpDeviceUniquePtr nv_tensorrt_rtx_ep; +#ifdef _WIN32 + /// Since this test runs after other tests that use registration interface this test has to use it as well + /// windows as otherwise the kernel registry inside the EP will not be populated. The legacy APis ony call the initialize once. + Utils::RegisterAndGetNvTensorRtRtxEp(*ort_env, nv_tensorrt_rtx_ep); + auto ep_devices = ort_env->GetEpDevices(); + Ort::ConstEpDevice selected_device; + for (auto& device : ep_devices) { + if (!std::strcmp(device.EpName(), kNvTensorRTRTXExecutionProvider)) { + selected_device = device; + } + } + session_options.AppendExecutionProvider_V2(*ort_env, {selected_device}, option_map); +#else + session_options.AppendExecutionProvider(onnxruntime::kNvTensorRTRTXExecutionProvider, option_map); +#endif + return nv_tensorrt_rtx_ep; +} + +std::vector readBinaryFile(const PathString& filename) { + std::ifstream file(filename, std::ios::binary); + if (!file.is_open()) { + throw std::runtime_error("Could not open file: " + PathToUTF8String(filename)); + } + + file.seekg(0, std::ios::end); + std::streamsize filesize = file.tellg(); + file.seekg(0, std::ios::beg); + + std::vector buffer(filesize); + if (!file.read(reinterpret_cast(buffer.data()), filesize)) { + throw std::runtime_error("Could not read file: " + PathToUTF8String(filename)); + } + + return buffer; +} + +struct CompileParam { + bool embed_mode; + bool bytestream_io; + bool external_initialzier_for_parser = false; + const std::string to_string() const { + return "embed_mode_" + std::to_string(embed_mode) + "_bytestream_io_" + std::to_string(bytestream_io) + "_ext_init_" + std::to_string(external_initialzier_for_parser); + ; + } +}; +class CompileApiTest + : public testing::TestWithParam { + public: + const CompileParam& GetCompileParam() const { + return GetParam(); + } +}; + +void SmallModelTest(CompileParam test_param, bool fully_supported_model) { + std::string test_name = test_param.to_string(); + if (!fully_supported_model) + test_name += "_fast_gelu"; + PathString model_name = path_utils::MakePathString("nv_execution_provider_compile_" + test_name + ".onnx"); + PathString model_name_ctx = path_utils::MakePathString("nv_execution_provider_compile_" + test_name + "_ctx.onnx"); + clearFileIfExists(model_name_ctx); + std::string graph_name = "test"; + std::vector dims = {1, 3, 2}; + + CreateBaseModel(model_name, graph_name, dims, !fully_supported_model); + + Ort::SessionOptions session_options; + std::unordered_map option_map{ + {onnxruntime::nv::provider_option_names::kUseExternalDataInitializer, std::to_string(test_param.external_initialzier_for_parser)}}; + auto ep = AppendTrtEtxEP(session_options, option_map); + + Ort::ModelCompilationOptions model_compile_options(*ort_env, session_options); + model_compile_options.SetEpContextEmbedMode(test_param.embed_mode); + + void* output_context = nullptr; + size_t output_context_size = 0; + std::vector input_onnx; + if (test_param.bytestream_io) { + input_onnx = readBinaryFile(model_name); + model_compile_options.SetInputModelFromBuffer(input_onnx.data(), input_onnx.size()); + model_compile_options.SetOutputModelBuffer(Ort::AllocatorWithDefaultOptions(), &output_context, &output_context_size); + } else { + model_compile_options.SetInputModelPath(model_name.c_str()); + model_compile_options.SetOutputModelPath(model_name_ctx.c_str()); + } + // AOT time + ASSERT_TRUE(Ort::CompileModel(*ort_env, model_compile_options).IsOK()); + + // JIT time + Ort::Session session_object{nullptr}; + if (test_param.bytestream_io) { + session_object = Ort::Session(*ort_env, output_context, output_context_size, session_options); + } else { + session_object = Ort::Session(*ort_env, model_name_ctx.c_str(), session_options); + } + auto io_binding = generate_io_binding(session_object); + Ort::RunOptions run_options; + session_object.Run(run_options, io_binding); +} + +TEST_P(CompileApiTest, SmallModel) { + const auto& test_param = GetCompileParam(); + SmallModelTest(test_param, true); +} + +TEST_P(CompileApiTest, SmallSplitModel) { + const auto& test_param = GetCompileParam(); + SmallModelTest(test_param, false); +} + +TEST_P(CompileApiTest, LargeModel) { + const auto& test_param = GetCompileParam(); + // with embed mode == 1 the resulting file will be over the 2GB proto limit + if (test_param.embed_mode == 1) { + GTEST_SKIP(); + } + std::string test_name = test_param.to_string(); + PathString model_name = path_utils::MakePathString("nv_execution_provider_compile_large_" + test_name + ".onnx"); + PathString external_data_name = path_utils::MakePathString("nv_execution_provider_compile_large_" + test_name + ".onnx_data"); + PathString model_name_ctx = path_utils::MakePathString("nv_execution_provider_compile_large_" + test_name + "_ctx.onnx"); + PathString model_name_ctx_data = path_utils::MakePathString("nv_execution_provider_compile_large_" + test_name + "_ctx.onnx_data"); + clearFileIfExists(model_name_ctx); + clearFileIfExists(model_name_ctx_data); + // This accelerates test iterations if the large model was already generated + if (!std::filesystem::exists(model_name) || !std::filesystem::exists(external_data_name)) { + CreateLargeLLMModel(model_name, external_data_name); + } + + Ort::SessionOptions session_options; + std::unordered_map option_map{ + {onnxruntime::nv::provider_option_names::kUseExternalDataInitializer, + std::to_string(test_param.bytestream_io || test_param.external_initialzier_for_parser)}}; + auto ep = AppendTrtEtxEP(session_options, option_map); + + Ort::ModelCompilationOptions model_compile_options(*ort_env, session_options); + model_compile_options.SetEpContextEmbedMode(test_param.embed_mode); + + void* output_context = nullptr; + size_t output_context_size = 0; + std::vector input_onnx, input_data; + std::vector file_names; + std::vector file_buffers; + std::vector lengths; + if (test_param.bytestream_io) { + input_onnx = readBinaryFile(model_name); + input_data = readBinaryFile(external_data_name); + file_names = {external_data_name}; + file_buffers = {input_data.data()}; + lengths = {input_data.size()}; + session_options.AddExternalInitializersFromFilesInMemory(file_names, file_buffers, lengths); + + model_compile_options.SetInputModelFromBuffer(input_onnx.data(), input_onnx.size()); + model_compile_options.SetOutputModelBuffer(Ort::AllocatorWithDefaultOptions(), &output_context, &output_context_size); + } else { + model_compile_options.SetInputModelPath(model_name.c_str()); + model_compile_options.SetOutputModelPath(model_name_ctx.c_str()); + model_compile_options.SetOutputModelExternalInitializersFile(model_name_ctx_data.c_str(), 1024); + } + + // AOT time + ASSERT_TRUE(Ort::CompileModel(*ort_env, model_compile_options).IsOK()); + + // JIT time + std::unique_ptr session; + if (test_param.bytestream_io) { + session = std::make_unique(*ort_env, output_context, output_context_size, session_options); + } else { + session = std::make_unique(*ort_env, model_name_ctx.c_str(), session_options); + } + + auto io_binding = generate_io_binding(*session); + Ort::RunOptions run_options; + session->Run(run_options, io_binding); +} + +INSTANTIATE_TEST_SUITE_P( + NvExecutionProviderTest, CompileApiTest, + ::testing::Values( + CompileParam{true, false}, + CompileParam{false, false}, + CompileParam{true, true}, + CompileParam{false, true}, + // test with external initializers for parser + CompileParam{true, true, true}, + CompileParam{true, false, true}), + [](const testing::TestParamInfo& info) { + return info.param.to_string(); + }); + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.cc index f0ce5c0b296ca..17182ab032f7a 100644 --- a/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.cc +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.cc @@ -3,18 +3,26 @@ // Licensed under the MIT License. // registration/selection is only supported on windows as there's no device discovery on other platforms -#ifdef _WIN32 #include "test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h" #include +#include #include #include "core/session/onnxruntime_cxx_api.h" #include "test/util/include/api_asserts.h" +#include "core/graph/basic_types.h" +#include "core/graph/onnx_protobuf.h" +#include "core/graph/model_saving_options.h" +#include "test/util/include/scoped_env_vars.h" +#include "test/common/trt_op_test_utils.h" +#include "test/providers/provider_test_utils.h" +#include "test/framework/test_utils.h" namespace onnxruntime { namespace test { +#ifdef _WIN32 Utils::NvTensorRtRtxEpInfo Utils::nv_tensorrt_rtx_ep_info; @@ -51,8 +59,410 @@ void Utils::RegisterAndGetNvTensorRtRtxEp(Ort::Env& env, RegisteredEpDeviceUniqu c_api.UnregisterExecutionProviderLibrary(env, nv_tensorrt_rtx_ep_info.registration_name.c_str()); }); } +#endif // _WIN32 + +void CreateBaseModel(const PathString& model_name, + std::string graph_name, + std::vector dims, + bool add_fast_gelu, + ONNX_NAMESPACE::TensorProto_DataType dtype, + const PathString& external_initializer_file) { + onnxruntime::Model model(graph_name, false, DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + std::vector inputs; + std::vector outputs; + + // FLOAT tensor + ONNX_NAMESPACE::TypeProto float_tensor; + float_tensor.mutable_tensor_type()->set_elem_type(dtype); + + for (auto dim : dims) { + float_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim); + } + ONNX_NAMESPACE::TypeProto dyn_float_tensor; + dyn_float_tensor.mutable_tensor_type()->set_elem_type(dtype); + + auto& input_arg_1 = graph.GetOrCreateNodeArg("X", &float_tensor); + auto& input_arg_2 = graph.GetOrCreateNodeArg("Y", &float_tensor); + inputs.push_back(&input_arg_1); + inputs.push_back(&input_arg_2); + auto& output_arg = graph.GetOrCreateNodeArg("node_1_out_1", &float_tensor); + outputs.push_back(&output_arg); + graph.AddNode("node_1", "Add", "node 1.", inputs, outputs); + + auto& input_arg_3 = graph.GetOrCreateNodeArg("Z", &float_tensor); + inputs.clear(); + inputs.push_back(&output_arg); + inputs.push_back(&input_arg_3); + + auto& output_arg_2 = graph.GetOrCreateNodeArg("node_2_out_1", &float_tensor); + outputs.clear(); + outputs.push_back(&output_arg_2); + graph.AddNode("node_2", "Add", "node 2.", inputs, outputs); + + inputs.clear(); + inputs.push_back(&output_arg_2); + + if (add_fast_gelu) { + auto& output_arg_3 = graph.GetOrCreateNodeArg("node_3_out_1", &dyn_float_tensor); + outputs.clear(); + outputs.push_back(&output_arg_3); + + graph.AddNode("node_3", "FastGelu", "node 3.", inputs, outputs, + /* attributes */ nullptr, kMSDomain); + + inputs.clear(); + inputs.push_back(&output_arg_3); + } + + ONNX_NAMESPACE::TypeProto float_scalar; + float_scalar.mutable_tensor_type()->set_elem_type(dtype); + float_scalar.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); + auto& input_scalar = graph.GetOrCreateNodeArg("S", &float_scalar); + inputs.push_back(&input_scalar); + + auto& output_arg_4 = graph.GetOrCreateNodeArg("O", &dyn_float_tensor); + + outputs.clear(); + outputs.push_back(&output_arg_4); + graph.AddNode("node_5", "Add", "node 5.", inputs, outputs); + + auto status = graph.Resolve(); + ASSERT_TRUE(status.IsOK()); + if (!external_initializer_file.empty()) { + ModelSavingOptions save_options(128); + status = Model::SaveWithExternalInitializers(model, model_name, external_initializer_file, save_options); + } else { + status = Model::Save(model, model_name); + } + ASSERT_TRUE(status.IsOK()); +} + +// Helper to create large initializers +ONNX_NAMESPACE::TensorProto CreateLargeWeight( + const std::string& name, + ONNX_NAMESPACE::TensorProto_DataType dtype, + const std::vector& shape, + float scale = 0.02f) { + ONNX_NAMESPACE::TensorProto tensor; + tensor.set_name(name); + tensor.set_data_type(dtype); + for (auto d : shape) tensor.add_dims(d); + // Here we fill with random floats, but for real data, use your trained weights. + size_t total_size = 1; + for (int64_t d : shape) total_size *= d; + std::random_device rd; + std::default_random_engine rng(rd()); + if (dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + std::vector data(total_size); + std::normal_distribution dist(0.0f, scale); + for (auto& v : data) v = dist(rng); + tensor.set_raw_data(data.data(), total_size * sizeof(float)); + } else if (dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + std::vector data(total_size); + std::normal_distribution dist(0.0f, scale); + for (auto& v : data) v = MLFloat16(dist(rng)); + tensor.set_raw_data(data.data(), total_size * sizeof(MLFloat16)); + } else { + throw std::runtime_error("Unsupported data type for large weight"); + } + return tensor; +} + +// Helper to add a GroupQueryAttention node +onnxruntime::NodeArg& AddGroupQueryAttention( + onnxruntime::Graph& graph, + onnxruntime::NodeArg& query, + onnxruntime::NodeArg& key, + onnxruntime::NodeArg& value, + int batch_size, + int head_dim, + int seq_len, + int num_heads, + int kv_num_heads, + float scale, + ONNX_NAMESPACE::TensorProto_DataType dtype, + const std::string& node_name) { + // KV cache + ONNX_NAMESPACE::TypeProto key_type; + key_type.mutable_tensor_type()->set_elem_type(dtype); + key_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(batch_size); + key_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(kv_num_heads); + key_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(seq_len); + key_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(head_dim); + auto& past_key = graph.GetOrCreateNodeArg(node_name + "_past_key", &key_type); + + ONNX_NAMESPACE::TypeProto value_type; + value_type.mutable_tensor_type()->set_elem_type(dtype); + value_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(batch_size); + value_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(kv_num_heads); + value_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(seq_len); + value_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(head_dim); + auto& past_value = graph.GetOrCreateNodeArg(node_name + "_past_value", &value_type); + + // Output + auto& output = graph.GetOrCreateNodeArg(node_name + "_output", nullptr); + + // Create required initializers for GroupQueryAttention + ONNX_NAMESPACE::TensorProto seqlens_k_tensor; + seqlens_k_tensor.set_name(node_name + "_seqlens_k"); + seqlens_k_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT32); + seqlens_k_tensor.add_dims(2); + seqlens_k_tensor.set_dims(0, batch_size); + seqlens_k_tensor.set_dims(0, 1); + seqlens_k_tensor.add_int32_data(seq_len - 1); // seqlens_k = total_sequence_length - 1 + graph.AddInitializedTensor(seqlens_k_tensor); + + ONNX_NAMESPACE::TensorProto total_seq_len_tensor; + total_seq_len_tensor.set_name(node_name + "_total_sequence_length"); + total_seq_len_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT32); + total_seq_len_tensor.add_int32_data(seq_len); + graph.AddInitializedTensor(total_seq_len_tensor); + + // Get the initializers that were created for this node + auto* seqlens_k = graph.GetNodeArg(node_name + "_seqlens_k"); + auto* total_sequence_length = graph.GetNodeArg(node_name + "_total_sequence_length"); + + auto& present_value = graph.GetOrCreateNodeArg(node_name + "_present_value", nullptr); + auto& present_key = graph.GetOrCreateNodeArg(node_name + "_present_key", nullptr); + + // Inputs - GroupQueryAttention requires at least 7 inputs (query, key, value, past_key, past_value, seqlens_k, total_sequence_length) + std::vector inputs = { + &query, // 0: query + &key, // 1: key + &value, // 2: value + &past_key, // 3: past_key (optional) + &past_value, // 4: past_value (optional) + seqlens_k, // 5: seqlens_k (required) + total_sequence_length, // 6: total_sequence_length (required) + // nullptr, // 7: cos_cache (optional) + // nullptr, // 8: sin_cache (optional) + // nullptr, // 9: position_ids (optional) + // nullptr, // 10: attention_bias (optional) + // nullptr // 11: head_sink (optional) + }; + + // Attributes + NodeAttributes attrs; + ONNX_NAMESPACE::AttributeProto attr_heads; + attr_heads.set_name("num_heads"); + attr_heads.set_type(onnx::AttributeProto_AttributeType_INT); + attr_heads.set_i(num_heads); + attrs["num_heads"] = attr_heads; + ONNX_NAMESPACE::AttributeProto attr_kv_num_heads; + attr_kv_num_heads.set_name("kv_num_heads"); + attr_kv_num_heads.set_type(onnx::AttributeProto_AttributeType_INT); + attr_kv_num_heads.set_i(kv_num_heads); + attrs["kv_num_heads"] = attr_kv_num_heads; + ONNX_NAMESPACE::AttributeProto attr_scale; + attr_scale.set_name("scale"); + attr_scale.set_type(onnx::AttributeProto_AttributeType_FLOAT); + attr_scale.set_f(scale); + attrs["scale"] = attr_scale; + + // Register node + graph.AddNode( + node_name, + "GroupQueryAttention", + "GroupQueryAttention Node", + inputs, + {&output, &present_key, &present_value}, + &attrs, + "com.microsoft"); + + return output; +} + +void CreateLargeLLMModel(const PathString& model_path, const PathString& external_data_path) { + // Model parameters (example: 24 layers, 4096 hidden dim, 32 attention heads, 8 kv heads => GQA) + int batch_size = 1; + int num_layers = 32; + int hidden_dim = 2048; + int q_num_heads = 8; + int kv_num_heads = 1; // GQA: q_num_heads > kv_num_heads, and divisible. + int seq_length = 128; // Short, for demonstration. + int vocab_size = 32000; + auto dtype = ONNX_NAMESPACE::TensorProto_DataType_FLOAT16; + + // Set up model/graph + onnxruntime::Model model("LLM_With_GQA", false, DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + + // Input + ONNX_NAMESPACE::TypeProto input_type; + input_type.mutable_tensor_type()->set_elem_type(dtype); + input_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(batch_size); + input_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(seq_length); + input_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(hidden_dim); + auto& input = graph.GetOrCreateNodeArg("input", &input_type); + + auto* current_arg = &input; + + // Repeated layers: [Attention + MLP] + for (int l = 0; l < num_layers; ++l) { + // KV cache - initialize with zeros for the first forward pass + int head_dim = hidden_dim / q_num_heads; + + // Split Q, K, V + auto& q_split = graph.GetOrCreateNodeArg("q_split_" + std::to_string(l), nullptr); + auto& k_split = graph.GetOrCreateNodeArg("k_split_" + std::to_string(l), nullptr); + auto& v_split = graph.GetOrCreateNodeArg("v_split_" + std::to_string(l), nullptr); + constexpr bool split = false; + if constexpr (split) { + // Attention weights (Q, K, V projections) + auto wqkv = CreateLargeWeight("wqkv_" + std::to_string(l), + dtype, {hidden_dim, hidden_dim * 3}); + graph.AddInitializedTensor(wqkv); + + // Q = input @ wq, K = input @ wk, V = input @ wv + auto& qkv_arg = graph.GetOrCreateNodeArg("qkv_" + std::to_string(l), nullptr); + graph.AddNode("QKV_Linear_" + std::to_string(l), "MatMul", "", {current_arg, graph.GetNodeArg(wqkv.name())}, {&qkv_arg}); + + NodeAttributes attrs_split; + ONNX_NAMESPACE::AttributeProto attr_split_axis; + attr_split_axis.set_name("axis"); + attr_split_axis.set_type(onnx::AttributeProto_AttributeType_INT); + attr_split_axis.set_i(-1); + attrs_split["axis"] = attr_split_axis; + ONNX_NAMESPACE::AttributeProto attr_split_num_outputs; + attr_split_num_outputs.set_name("num_outputs"); + attr_split_num_outputs.set_type(onnx::AttributeProto_AttributeType_INT); + attr_split_num_outputs.set_i(3); + attrs_split["num_outputs"] = attr_split_num_outputs; + graph.AddNode("Q_Split_" + std::to_string(l), "Split", "", {&qkv_arg}, {&q_split, &k_split, &v_split}, &attrs_split); + } else { + // Attention weights (Q, K, V projections) + auto wq = CreateLargeWeight("wq_" + std::to_string(l), + dtype, {hidden_dim, hidden_dim}); + graph.AddInitializedTensor(wq); + auto wk = CreateLargeWeight("wk_" + std::to_string(l), + dtype, {hidden_dim, head_dim * kv_num_heads}); + graph.AddInitializedTensor(wk); + auto wv = CreateLargeWeight("wv_" + std::to_string(l), + dtype, {hidden_dim, head_dim * kv_num_heads}); + graph.AddInitializedTensor(wv); + + // Q = input @ wq, K = input @ wk, V = input @ wv + graph.AddNode("Q_Linear_" + std::to_string(l), "MatMul", "", {current_arg, graph.GetNodeArg(wq.name())}, {&q_split}); + graph.AddNode("K_Linear_" + std::to_string(l), "MatMul", "", {current_arg, graph.GetNodeArg(wk.name())}, {&k_split}); + graph.AddNode("V_Linear_" + std::to_string(l), "MatMul", "", {current_arg, graph.GetNodeArg(wv.name())}, {&v_split}); + } + // Reshape Q, K, V + auto& q_reshaped = graph.GetOrCreateNodeArg("q_reshaped_" + std::to_string(l), nullptr); + auto& k_reshaped = graph.GetOrCreateNodeArg("k_reshaped_" + std::to_string(l), nullptr); + auto& v_reshaped = graph.GetOrCreateNodeArg("v_reshaped_" + std::to_string(l), nullptr); + + ONNX_NAMESPACE::TensorProto q_shape_tensor; + q_shape_tensor.set_name("q_shape_" + std::to_string(l)); + q_shape_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + q_shape_tensor.add_dims(3); + q_shape_tensor.add_int64_data(batch_size); + q_shape_tensor.add_int64_data(seq_length); + q_shape_tensor.add_int64_data(head_dim * q_num_heads); + graph.AddInitializedTensor(q_shape_tensor); + + ONNX_NAMESPACE::TensorProto k_shape_tensor; + k_shape_tensor.set_name("k_shape_" + std::to_string(l)); + k_shape_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + k_shape_tensor.add_dims(3); + k_shape_tensor.add_int64_data(batch_size); + k_shape_tensor.add_int64_data(seq_length); + k_shape_tensor.add_int64_data(head_dim * kv_num_heads); + graph.AddInitializedTensor(k_shape_tensor); + + ONNX_NAMESPACE::TensorProto v_shape_tensor; + v_shape_tensor.set_name("v_shape_" + std::to_string(l)); + v_shape_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + v_shape_tensor.add_dims(3); + v_shape_tensor.add_int64_data(batch_size); + v_shape_tensor.add_int64_data(seq_length); + v_shape_tensor.add_int64_data(head_dim * kv_num_heads); + graph.AddInitializedTensor(v_shape_tensor); + + graph.AddNode("Q_Reshape_" + std::to_string(l), "Reshape", "", {&q_split, graph.GetNodeArg(q_shape_tensor.name())}, {&q_reshaped}); + graph.AddNode("K_Reshape_" + std::to_string(l), "Reshape", "", {&k_split, graph.GetNodeArg(k_shape_tensor.name())}, {&k_reshaped}); + graph.AddNode("V_Reshape_" + std::to_string(l), "Reshape", "", {&v_split, graph.GetNodeArg(v_shape_tensor.name())}, {&v_reshaped}); + + // Replace standard attention with GQA + auto& attn_out = AddGroupQueryAttention( + graph, q_reshaped, k_reshaped, v_reshaped, + batch_size, head_dim, seq_length, q_num_heads, kv_num_heads, + 1.0f, dtype, + "GQA_" + std::to_string(l)); + + // Add an MLP block: (Linear + Activation + Linear) + auto w1 = CreateLargeWeight("mlp_w1_" + std::to_string(l), dtype, {hidden_dim, hidden_dim * 4}); + auto w2 = CreateLargeWeight("mlp_w2_" + std::to_string(l), dtype, {hidden_dim * 4, hidden_dim}); + graph.AddInitializedTensor(w1); + graph.AddInitializedTensor(w2); + + auto& mlp_hidden = graph.GetOrCreateNodeArg("mlp_hidden_" + std::to_string(l), nullptr); + graph.AddNode("MLP_1_" + std::to_string(l), "MatMul", "", {&attn_out, graph.GetNodeArg(w1.name())}, {&mlp_hidden}); + auto& relu_out = graph.GetOrCreateNodeArg("relu_" + std::to_string(l), nullptr); + graph.AddNode("Relu_" + std::to_string(l), "Relu", "", {&mlp_hidden}, {&relu_out}); + auto& mlp_out = graph.GetOrCreateNodeArg("mlp_out_" + std::to_string(l), nullptr); + graph.AddNode("MLP_2_" + std::to_string(l), "MatMul", "", {&relu_out, graph.GetNodeArg(w2.name())}, {&mlp_out}); + current_arg = &mlp_out; // For next layer. + } + + // Final projection to vocab + auto w_logits = CreateLargeWeight("w_logits", + dtype, {hidden_dim, vocab_size}); + graph.AddInitializedTensor(w_logits); + auto& output = graph.GetOrCreateNodeArg("logits", nullptr); + graph.AddNode("Output_Linear", "MatMul", "", {current_arg, graph.GetNodeArg(w_logits.name())}, {&output}); + + // Validate, Write as large model with external data + auto status = graph.Resolve(); + if (!status.IsOK()) throw std::runtime_error(status.ErrorMessage()); + + onnxruntime::ModelSavingOptions save_options(128); + status = onnxruntime::Model::SaveWithExternalInitializers( + model, model_path, external_data_path, save_options); + if (!status.IsOK()) throw std::runtime_error(status.ErrorMessage()); +} + +Ort::IoBinding generate_io_binding( + Ort::Session& session, + std::map> shape_overwrites, + OrtAllocator* allocator) { + Ort::IoBinding binding(session); + auto default_allocator = Ort::AllocatorWithDefaultOptions(); + if (allocator == nullptr) { + allocator = default_allocator; + } + const OrtMemoryInfo* info; + Ort::ThrowOnError(Ort::GetApi().AllocatorGetInfo(allocator, &info)); + Ort::MemoryInfo mem_info(info->name, info->alloc_type, info->device.Id(), info->mem_type); + + for (int input_idx = 0; input_idx < int(session.GetInputCount()); ++input_idx) { + auto input_name = session.GetInputNameAllocated(input_idx, Ort::AllocatorWithDefaultOptions()); + auto full_tensor_info = session.GetInputTypeInfo(input_idx); + auto tensor_info = full_tensor_info.GetTensorTypeAndShapeInfo(); + auto shape = tensor_info.GetShape(); + auto type = tensor_info.GetElementType(); + if (shape_overwrites.find(input_name.get()) == shape_overwrites.end()) { + for (auto& v : shape) { + if (v == -1) { + v = 1; + } + } + } else { + shape = shape_overwrites[input_name.get()]; + } + auto input_value = Ort::Value::CreateTensor(allocator, + shape.data(), + shape.size(), + type); + binding.BindInput(input_name.get(), input_value); + } + + for (int output_idx = 0; output_idx < int(session.GetOutputCount()); ++output_idx) { + auto output_name = session.GetOutputNameAllocated(output_idx, Ort::AllocatorWithDefaultOptions()); + binding.BindOutput(output_name.get(), mem_info); + } + return binding; +} } // namespace test } // namespace onnxruntime - -#endif // _WIN32 diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h b/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h index ef14d3cb382c0..0f011af8211ca 100644 --- a/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h @@ -5,9 +5,21 @@ #include #include +#include +#include + +#include +#include +#include +#include +#include +#include -#include "core/session/onnxruntime_cxx_api.h" #include "core/graph/constants.h" +#include "core/common/path_string.h" +#include "core/framework/tensor.h" +#include "core/framework/ort_value.h" +#include "test/util/include/api_asserts.h" namespace onnxruntime { namespace test { @@ -17,7 +29,7 @@ using RegisteredEpDeviceUniquePtr = std::unique_ptr> converter; + return converter.to_bytes(path); +#else + return path.c_str(); +#endif +} + +[[maybe_unused]] static void clearFileIfExists(PathString path) { + if (std::filesystem::exists(path)) { + std::filesystem::remove(path); + } +} + +template +static void VerifyOutputs(const std::vector& fetches, const std::vector& expected_dims, + const std::vector& expected_values) { + ASSERT_EQ(1, fetches.size()); + auto& rtensor = fetches.front().Get(); + TensorShape expected_shape(expected_dims); + ASSERT_EQ(expected_shape, rtensor.Shape()); + const std::vector found(rtensor.Data(), rtensor.Data() + expected_values.size()); + ASSERT_EQ(expected_values, found); +} + +/** + * Create a simple model with dynamic or non-dynamic input shape. + * \param model_name - model name + * \param graph_name - graph name + * \param dims - input dimensions + * \param add_fast_gelu - add FastGelu node which makes the whole model partition into TRT EP and CUDA EP subgraphs. + * \param external_initializer_file - file name to save external initializers to + * + * input: "X", "Y" and "Z" + * you can specify input dimensions, for example (1, 3, 2), (1, 2) or (1, -1, -1)). Note: -1 means the dimension is dynamic. + * All three inputs have the same dimensions. + * output: "M" + * + * "X" "Y" + * \ / + * "Z" Add + * \ / + * Add + * / + * Add (+ float scalar "S") + * / + * "O" + * + * or + * + * "X" "Y" + * \ / + * "Z" Add + * \ / + * Add + * / + * FastGelu (This node will be placed on CUDA EP) + * / + * * Add (+ float scalar "S") + * / + * "O" + */ +void CreateBaseModel(const PathString& model_name, + std::string graph_name, + std::vector dims, + bool add_fast_gelu = false, + ONNX_NAMESPACE::TensorProto_DataType dtype = ONNX_NAMESPACE::TensorProto_DataType_FLOAT, + const PathString& external_initializer_file = {}); + +void CreateLargeLLMModel(const PathString& model_path, const PathString& external_data_path); + +Ort::IoBinding generate_io_binding( + Ort::Session& session, + std::map> shape_overwrites = {}, + OrtAllocator* allocator = nullptr); + } // namespace test } // namespace onnxruntime diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index dd3e096c0334b..bf89ff2010ec5 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -1708,8 +1708,10 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): run_ios_tests(args, source_dir, config, cwd) continue dll_path_list = [] - if args.use_tensorrt or args.use_nv_tensorrt_rtx: + if args.use_tensorrt: dll_path_list.append(os.path.join(args.tensorrt_home, "lib")) + if args.use_nv_tensorrt_rtx: + dll_path_list.append(os.path.join(args.tensorrt_rtx_home, "lib")) dll_path = None if len(dll_path_list) > 0: