Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 44 additions & 20 deletions src/azure-cli-core/azure/cli/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,12 @@
EXCLUDED_PARAMS = ['self', 'raw', 'polling', 'custom_headers', 'operation_config',
'content_version', 'kwargs', 'client', 'no_wait']
EVENT_FAILED_EXTENSION_LOAD = 'MainLoader.OnFailedExtensionLoad'
# Extensions that will always be loaded if installed. These extensions don't expose commands but hook into CLI core.
ALWAYS_LOADED_EXTENSION_MODNAMES = ['azext_ai_examples', 'azext_ai_did_you_mean_this']

# [Reserved, in case of future usage]
# Modules that will always be loaded. They don't expose commands but hook into CLI core.
ALWAYS_LOADED_MODULES = []
# Extensions that will always be loaded if installed. They don't expose commands but hook into CLI core.
ALWAYS_LOADED_EXTENSIONS = ['azext_ai_examples', 'azext_ai_did_you_mean_this']


class AzCli(CLI):
Expand Down Expand Up @@ -153,7 +157,7 @@ def save_local_context(self, parsed_args, argument_definitions, specified_argume
class MainCommandsLoader(CLICommandsLoader):

# Format string for pretty-print the command module table
header_mod = "%-20s %10s %9s %9s" % ("Extension", "Load Time", "Groups", "Commands")
header_mod = "%-20s %10s %9s %9s" % ("Name", "Load Time", "Groups", "Commands")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix an incorrect naming.

item_format_string = "%-20s %10.3f %9d %9d"
header_ext = header_mod + " Directory"
item_ext_format_string = item_format_string + " %s"
Expand Down Expand Up @@ -181,12 +185,19 @@ def load_command_table(self, args):
get_extensions, get_extension_path, get_extension_modname)

def _update_command_table_from_modules(args, command_modules=None):
'''Loads command table(s)
When `module_name` is specified, only commands from that module will be loaded.
If the module is not found, all commands are loaded.
'''

if not command_modules:
"""Loads command tables from modules and merge into the main command table.

:param args: Arguments of the command.
:param list command_modules: Command modules to load, in the format like ['resource', 'profile'].
If None, will do module discovery and load all modules.
If [], only ALWAYS_LOADED_MODULES will be loaded.
Otherwise, the list will be extended using ALWAYS_LOADED_MODULES.
"""

# As command modules are built-in, the existence of modules in ALWAYS_LOADED_MODULES is NOT checked
if command_modules is not None:
command_modules.extend(ALWAYS_LOADED_MODULES)
else:
# Perform module discovery
command_modules = []
try:
Expand Down Expand Up @@ -234,6 +245,15 @@ def _update_command_table_from_modules(args, command_modules=None):
cumulative_group_count, cumulative_command_count)

def _update_command_table_from_extensions(ext_suppressions, extension_modname=None):
"""Loads command tables from extensions and merge into the main command table.

:param ext_suppressions: Extension suppression information.
:param extension_modname: Command modules to load, in the format like ['azext_timeseriesinsights'].
If None, will do extension discovery and load all extensions.
If [], only ALWAYS_LOADED_EXTENSIONS will be loaded.
Otherwise, the list will be extended using ALWAYS_LOADED_EXTENSIONS.
If the extensions in the list are not installed, it will be skipped.
"""

from azure.cli.core.extension.operations import check_version_compatibility

Expand All @@ -251,19 +271,20 @@ def _handle_extension_suppressions(extensions):
def _filter_modname(extensions):
# Extension's name may not be the same as its modname. eg. name: virtual-wan, modname: azext_vwan
filtered_extensions = []
extension_modname.extend(ALWAYS_LOADED_EXTENSION_MODNAMES)
for ext in extensions:
ext_name = ext.name
ext_dir = ext.path or get_extension_path(ext.name)
ext_mod = get_extension_modname(ext_name, ext_dir=ext_dir)
ext_mod = get_extension_modname(ext.name, ext.path)
# Filter the extensions according to the index
if ext_mod in extension_modname:
filtered_extensions.append(ext)
extension_modname.remove(ext_mod)
if extension_modname:
logger.debug("These extensions are not installed and will be skipped: %s", extension_modname)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will show something like

These extensions are not installed and will be skipped: ['azext_ai_examples', 'azext_ai_did_you_mean_this']

return filtered_extensions

extensions = get_extensions()
if extensions:
if extension_modname:
if extension_modname is not None:
extension_modname.extend(ALWAYS_LOADED_EXTENSIONS)
extensions = _filter_modname(extensions)
allowed_extensions = _handle_extension_suppressions(extensions)
module_commands = set(self.command_table.keys())
Expand Down Expand Up @@ -375,18 +396,21 @@ def _roughly_parse_command(args):
index_result = command_index.get(args)
if index_result:
index_modules, index_extensions = index_result
if index_modules:
_update_command_table_from_modules(args, index_modules)
if index_extensions:
# The index won't contain suppressed extensions
_update_command_table_from_extensions([], index_extensions)
# Always load modules and extensions, because some of them (like those in
# ALWAYS_LOADED_EXTENSIONS) don't expose a command, but hooks into handlers in CLI core
_update_command_table_from_modules(args, index_modules)
# The index won't contain suppressed extensions
_update_command_table_from_extensions([], index_extensions)

logger.debug("Loaded %d groups, %d commands.", len(self.command_group_table), len(self.command_table))
# The index may be outdated. Make sure the command appears in the loaded command table
command_str = _roughly_parse_command(args)
if command_str in self.command_table or command_str in self.command_group_table:
if command_str in self.command_table:
logger.debug("Found a match in the command table for '%s'", command_str)
return self.command_table
if command_str in self.command_group_table:
logger.debug("Found a match in the command group table for '%s'", command_str)
return self.command_table
Comment on lines +411 to +413
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separate the debug log for found in command table and found in the command group table.


logger.debug("Could not find a match in the command table for '%s'. The index may be outdated",
command_str)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def _prepare_test_commands_loader(loader_cls, cli_ctx, command):

class TestCommandRegistration(unittest.TestCase):

test_hook = []

@classmethod
def setUpClass(cls):
# Ensure initialization has occurred correctly
Expand Down Expand Up @@ -149,16 +151,19 @@ def _mock_iter_modules(_):
return [(None, "hello", None),
(None, "extra", None)]

def _mock_extension_modname(ext_name, ext_dir):
def _mock_get_extension_modname(ext_name, ext_dir):
if ext_name.endswith('.ExtCommandsLoader'):
return "azext_hello1"
if ext_name.endswith('.Ext2CommandsLoader'):
return "azext_hello2"
if ext_name.endswith('.ExtAlwaysLoadedCommandsLoader'):
return "azext_always_loaded"

def _mock_get_extensions():
MockExtension = namedtuple('Extension', ['name', 'preview', 'experimental', 'path', 'get_metadata'])
return [MockExtension(name=__name__ + '.ExtCommandsLoader', preview=False, experimental=False, path=None, get_metadata=lambda: {}),
MockExtension(name=__name__ + '.Ext2CommandsLoader', preview=False, experimental=False, path=None, get_metadata=lambda: {})]
MockExtension(name=__name__ + '.Ext2CommandsLoader', preview=False, experimental=False, path=None, get_metadata=lambda: {}),
MockExtension(name=__name__ + '.ExtAlwaysLoadedCommandsLoader', preview=False, experimental=False, path=None, get_metadata=lambda: {})]

def _mock_load_command_loader(loader, args, name, prefix):

Expand All @@ -181,7 +186,7 @@ def load_command_table(self, args):
self.__module__ = "azure.cli.command_modules.extra"
return self.command_table

# A command from an extension
# Extend existing group by adding a new command
class ExtCommandsLoader(AzCommandsLoader):

def load_command_table(self, args):
Expand All @@ -191,7 +196,7 @@ def load_command_table(self, args):
self.__module__ = "azext_hello1"
return self.command_table

# A command from an extension that overrides the original command
# Override existing command
class Ext2CommandsLoader(AzCommandsLoader):

def load_command_table(self, args):
Expand All @@ -201,10 +206,21 @@ def load_command_table(self, args):
self.__module__ = "azext_hello2"
return self.command_table

# Contain no command, but hook into CLI core
class ExtAlwaysLoadedCommandsLoader(AzCommandsLoader):

def load_command_table(self, args):
# Hook something fake into the test_hook
TestCommandRegistration.test_hook = "FAKE_HANDLER"
self.__module__ = "azext_always_loaded"
return self.command_table

if prefix == 'azure.cli.command_modules.':
command_loaders = {'hello': TestCommandsLoader, 'extra': Test2CommandsLoader}
else:
command_loaders = {'azext_hello1': ExtCommandsLoader, 'azext_hello2': Ext2CommandsLoader}
command_loaders = {'azext_hello1': ExtCommandsLoader,
'azext_hello2': Ext2CommandsLoader,
'azext_always_loaded': ExtAlwaysLoadedCommandsLoader}

module_command_table = {}
for mod_name, loader_cls in command_loaders.items():
Expand All @@ -221,7 +237,7 @@ def load_command_table(self, args):
@mock.patch('importlib.import_module', _mock_import_lib)
@mock.patch('pkgutil.iter_modules', _mock_iter_modules)
@mock.patch('azure.cli.core.commands._load_command_loader', _mock_load_command_loader)
@mock.patch('azure.cli.core.extension.get_extension_modname', _mock_extension_modname)
@mock.patch('azure.cli.core.extension.get_extension_modname', _mock_get_extension_modname)
@mock.patch('azure.cli.core.extension.get_extensions', _mock_get_extensions)
def test_register_command_from_extension(self):

Expand All @@ -242,11 +258,10 @@ def test_register_command_from_extension(self):
self.assertTrue(isinstance(hello_overridden_cmd.command_source, ExtensionCommandSource))
self.assertTrue(hello_overridden_cmd.command_source.overrides_command)

@mock.patch.dict("os.environ", {"AZURE_CORE_USE_COMMAND_INDEX": "True"})
@mock.patch('importlib.import_module', _mock_import_lib)
@mock.patch('pkgutil.iter_modules', _mock_iter_modules)
@mock.patch('azure.cli.core.commands._load_command_loader', _mock_load_command_loader)
@mock.patch('azure.cli.core.extension.get_extension_modname', _mock_extension_modname)
@mock.patch('azure.cli.core.extension.get_extension_modname', _mock_get_extension_modname)
@mock.patch('azure.cli.core.extension.get_extensions', _mock_get_extensions)
def test_command_index(self):

Expand Down Expand Up @@ -377,6 +392,32 @@ def update_and_check_index():
del INDEX[CommandIndex._COMMAND_INDEX_CLOUD_PROFILE]
del INDEX[CommandIndex._COMMAND_INDEX]

@mock.patch('importlib.import_module', _mock_import_lib)
@mock.patch('pkgutil.iter_modules', _mock_iter_modules)
@mock.patch('azure.cli.core.commands._load_command_loader', _mock_load_command_loader)
@mock.patch('azure.cli.core.extension.get_extension_modname', _mock_get_extension_modname)
@mock.patch('azure.cli.core.extension.get_extensions', _mock_get_extensions)
def test_command_index_always_loaded_extensions(self):

cli = DummyCli()
loader = cli.commands_loader

from azure.cli.core import CommandIndex
index = CommandIndex()
index.invalidate()

# Test azext_always_loaded is loaded when command index is rebuilt
with mock.patch('azure.cli.core.ALWAYS_LOADED_EXTENSIONS', ['azext_always_loaded']):
loader.load_command_table(["hello", "mod-only"])
self.assertEqual(TestCommandRegistration.test_hook, "FAKE_HANDLER")

TestCommandRegistration.test_hook = []

# Test azext_always_loaded is loaded when command index is used
with mock.patch('azure.cli.core.ALWAYS_LOADED_EXTENSIONS', ['azext_always_loaded']):
loader.load_command_table(["hello", "mod-only"])
self.assertEqual(TestCommandRegistration.test_hook, "FAKE_HANDLER")

def test_argument_with_overrides(self):

global_vm_name_type = CLIArgumentType(
Expand Down