Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 8 additions & 30 deletions onnxruntime/test/perftest/command_args_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,12 @@ static const onnxruntime::perftest::PerformanceTestConfig& DefaultPerformanceTes
return default_config;
}

ABSL_FLAG(std::string, f, "", "Specifies a free dimension by name to override to a specific value for performance optimization.");
ABSL_FLAG(std::string, F, "", "Specifies a free dimension by denotation to override to a specific value for performance optimization.");
ABSL_FLAG(std::string, f, "",
"Specifies free dimensions by name to override with the specific values for performance optimization. Key-value pairs are separated by space.\n"
"[Usage]: -f \"dimension_name_1:override_value_1 dimension_name_2:override_value_2 ... \". Override value must > 0.");
ABSL_FLAG(std::string, F, "",
"Specifies free dimensions by denotation to override with the specific values for performance optimization. Key-value pairs are separated by space.\n"
"[Usage]: -F \"denotation_name_1:override_value_1 denotation_name_2:override_value_2 ... \". Override value must > 0.");
ABSL_FLAG(std::string, m, "duration", "Specifies the test mode. Value could be 'duration' or 'times'.");
ABSL_FLAG(std::string, e, "cpu", "Specifies the provider 'cpu','cuda','dnnl','tensorrt', 'nvtensorrtrtx', 'openvino', 'dml', 'acl', 'nnapi', 'coreml', 'qnn', 'snpe', 'rocm', 'migraphx', 'xnnpack', 'vitisai' or 'webgpu'.");
ABSL_FLAG(size_t, r, DefaultPerformanceTestConfig().run_config.repeated_times, "Specifies the repeated times if running in 'times' test mode.");
Expand Down Expand Up @@ -168,26 +172,6 @@ ABSL_FLAG(bool, h, false, "Print program usage.");
namespace onnxruntime {
namespace perftest {

static bool ParseDimensionOverride(std::string& dim_identifier, int64_t& override_val, const char* option) {
std::basic_string<char> free_dim_str(option);
size_t delimiter_location = free_dim_str.find(":");
if (delimiter_location >= free_dim_str.size() - 1) {
return false;
}
dim_identifier = free_dim_str.substr(0, delimiter_location);
std::string override_val_str = free_dim_str.substr(delimiter_location + 1, std::string::npos);
ORT_TRY {
override_val = std::stoll(override_val_str.c_str());
if (override_val <= 0) {
return false;
}
}
ORT_CATCH(...) {
return false;
}
return true;
}

std::string CustomUsageMessage() {
std::ostringstream oss;
oss << "onnxruntime_perf_test [options...] model_path [result_file]\n\n";
Expand Down Expand Up @@ -220,25 +204,19 @@ bool CommandLineParser::ParseArguments(PerformanceTestConfig& test_config, int a
{
const auto& dim_override_str = absl::GetFlag(FLAGS_f);
if (!dim_override_str.empty()) {
std::string dim_name;
int64_t override_val;
if (!ParseDimensionOverride(dim_name, override_val, dim_override_str.c_str())) {
if (!ParseDimensionOverride(dim_override_str, test_config.run_config.free_dim_name_overrides)) {
return false;
}
test_config.run_config.free_dim_name_overrides[dim_name] = override_val;
}
}

// -F
{
const auto& dim_override_str = absl::GetFlag(FLAGS_F);
if (!dim_override_str.empty()) {
std::string dim_denotation;
int64_t override_val;
if (!ParseDimensionOverride(dim_denotation, override_val, dim_override_str.c_str())) {
if (!ParseDimensionOverride(dim_override_str, test_config.run_config.free_dim_denotation_overrides)) {
return false;
}
test_config.run_config.free_dim_denotation_overrides[dim_denotation] = override_val;
}
}

Expand Down
28 changes: 28 additions & 0 deletions onnxruntime/test/perftest/strings_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,34 @@ void ParseSessionConfigs(const std::string& configs_string,
}
}

bool ParseDimensionOverride(const std::string& input, std::map<std::string, int64_t>& free_dim_override_map) {
std::stringstream ss(input);
std::string free_dim_str;

while (std::getline(ss, free_dim_str, ' ')) {
if (!free_dim_str.empty()) {
size_t delimiter_location = free_dim_str.find(":");
if (delimiter_location >= free_dim_str.size() - 1) {
return false;
}
std::string dim_identifier = free_dim_str.substr(0, delimiter_location);
std::string override_val_str = free_dim_str.substr(delimiter_location + 1, std::string::npos);
ORT_TRY {
int64_t override_val = std::stoll(override_val_str.c_str());
if (override_val <= 0) {
return false;
}
free_dim_override_map[dim_identifier] = override_val;
}
ORT_CATCH(...) {
return false;
}
}
}

return true;
}

void ParseEpOptions(const std::string& input, std::vector<std::unordered_map<std::string, std::string>>& result) {
auto tokens = utils::SplitString(input, ";", true);

Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/test/perftest/strings_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <[email protected]>
// Licensed under the MIT License.
#include <string_view>
#include <map>
#include <unordered_map>
#include <unordered_set>
#include <vector>
Expand All @@ -14,6 +15,8 @@ void ParseSessionConfigs(const std::string& configs_string,
std::unordered_map<std::string, std::string>& session_configs,
const std::unordered_set<std::string>& available_keys = {});

bool ParseDimensionOverride(const std::string& input, std::map<std::string, int64_t>& free_dim_override_map);

void ParseEpList(const std::string& input, std::vector<std::string>& result);

void ParseEpOptions(const std::string& input, std::vector<std::unordered_map<std::string, std::string>>& result);
Expand Down
Loading