diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index 5c81696d5c57e..a22375320edae 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -30,8 +30,14 @@ static const onnxruntime::perftest::PerformanceTestConfig& DefaultPerformanceTes return default_config; } -ABSL_FLAG(std::string, f, "", "Specifies a free dimension by name to override to a specific value for performance optimization."); -ABSL_FLAG(std::string, F, "", "Specifies a free dimension by denotation to override to a specific value for performance optimization."); +ABSL_FLAG(std::string, f, "", + "Specifies a free dimension by name to override to a specific value for performance optimization.\n" + "[Usage]: -f \"dimension_name1:override_value1\" -f \"dimension_name2:override_value2\" ... or" + " -f \"dimension_name1:override_value1 dimension_name2:override_value2 ... \". Override value must > 0."); +ABSL_FLAG(std::string, F, "", + "Specifies a free dimension by denotation to override to a specific value for performance optimization.\n" + "[Usage]: -f \"dimension_denotation1:override_value1\" -f \"dimension_denotation2:override_value2\" ... or" + " -f \"dimension_denotation1:override_value1 dimension_denotation2 : override_value2... \". Override value must > 0."); ABSL_FLAG(std::string, m, "duration", "Specifies the test mode. Value could be 'duration' or 'times'."); ABSL_FLAG(std::string, e, "cpu", "Specifies the provider 'cpu','cuda','dnnl','tensorrt', 'nvtensorrtrtx', 'openvino', 'dml', 'acl', 'nnapi', 'coreml', 'qnn', 'snpe', 'rocm', 'migraphx', 'xnnpack', 'vitisai' or 'webgpu'."); ABSL_FLAG(size_t, r, DefaultPerformanceTestConfig().run_config.repeated_times, "Specifies the repeated times if running in 'times' test mode."); @@ -168,26 +174,6 @@ ABSL_FLAG(bool, h, false, "Print program usage."); namespace onnxruntime { namespace perftest { -static bool ParseDimensionOverride(std::string& dim_identifier, int64_t& override_val, const char* option) { - std::basic_string free_dim_str(option); - size_t delimiter_location = free_dim_str.find(":"); - if (delimiter_location >= free_dim_str.size() - 1) { - return false; - } - dim_identifier = free_dim_str.substr(0, delimiter_location); - std::string override_val_str = free_dim_str.substr(delimiter_location + 1, std::string::npos); - ORT_TRY { - override_val = std::stoll(override_val_str.c_str()); - if (override_val <= 0) { - return false; - } - } - ORT_CATCH(...) { - return false; - } - return true; -} - std::string CustomUsageMessage() { std::ostringstream oss; oss << "onnxruntime_perf_test [options...] model_path [result_file]\n\n"; @@ -212,20 +198,21 @@ bool CommandLineParser::ParseArguments(PerformanceTestConfig& test_config, int a absl::SetFlagsUsageConfig(config); absl::SetProgramUsageMessage(CustomUsageMessage()); - auto utf8_strings = utils::ConvertArgvToUtf8Strings(argc, argv); - auto utf8_argv = utils::CStringsFromStrings(utf8_strings); + auto utf8_argv_strings = utils::ConvertArgvToUtf8Strings(argc, argv); + auto utf8_argv = utils::CStringsFromStrings(utf8_argv_strings); auto positional = absl::ParseCommandLine(static_cast(utf8_argv.size()), utf8_argv.data()); // -f { const auto& dim_override_str = absl::GetFlag(FLAGS_f); if (!dim_override_str.empty()) { - std::string dim_name; - int64_t override_val; - if (!ParseDimensionOverride(dim_name, override_val, dim_override_str.c_str())) { + // Abseil doesn't support the same option being provided multiple times - only the last occurrence is applied. + // To preserve the previous usage of '-f', where users may specify it multiple times to override different dimension names, + // we need to manually parse argv. + std::string option = "f"; + if (!ParseDimensionOverrideFromArgv(argc, utf8_argv_strings, option, test_config.run_config.free_dim_name_overrides)) { return false; } - test_config.run_config.free_dim_name_overrides[dim_name] = override_val; } } @@ -233,12 +220,11 @@ bool CommandLineParser::ParseArguments(PerformanceTestConfig& test_config, int a { const auto& dim_override_str = absl::GetFlag(FLAGS_F); if (!dim_override_str.empty()) { - std::string dim_denotation; - int64_t override_val; - if (!ParseDimensionOverride(dim_denotation, override_val, dim_override_str.c_str())) { + // Same reason as '-f' above to manully parse argv. + std::string option = "F"; + if (!ParseDimensionOverrideFromArgv(argc, utf8_argv_strings, option, test_config.run_config.free_dim_denotation_overrides)) { return false; } - test_config.run_config.free_dim_denotation_overrides[dim_denotation] = override_val; } } diff --git a/onnxruntime/test/perftest/main.cc b/onnxruntime/test/perftest/main.cc index 973baf774b024..513122609bb01 100644 --- a/onnxruntime/test/perftest/main.cc +++ b/onnxruntime/test/perftest/main.cc @@ -35,7 +35,7 @@ int real_main(int argc, char* argv[]) { } ORT_CATCH(const Ort::Exception& e) { ORT_HANDLE_EXCEPTION([&]() { - fprintf(stderr, "Error creating environment: %s \n", e.what()); + std::cerr << "Error creating environment: " << e.what() << std::endl; failed = true; }); } @@ -98,7 +98,7 @@ int main(int argc, char* argv[]) { } ORT_CATCH(const std::exception& ex) { ORT_HANDLE_EXCEPTION([&]() { - fprintf(stderr, "%s\n", ex.what()); + std::cerr << ex.what() << std::endl; retval = -1; }); } diff --git a/onnxruntime/test/perftest/strings_helper.cc b/onnxruntime/test/perftest/strings_helper.cc index f4860b35c79da..5743346f8edf1 100644 --- a/onnxruntime/test/perftest/strings_helper.cc +++ b/onnxruntime/test/perftest/strings_helper.cc @@ -56,6 +56,53 @@ void ParseSessionConfigs(const std::string& configs_string, } } +bool ParseDimensionOverride(const std::string& input, std::map& free_dim_override_map) { + std::stringstream ss(input); + std::string free_dim_str; + + while (std::getline(ss, free_dim_str, ' ')) { + if (!free_dim_str.empty()) { + size_t delimiter_location = free_dim_str.find(":"); + if (delimiter_location >= free_dim_str.size() - 1) { + return false; + } + std::string dim_identifier = free_dim_str.substr(0, delimiter_location); + std::string override_val_str = free_dim_str.substr(delimiter_location + 1, std::string::npos); + ORT_TRY { + int64_t override_val = std::stoll(override_val_str.c_str()); + if (override_val <= 0) { + return false; + } + free_dim_override_map[dim_identifier] = override_val; + } + ORT_CATCH(const std::exception& ex) { + ORT_HANDLE_EXCEPTION([&]() { + std::cerr << "Error parsing free dimension override value: " << override_val_str.c_str() << ", " << ex.what() << std::endl; + }); + return false; + } + } + } + + return true; +} + +bool ParseDimensionOverrideFromArgv(int argc, std::vector& argv, std::string& option, std::map& free_dim_override_map) { + for (int i = 1; i < argc; ++i) { + auto utf8_arg = argv[i]; + if (utf8_arg == ("-" + option) || utf8_arg == ("--" + option)) { + auto value_idx = i + 1; + if (value_idx >= argc || argv[value_idx][0] == '-') { + std::cerr << utf8_arg << " should be followed by a key-value pair." << std::endl; + return false; + } + + if (!ParseDimensionOverride(argv[value_idx], free_dim_override_map)) return false; + } + } + return true; +} + void ParseEpOptions(const std::string& input, std::vector>& result) { auto tokens = utils::SplitString(input, ";", true); diff --git a/onnxruntime/test/perftest/strings_helper.h b/onnxruntime/test/perftest/strings_helper.h index 621ab746273bd..a33b3d5089c9b 100644 --- a/onnxruntime/test/perftest/strings_helper.h +++ b/onnxruntime/test/perftest/strings_helper.h @@ -3,6 +3,7 @@ // SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include +#include #include #include #include @@ -14,6 +15,10 @@ void ParseSessionConfigs(const std::string& configs_string, std::unordered_map& session_configs, const std::unordered_set& available_keys = {}); +bool ParseDimensionOverride(const std::string& input, std::map& free_dim_override_map); + +bool ParseDimensionOverrideFromArgv(int argc, std::vector& argv, std::string& option, std::map& free_dim_override_map); + void ParseEpList(const std::string& input, std::vector& result); void ParseEpOptions(const std::string& input, std::vector>& result);