diff --git a/src/vswhere.lib/CommandArgs.cpp b/src/vswhere.lib/CommandArgs.cpp index 05ce169..e9c542b 100644 --- a/src/vswhere.lib/CommandArgs.cpp +++ b/src/vswhere.lib/CommandArgs.cpp @@ -15,6 +15,9 @@ static wstring ParseArgument(IteratorType& it, const IteratorType& end, const Co template static void ParseArgumentArray(IteratorType& it, const IteratorType& end, const CommandParser::Token& arg, vector& arr); +template +static void ParseRequiresArray(IteratorType& it, const IteratorType& end, const CommandParser::Token& arg, vector& literals, vector& patterns); + const vector CommandArgs::s_Products { L"Microsoft.VisualStudio.Product.Enterprise", @@ -71,7 +74,7 @@ void CommandArgs::Parse(_In_ vector args) } else if (ArgumentEquals(arg.Value, L"requires")) { - ParseArgumentArray(it, args.end(), arg, m_requires); + ParseRequiresArray(it, args.end(), arg, m_requires, m_requiresPattern); hasSelection = true; } else if (ArgumentEquals(arg.Value, L"requiresAny")) @@ -218,7 +221,7 @@ void CommandArgs::Parse(_In_ vector args) void CommandArgs::Usage(_In_ Console& console) const { auto pos = m_path.find_last_of(L"\\"); - auto path = ++pos != wstring::npos ? m_path.substr(pos) : m_path; + auto& path = ++pos != wstring::npos ? m_path.substr(pos) : m_path; console.WriteLine(ResourceManager::FormatString(IDS_USAGE, path.c_str())); @@ -231,6 +234,37 @@ void CommandArgs::Usage(_In_ Console& console) const } } +std::wregex CommandArgs::ParseRegex(_In_ const std::wstring& pattern) noexcept +{ + // Reserve ~125% of the incoming pattern to hold any changes. + wstring accumulator; + accumulator.reserve(pattern.size() * 1.25); + + for (auto it = pattern.begin(); it != pattern.end(); ++it) + { + switch (*it) + { + case L'.': + accumulator += L"\\."; + break; + + case L'*': + accumulator += L".*"; + break; + + case L'?': + accumulator += L"."; + break; + + default: + accumulator += *it; + break; + } + } + + return std::move(wregex(accumulator, wregex::basic | wregex::icase | wregex::nosubs)); +} + static bool ArgumentEquals(_In_ const wstring& name, _In_ LPCWSTR expect) { _ASSERT(expect && *expect); @@ -281,3 +315,38 @@ static void ParseArgumentArray(IteratorType& it, const IteratorType& end, const arr.push_back(it->Value); } } + +template +static void ParseRequiresArray(IteratorType& it, const IteratorType& end, const CommandParser::Token& arg, vector& literals, vector& patterns) +{ + wstring& param = it->Value; + auto nit = next(it); + + // Require arguments if the parameter is specified. + if (nit == end || CommandParser::Token::eArgument != nit->Type) + { + auto message = ResourceManager::FormatString(IDS_E_ARGREQUIRED, param.c_str()); + throw win32_error(ERROR_INVALID_PARAMETER, message); + } + + while (nit != end) + { + if (CommandParser::Token::eParameter == nit->Type) + { + break; + } + + ++it; + ++nit; + + if (it->Value.find(L'*', 0) == wstring::npos && it->Value.find(L'?', 0) == wstring::npos) + { + literals.push_back(it->Value); + } + else + { + auto pattern = CommandArgs::ParseRegex(it->Value); + patterns.push_back(std::move(pattern)); + } + } +} diff --git a/src/vswhere.lib/CommandArgs.h b/src/vswhere.lib/CommandArgs.h index e9e3ddc..1e48b71 100644 --- a/src/vswhere.lib/CommandArgs.h +++ b/src/vswhere.lib/CommandArgs.h @@ -30,6 +30,7 @@ class CommandArgs m_productsAll(obj.m_productsAll), m_products(obj.m_products), m_requires(obj.m_requires), + m_requiresPattern(obj.m_requiresPattern), m_version(obj.m_version), m_latest(obj.m_latest), m_legacy(obj.m_legacy), @@ -72,6 +73,11 @@ class CommandArgs return m_requires; } + const std::vector& get_RequiresPattern() const noexcept + { + return m_requiresPattern; + } + const bool get_RequiresAny() const noexcept { return m_requiresAny; @@ -157,6 +163,8 @@ class CommandArgs void Parse(_In_ int argc, _In_ LPCWSTR argv[]); void Usage(_In_ Console& console) const; + static std::wregex ParseRegex(_In_ const std::wstring& pattern) noexcept; + private: static const std::vector s_Products; static const std::wstring s_Format; @@ -168,6 +176,7 @@ class CommandArgs bool m_productsAll; std::vector m_products; std::vector m_requires; + std::vector m_requiresPattern; bool m_requiresAny; std::wstring m_version; bool m_latest; diff --git a/src/vswhere.lib/Formatter.cpp b/src/vswhere.lib/Formatter.cpp index 75f9ebf..1c91e8f 100644 --- a/src/vswhere.lib/Formatter.cpp +++ b/src/vswhere.lib/Formatter.cpp @@ -342,7 +342,7 @@ void Formatter::WritePackages(_In_ ISetupInstance* pInstance) StartArray(L"packages"); SafeArray saPackages(psaPackages); - const auto packages = saPackages.Elements(); + const auto& packages = saPackages.Elements(); for (const auto& package : packages) { @@ -431,6 +431,7 @@ bool Formatter::WriteProperties(_In_ ISetupPropertyStore* pProperties, _In_opt_ SafeArray saNames(psaNames); + // Copy the elements so we can sort them. auto elems = saNames.Elements(); sort(elems.begin(), elems.end(), less); diff --git a/src/vswhere.lib/InstanceSelector.cpp b/src/vswhere.lib/InstanceSelector.cpp index 6c76158..d0b9006 100644 --- a/src/vswhere.lib/InstanceSelector.cpp +++ b/src/vswhere.lib/InstanceSelector.cpp @@ -8,6 +8,8 @@ using namespace std; using std::placeholders::_1; +ci_equal InstanceSelector::s_comparer; + InstanceSelector::InstanceSelector(_In_ const CommandArgs& args, _In_ ILegacyProvider& provider, _In_opt_ ISetupHelper* pHelper) : m_args(args), m_provider(provider), @@ -17,7 +19,7 @@ InstanceSelector::InstanceSelector(_In_ const CommandArgs& args, _In_ ILegacyPro m_helper = pHelper; if (m_helper) { - auto version = args.get_Version(); + auto& version = args.get_Version(); if (!version.empty()) { auto hr = m_helper->ParseVersionRange(version.c_str(), &m_ullMinimumVersion, &m_ullMaximumVersion); @@ -224,7 +226,7 @@ bool InstanceSelector::IsProductMatch(_In_ ISetupInstance2* pInstance) const } // Asterisk on command line will clear the array to find any products. - const auto products = m_args.get_Products(); + const auto& products = m_args.get_Products(); if (products.empty()) { return true; @@ -250,21 +252,19 @@ bool InstanceSelector::IsWorkloadMatch(_In_ ISetupInstance2* pInstance) const { _ASSERT(pInstance); - const auto requires = m_args.get_Requires(); - if (requires.empty()) + // Create copies and erase elements as found. + auto literals = m_args.get_Requires(); + auto literals_count = literals.size(); + + auto patterns = m_args.get_RequiresPattern(); + auto patterns_count = patterns.size(); + + if (literals.empty() && patterns.empty()) { // No workloads required matches every instance. return true; } - // Keep track of which requirements we matched. - typedef map MapType; - MapType found; - for (const auto& require : requires) - { - found.emplace(make_pair(require, false)); - } - LPSAFEARRAY psa = NULL; auto hr = pInstance->GetPackages(&psa); if (FAILED(hr)) @@ -277,25 +277,34 @@ bool InstanceSelector::IsWorkloadMatch(_In_ ISetupInstance2* pInstance) const { auto id = GetId(package); - auto it = found.find(id); - if (it != found.end()) + for (auto it = literals.cbegin(); it != literals.cend(); ++it) + { + if (s_comparer(id, *it)) + { + literals.erase(it); + goto next; + } + } + + for (auto it = patterns.cbegin(); it != patterns.cend(); ++it) { - it->second = true; + if (regex_match(id, *it)) + { + patterns.erase(it); + goto next; + } } + + next: continue; } if (m_args.get_RequiresAny()) { - return any_of(found.begin(), found.end(), [](MapType::const_reference pair) -> bool - { - return pair.second; - }); + return literals.size() < literals_count + || patterns.size() < patterns_count; } - return all_of(found.begin(), found.end(), [](MapType::const_reference pair) -> bool - { - return pair.second; - }); + return literals.empty() && patterns.empty(); } bool InstanceSelector::IsVersionMatch(_In_ ISetupInstance* pInstance) const diff --git a/src/vswhere.lib/InstanceSelector.h b/src/vswhere.lib/InstanceSelector.h index 3e0950a..0f4c0a6 100644 --- a/src/vswhere.lib/InstanceSelector.h +++ b/src/vswhere.lib/InstanceSelector.h @@ -27,6 +27,8 @@ class InstanceSelector std::vector Select(_In_opt_ IEnumSetupInstances* pEnum) const; private: + static ci_equal s_comparer; + static std::wstring GetId(_In_ ISetupPackageReference* pPackageReference); bool IsMatch(_In_ ISetupInstance* pInstance) const; bool IsProductMatch(_In_ ISetupInstance2* pInstance) const; diff --git a/src/vswhere.lib/Module.cpp b/src/vswhere.lib/Module.cpp index dd24252..b0155c6 100644 --- a/src/vswhere.lib/Module.cpp +++ b/src/vswhere.lib/Module.cpp @@ -36,7 +36,7 @@ const wstring& Module::get_Path() noexcept const wstring& Module::get_FileVersion() noexcept { - auto path = get_Path(); + auto& path = get_Path(); if (path.empty()) { return m_fileVersion; diff --git a/src/vswhere.lib/vswhere.lib.rc b/src/vswhere.lib/vswhere.lib.rc index 0a4135e..ec00776 100644 --- a/src/vswhere.lib/vswhere.lib.rc +++ b/src/vswhere.lib/vswhere.lib.rc @@ -83,9 +83,12 @@ BEGIN \n See https://aka.ms/vs/workloads for a list of product IDs.\ \n -requires arg One or more workload or component IDs required when finding instances.\ \n All specified IDs must be installed unless -requiresAny is specified.\ +\n You can specify wildcards including ""?"" to match any one character,\ +\n or ""*"" to match zero or more of any characters.\ \n See https://aka.ms/vs/workloads for a list of workload and component IDs.\ \n -requiresAny Find instances with any one or more workload or components IDs passed to -requires.\ \n -version arg A version range for instances to find. Example: [15.0,16.0) will find versions 15.*.\ +\n See https://aka.ms/vswhere/versions for more information about versions.\ \n -latest Return only the newest version and last installed.\ \n -sort Sorts the instances from newest version and last installed to oldest.\ \n When used with ""find"", first instances are sorted then files are sorted lexigraphically.\ diff --git a/src/vswhere/Program.cpp b/src/vswhere/Program.cpp index 1e997a1..049fa1f 100644 --- a/src/vswhere/Program.cpp +++ b/src/vswhere/Program.cpp @@ -198,7 +198,7 @@ void WriteLogo(_In_ const CommandArgs& args, _In_ Console& console, _In_ Module& { if (args.get_Logo()) { - const auto version = module.get_FileVersion(); + const auto& version = module.get_FileVersion(); const auto nID = version.empty() ? IDS_PROGRAMINFO : IDS_PROGRAMINFOEX; console.WriteLine(ResourceManager::FormatString(nID, NBGV_INFORMATIONAL_VERSION, version.c_str())); diff --git a/test/vswhere.test/CommandArgsTests.cpp b/test/vswhere.test/CommandArgsTests.cpp index 49352ec..e2b1023 100644 --- a/test/vswhere.test/CommandArgsTests.cpp +++ b/test/vswhere.test/CommandArgsTests.cpp @@ -401,4 +401,56 @@ TEST_CLASS(CommandArgsTests) Assert::IsFalse(args.get_Logo()); Assert::IsTrue(args.get_UTF8()); } + + BEGIN_TEST_METHOD_ATTRIBUTE(Parse_Requires_Patterns) + TEST_WORKITEM(276) + END_TEST_METHOD_ATTRIBUTE() + TEST_METHOD(Parse_Requires_Patterns) + { + CommandArgs args; + args.Parse(L"vswhere.exe -requires foo ba* qux"); + + const auto& literals = args.get_Requires(); + const auto& patterns = args.get_RequiresPattern(); + + Assert::AreEqual(1, count(literals.cbegin(), literals.cend(), wstring(L"foo"))); + Assert::AreEqual(1, count(literals.cbegin(), literals.cend(), wstring(L"qux"))); + Assert::AreEqual(1, patterns.size()); + } + + BEGIN_TEST_METHOD_ATTRIBUTE(ParseRegex_Theory) + TEST_WORKITEM(276) + END_TEST_METHOD_ATTRIBUTE() + TEST_METHOD(ParseRegex_Theory) + { + const wstring id = L"Foo.Bar"; + vector> data = + { + { L"Foo.Bar", true }, + { L"Foo.*", true }, + { L"*.Bar", true }, + { L"F*R", true }, + { L"foo?bar", true }, + { L"f??", false }, + { L"f??.??r", true }, + { L"*", true }, + { L".*", false }, + { L"?", false }, + { L"Baz", false }, + { L"*baz", false }, + { L"foo.bar*", true }, + }; + + for (const auto& item : data) + { + wstring pattern; + bool expected; + + tie(pattern, expected) = item; + auto re = CommandArgs::ParseRegex(pattern); + bool actual = regex_match(id, re); + + Assert::AreEqual(expected, actual, format(L"\"%ls\" =~ /%ls/", id.c_str(), pattern.c_str()).c_str()); + } + } }; diff --git a/test/vswhere.test/InstanceSelectorTests.cpp b/test/vswhere.test/InstanceSelectorTests.cpp index ff24a23..7c80015 100644 --- a/test/vswhere.test/InstanceSelectorTests.cpp +++ b/test/vswhere.test/InstanceSelectorTests.cpp @@ -697,4 +697,82 @@ TEST_CLASS(InstanceSelectorTests) Assert::AreEqual(S_OK, selected[1]->GetInstanceId(bstrInstanceId.GetAddress())); Assert::AreEqual(L"a", bstrInstanceId); } + + BEGIN_TEST_METHOD_ATTRIBUTE(Select_RequiresPattern_Workload) + TEST_WORKITEM(276) + END_TEST_METHOD_ATTRIBUTE() + TEST_METHOD(Select_RequiresPattern_Workload) + { + TestPackageReference product = + { + { L"Id", L"Microsoft.VisualStudio.Product.Enterprise" }, + }; + + TestPackageReference managedDesktop = { { L"Id", L"Microsoft.VisualStudio.Workload.ManagedDesktop" } }; + TestPackageReference nativeDesktop = { { L"Id", L"Microsoft.VisualStudio.Workload.NativeDesktop" } }; + vector packages = + { + &managedDesktop, + &nativeDesktop, + }; + + TestInstance::MapType properties = + { + { L"InstanceId", L"a1b2c3" }, + { L"InstallationName", L"test" }, + }; + + TestInstance instance(&product, packages, properties); + TestEnumInstances instances = + { + &instance, + }; + + CommandArgs args; + args.Parse(L"vswhere.exe -requires microsoft.visualstudio.workload.*desktop"); + + InstanceSelector sut(args); + auto selected = sut.Select(&instances); + + Assert::AreEqual(1, selected.size()); + } + + BEGIN_TEST_METHOD_ATTRIBUTE(Select_RequiresAnyPattern_Workload) + TEST_WORKITEM(276) + END_TEST_METHOD_ATTRIBUTE() + TEST_METHOD(Select_RequiresAnyPattern_Workload) + { + TestPackageReference product = + { + { L"Id", L"Microsoft.VisualStudio.Product.Enterprise" }, + }; + + TestPackageReference managedDesktop = { { L"Id", L"Microsoft.VisualStudio.Workload.ManagedDesktop" } }; + TestPackageReference nativeDesktop = { { L"Id", L"Microsoft.VisualStudio.Workload.NativeDesktop" } }; + vector packages = + { + &managedDesktop, + &nativeDesktop, + }; + + TestInstance::MapType properties = + { + { L"InstanceId", L"a1b2c3" }, + { L"InstallationName", L"test" }, + }; + + TestInstance instance(&product, packages, properties); + TestEnumInstances instances = + { + &instance, + }; + + CommandArgs args; + args.Parse(L"vswhere.exe -requires microsoft.visualstudio.workload.azure microsoft.visualstudio.workload.manageddeskto? -requiresAny"); + + InstanceSelector sut(args); + auto selected = sut.Select(&instances); + + Assert::AreEqual(1, selected.size()); + } }; diff --git a/version.json b/version.json index 3bc9fcf..ee7f430 100644 --- a/version.json +++ b/version.json @@ -1,6 +1,6 @@ { "$schema": "https://raw.githubusercontent.com/AArnott/Nerdbank.GitVersioning/master/src/NerdBank.GitVersioning/version.schema.json", - "version": "3.0", + "version": "3.1", "publicReleaseRefSpec": [ "^refs/heads/master$", "^refs/tags/v\\d+\\.\\d+"