Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
d9878a3
implement GetEPContextNodes()
thevishalagarwal May 28, 2025
bde3ce5
clean up
thevishalagarwal May 28, 2025
7165dfe
rebase to latest
thevishalagarwal May 29, 2025
7b1f5bc
remove ctx model to just add node
thevishalagarwal Jun 4, 2025
03c42fd
update GetCapabilities for multiple EP Context Nodes
thevishalagarwal Jun 9, 2025
43ac5d5
fix lint
thevishalagarwal Jun 11, 2025
313f4ce
add support for TRT external weights API
thevishalagarwal Jun 18, 2025
a7eadab
add new changes
thevishalagarwal Jul 14, 2025
23c6393
update external initializer fix
thevishalagarwal Jul 23, 2025
7b1320e
fix EP name
thevishalagarwal Jul 24, 2025
3b039fa
reorganize unittest helpers
gedoensmax Jul 31, 2025
28d211e
fix type tests
gedoensmax Jul 31, 2025
bd3d4ed
basic EP context support
gedoensmax Jul 31, 2025
d5151d7
large model test
gedoensmax Aug 1, 2025
8de13f9
remove support for weightless
gedoensmax Aug 1, 2025
e2b67a4
reduce header usages, cleanup and unify usage of windows ifdef
gedoensmax Aug 5, 2025
3728ce0
address review comments
gedoensmax Aug 8, 2025
80574d2
fix engine cache path with EP context
thevishalagarwal Aug 12, 2025
3996d9b
fix unit test to add seed for random tensors
thevishalagarwal Aug 12, 2025
a4f8c45
support sm86 and onwards RTX devices
thevishalagarwal Aug 13, 2025
b0bab1c
update cc check
thevishalagarwal Aug 13, 2025
f762d86
fix lint
thevishalagarwal Aug 15, 2025
b81a6d5
do not copy memory to EP owned memory for raw initializers
gedoensmax Aug 17, 2025
d0926f8
use ort values i data is already loaded in memory
gedoensmax Aug 20, 2025
f640430
remove unused var
gedoensmax Aug 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,8 @@ constexpr const char* kProfilesMinShapes = "nv_profile_min_shapes";
constexpr const char* kProfilesMaxShapes = "nv_profile_max_shapes";
constexpr const char* kProfilesOptShapes = "nv_profile_opt_shapes";
constexpr const char* kCudaGraphEnable = "nv_cuda_graph_enable";
constexpr const char* kONNXBytestream = "nv_onnx_bytestream";
constexpr const char* kONNXBytestreamSize = "nv_onnx_bytestream_size";
constexpr const char* kMultiProfileEnable = "nv_multi_profile_enable";
constexpr const char* kUseExternalDataInitializer = "nv_use_external_data_initializer";

} // namespace provider_option_names
namespace run_option_names {
Expand Down
442 changes: 338 additions & 104 deletions onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc

Large diffs are not rendered by default.

50 changes: 46 additions & 4 deletions onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,41 @@
}
};

// Data structure to hold user weights when ModelProtos are serialized with external data
class TensorrtUserWeights {
public:
TensorrtUserWeights(const std::string& name, const std::string& data) : name_(name),
data_cpy_(data) {
};

Check warning on line 161 in onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 You don't need a ; after a } [readability/braces] [4] Raw Output: onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h:161: You don't need a ; after a } [readability/braces] [4]

TensorrtUserWeights(const std::string& name, const void* data, size_t size) : name_(name), data_(data), size_(size) {
};

Check warning on line 164 in onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 You don't need a ; after a } [readability/braces] [4] Raw Output: onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h:164: You don't need a ; after a } [readability/braces] [4]

const char* Name() const {
return name_.c_str();
};

Check warning on line 168 in onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 You don't need a ; after a } [readability/braces] [4] Raw Output: onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h:168: You don't need a ; after a } [readability/braces] [4]

const void* Data() const {
if (!data_cpy_.empty()) {
return data_cpy_.data();
}
return data_;
}

int64_t Size() const {
if (!data_cpy_.empty()) {
return static_cast<int64_t>(data_cpy_.size());
}
return static_cast<int64_t>(size_);
}

private:
std::string name_{};
std::string data_cpy_{};
void const* data_;
size_t size_;
};

// Information to construct kernel function state.
struct TensorrtFuncState {
AllocateFunc test_allocate_func = nullptr;
Expand All @@ -168,7 +203,6 @@
std::vector<std::unordered_map<std::string, size_t>> output_info;
std::unordered_map<std::string, std::unordered_map<size_t, std::vector<std::vector<int64_t>>>> input_shape_ranges;
std::mutex* tensorrt_mu_ptr = nullptr;
std::string trt_node_name_with_precision;
bool engine_cache_enable = false;
std::string engine_cache_path;
nvinfer1::IRuntime* runtime = nullptr;
Expand All @@ -183,6 +217,7 @@
bool is_dynamic_shape = false;
std::string cache_prefix;
std::string cache_suffix;
// runtime parameters
std::vector<IAllocatorUniquePtr<void>> scratch_buffers;
std::vector<TensorParams> input_tensors;
std::vector<TensorParams> output_tensors;
Expand All @@ -204,6 +239,7 @@
std::vector<std::unordered_map<std::string, size_t>> output_info;
std::mutex* tensorrt_mu_ptr = nullptr;
bool is_dynamic_shape = false;
// runtime parameters
std::vector<IAllocatorUniquePtr<void>> scratch_buffers;
std::vector<TensorParams> input_tensors;
std::vector<TensorParams> output_tensors;
Expand Down Expand Up @@ -275,14 +311,16 @@

static common::Status RefitEngine(std::string onnx_model_filename,
std::string& onnx_model_folder_path,
std::string& weight_stripped_engine_cath_path,
bool path_check,
const void* onnx_model_bytestream,
size_t onnx_model_bytestream_size,
const void* onnx_external_data_bytestream,
size_t onnx_external_data_bytestream_size,
nvinfer1::ICudaEngine* trt_engine,
bool serialize_refitted_engine,
bool detailed_build_log);

const InlinedVector<const Node*> GetEpContextNodes() const override;

private:
mutable NvExecutionProviderInfo info_;
bool external_stream_ = false;
Expand All @@ -299,6 +337,9 @@
std::string onnx_model_folder_path_;
const void* onnx_model_bytestream_;
size_t onnx_model_bytestream_size_;
bool use_external_data_initializer_ = false;
const void* onnx_external_data_bytestream_ = nullptr;
size_t onnx_external_data_bytestream_size_ = 0;
bool sparsity_enable_ = false;
int auxiliary_streams_ = -1;
std::string cache_path_, engine_decryption_lib_path_;
Expand All @@ -317,6 +358,7 @@
std::string cache_prefix_;
std::string op_types_to_exclude_;
int nv_profile_index_ = 0;
std::unique_ptr<onnxruntime::Model> ep_context_model_;

// The format is as for TENSORRT_VERSION: (MAJOR * 100 + MINOR) * 100 + PATCH
int32_t trt_version_;
Expand All @@ -331,7 +373,6 @@
std::string ep_context_file_path_;
int ep_context_embed_mode_ = 0;
std::string ctx_model_path_;
std::string ep_cache_context_attr_;
std::string engine_cache_relative_path_to_context_model_dir;

std::unordered_set<std::string> control_flow_op_set_ = {"If", "Loop", "Scan"};
Expand Down Expand Up @@ -550,6 +591,7 @@
* going through the time-consuming processes of model parsing and engine building.
*/
Status CreateNodeComputeInfoFromPrecompiledEngine(const GraphViewer& graph_body_viewer,
size_t node_idx,
const Node& fused_node,
std::unordered_map<std::string, size_t>& input_map,
std::unordered_map<std::string, size_t>& output_map,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ NvExecutionProviderInfo NvExecutionProviderInfo::FromProviderOptions(const Provi
NvExecutionProviderInfo info{};
void* user_compute_stream = nullptr;
void* onnx_bytestream = nullptr;
void* external_data_bytestream = nullptr;
ORT_THROW_IF_ERROR(
ProviderOptionsParser{}
.AddValueParser(
Expand Down Expand Up @@ -48,21 +49,14 @@ NvExecutionProviderInfo NvExecutionProviderInfo::FromProviderOptions(const Provi
.AddAssignmentToReference(nv::provider_option_names::kProfilesMaxShapes, info.profile_max_shapes)
.AddAssignmentToReference(nv::provider_option_names::kProfilesOptShapes, info.profile_opt_shapes)
.AddAssignmentToReference(nv::provider_option_names::kCudaGraphEnable, info.cuda_graph_enable)
.AddAssignmentToReference(nv::provider_option_names::kUseExternalDataInitializer, info.use_external_data_initializer)
.AddAssignmentToReference(nv::provider_option_names::kMultiProfileEnable, info.multi_profile_enable)
.AddValueParser(
nv::provider_option_names::kONNXBytestream,
[&onnx_bytestream](const std::string& value_str) -> Status {
size_t address;
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address));
onnx_bytestream = reinterpret_cast<void*>(address);
return Status::OK();
})
.AddAssignmentToReference(nv::provider_option_names::kONNXBytestreamSize, info.onnx_bytestream_size)
.Parse(options)); // add new provider option here.

info.user_compute_stream = user_compute_stream;
info.has_user_compute_stream = (user_compute_stream != nullptr);
info.onnx_bytestream = onnx_bytestream;
info.external_data_bytestream = external_data_bytestream;

// EP context settings
// when EP context is enabled, default is to embed the engine in the context model
Expand All @@ -73,7 +67,8 @@ NvExecutionProviderInfo NvExecutionProviderInfo::FromProviderOptions(const Provi
info.dump_ep_context_model = false;
} else if (ep_context_enable == "1") {
info.dump_ep_context_model = true;
info.weight_stripped_engine_enable = true;
// We want to reenable weightless engines as soon constant initializers are supported as inputs
info.weight_stripped_engine_enable = false;
} else {
ORT_THROW("Invalid ", kOrtSessionOptionEpContextEnable, " must 0 or 1");
}
Expand Down Expand Up @@ -110,9 +105,7 @@ ProviderOptions NvExecutionProviderInfo::ToProviderOptions(const NvExecutionProv
{nv::provider_option_names::kProfilesMaxShapes, MakeStringWithClassicLocale(info.profile_max_shapes)},
{nv::provider_option_names::kProfilesOptShapes, MakeStringWithClassicLocale(info.profile_opt_shapes)},
{nv::provider_option_names::kCudaGraphEnable, MakeStringWithClassicLocale(info.cuda_graph_enable)},
{nv::provider_option_names::kONNXBytestream, MakeStringWithClassicLocale(info.onnx_bytestream)},
{nv::provider_option_names::kONNXBytestreamSize, MakeStringWithClassicLocale(info.onnx_bytestream_size)},
};
{nv::provider_option_names::kUseExternalDataInitializer, MakeStringWithClassicLocale(info.use_external_data_initializer)}};
return options;
}
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ struct NvExecutionProviderInfo {
std::string onnx_model_folder_path{""};
const void* onnx_bytestream{nullptr};
size_t onnx_bytestream_size{0};
bool use_external_data_initializer{false};
const void* external_data_bytestream{nullptr};
size_t external_data_bytestream_size{0};
bool engine_decryption_enable{false};
std::string engine_decryption_lib_path{""};
bool force_sequential_engine_build{false};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -386,22 +386,11 @@ std::string GetCachePath(const std::string& root, const std::string& name) {
* Get compute capability
*
*/
std::string GetComputeCapacity(const cudaDeviceProp& prop) {
std::string GetComputeCapability(const cudaDeviceProp& prop) {
const std::string compute_capability = std::to_string(prop.major * 10 + prop.minor);
return compute_capability;
}

/*
* Get Timing by compute capability
*
*/
std::string GetTimingCachePath(const std::string& root, std::string& compute_cap) {
// append compute capability of the GPU as this invalidates the cache and TRT will throw when loading the cache
const std::string timing_cache_name = "NvExecutionProvider_cache_sm" +
compute_cap + ".timing";
return GetCachePath(root, timing_cache_name);
}

/*
* Get cache by type
*
Expand Down
Loading
Loading