Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Extract class CommandIndex
  • Loading branch information
jiasli committed Jun 12, 2020
commit 8defa9f78e0b2c2fece4f76c2f2936ecbe901733
191 changes: 110 additions & 81 deletions src/azure-cli-core/azure/cli/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,6 @@
EXCLUDED_PARAMS = ['self', 'raw', 'polling', 'custom_headers', 'operation_config',
'content_version', 'kwargs', 'client', 'no_wait']
EVENT_FAILED_EXTENSION_LOAD = 'MainLoader.OnFailedExtensionLoad'
_COMMAND_INDEX = 'commandIndex'
_COMMAND_INDEX_VERSION = 'version'
_COMMAND_INDEX_CLOUD_PROFILE = 'cloudProfile'
# Extensions that will always be loaded if installed. These extensions don't expose commands but hook into CLI core.
ALWAYS_LOADED_EXTENSION_MODNAMES = ['azext_ai_examples', 'azext_ai_did_you_mean_this']

Expand Down Expand Up @@ -182,7 +179,6 @@ def load_command_table(self, args):
_load_module_command_loader, _load_extension_command_loader, BLOCKED_MODS, ExtensionCommandSource)
from azure.cli.core.extension import (
get_extensions, get_extension_path, get_extension_modname)
from azure.cli.core._session import INDEX

def _update_command_table_from_modules(args, command_modules=None):
'''Loads command table(s)
Expand Down Expand Up @@ -351,24 +347,6 @@ def _get_extension_suppressions(mod_loaders):
res.append(sup)
return res

def _update_command_index():
start_time = timeit.default_timer()
INDEX[_COMMAND_INDEX_VERSION] = __version__
INDEX[_COMMAND_INDEX_CLOUD_PROFILE] = self.cli_ctx.cloud.profile
INDEX[_COMMAND_INDEX] = {}
from collections import defaultdict
index = defaultdict(list)
for command_name, command in self.command_table.items():
# Get the top-level name: <vm> create
top_command = command_name.split()[0]
# Get module name, like azure.cli.command_modules.vm, azext_webapp
module_name = command.loader.__module__
if module_name not in index[top_command]:
index[top_command].append(module_name)
elapsed_time = timeit.default_timer() - start_time
logger.debug("Updated command index in %.3f seconds.", elapsed_time)
INDEX[_COMMAND_INDEX] = index

def _roughly_parse_command(args):
# Roughly parse the command part: <az vm create> --name vm1
# Similar to knack.invocation.CommandInvoker._rudimentary_get_command, but we don't need to bother with
Expand All @@ -381,51 +359,16 @@ def _roughly_parse_command(args):
break
return ' '.join(nouns).lower()

# Check the command index for (command: [module]) mapping, like
# "sql": ["azure.cli.command_modules.sql", "azure.cli.command_modules.sqlvm", "azext_sql"]
top_command = None
# This list contains both built-in modules and extensions
index_modules = []
index_builtin_modules = []
index_extensions = []

# Clear the tables to make this method idempotent
self.command_group_table = {}
self.command_table = {}

# If the command index version or cloud profile doesn't match those of the current command,
# invalidate the command index
index_version = INDEX[_COMMAND_INDEX_VERSION]
cloud_profile = INDEX[_COMMAND_INDEX_CLOUD_PROFILE]
if not (index_version and index_version == __version__ and
cloud_profile and self.cli_ctx.cloud.profile):
logger.debug("Command index version or cloud profile is invalid or doesn't match the current command.")
invalidate_command_index()

if args and not args[0].startswith('-'):
# A top level command is provided, like `az version`
top_command = args[0]
index = INDEX[_COMMAND_INDEX]
# Un-comment this line to disable command index
# index = {}
index_modules = index.get(top_command)

if index_modules:
# Found modules from index
logger.debug("Modules found from index for '%s': %s", top_command, index_modules)
command_module_prefix = 'azure.cli.command_modules.'
for m in index_modules:
if m.startswith(command_module_prefix):
# The top-level command is from a command module
index_builtin_modules.append(m[len(command_module_prefix):])
elif m.startswith('azext_'):
# The top-level command is from an extension
index_extensions.append(m)
else:
logger.warning("Unrecognized module: %s", m)

if index_builtin_modules:
_update_command_table_from_modules(args, index_builtin_modules)
self.command_group_table.clear()
self.command_table.clear()

command_index = CommandIndex(self.cli_ctx)
# command_index = CommandIndex(self.cli_ctx, False)
index_modules, index_extensions = command_index.get(args)
if index_modules or index_extensions:
if index_modules:
_update_command_table_from_modules(args, index_modules)
if index_extensions:
# The index won't contain suppressed extensions
_update_command_table_from_extensions([], index_extensions)
Expand All @@ -440,7 +383,7 @@ def _roughly_parse_command(args):
logger.debug("Could not find a match in the command table for '%s'. The index may be outdated",
command_str)
else:
logger.debug("No module found from index for '%s'", top_command)
logger.debug("No module found from index for '%s'", args)

# No module found from the index. Load all command modules and extensions
logger.debug("Loading all modules and extensions")
Expand All @@ -451,7 +394,7 @@ def _roughly_parse_command(args):
# as an extension could override the commands already loaded.
_update_command_table_from_extensions(ext_suppressions)
logger.debug("Loaded %d groups, %d commands.", len(self.command_group_table), len(self.command_table))
_update_command_index()
command_index.update(self.command_table)

return self.command_table

Expand Down Expand Up @@ -762,18 +705,104 @@ def get_default_cli():
help_cls=AzCliHelp)


def invalidate_command_index():
"""Invalidate the command index.
class CommandIndex:

_COMMAND_INDEX = 'commandIndex'
_COMMAND_INDEX_VERSION = 'version'
_COMMAND_INDEX_CLOUD_PROFILE = 'cloudProfile'

def __init__(self, cli_ctx=None, enabled=True):
from azure.cli.core._session import INDEX
self.INDEX = INDEX
self.version = __version__
self.cloud_profile = cli_ctx.cloud.profile
self.enabled = enabled

def get(self, args):
"""Get the corresponding module and extension list from the command

:param args: command sections separated by spaces, like ['network', 'vnet', 'create', '-h']
:return: list for modules and list for extensions
"""

if not self.enabled:
return None, None

# If the command index version or cloud profile doesn't match those of the current command,
# invalidate the command index.
index_version = self.INDEX[self._COMMAND_INDEX_VERSION]
cloud_profile = self.INDEX[self._COMMAND_INDEX_CLOUD_PROFILE]
if not (index_version and index_version == self.version and
cloud_profile and cloud_profile == self.cloud_profile):
logger.debug("Command index version or cloud profile is invalid or doesn't match the current command.")
self.invalidate()
return None, None

# Make sure the top-level command is provided, like `az version`.
# Skip command index for `az` or `az --help`.
if not args or args[0].startswith('-'):
return None, None

This function MUST be called when installing or updating extensions. Otherwise, when an extension
1. overrides a built-in command, or
2. extends an existing command group,
the command or command group will only be loaded from the command modules as per the stale command index,
making the newly installed extension be ignored.
# Get the top-level command, like `network` in `network vnet create -h`
top_command = args[0]
index = self.INDEX[self._COMMAND_INDEX]
# Check the command index for (command: [module]) mapping, like
# "network": ["azure.cli.command_modules.natgateway", "azure.cli.command_modules.network", "azext_firewall"]
index_modules_extensions = index.get(top_command)

This function can be called when removing extensions and updating cloud profiles for double insurance.
"""
from azure.cli.core._session import INDEX
INDEX[_COMMAND_INDEX_VERSION] = ""
INDEX[_COMMAND_INDEX] = {}
logger.debug("Command index has been invalidated.")
if index_modules_extensions:
# This list contains both built-in modules and extensions
index_builtin_modules = []
index_extensions = []
# Found modules from index
logger.debug("Modules found from index for '%s': %s", top_command, index_modules_extensions)
command_module_prefix = 'azure.cli.command_modules.'
for m in index_modules_extensions:
if m.startswith(command_module_prefix):
# The top-level command is from a command module
index_builtin_modules.append(m[len(command_module_prefix):])
elif m.startswith('azext_'):
# The top-level command is from an extension
index_extensions.append(m)
else:
logger.warning("Unrecognized module: %s", m)
return index_builtin_modules, index_extensions

return None, None

def update(self, command_table):
"""Update the command index according to cli_ctx.invocation.commands_loader.command_table
"""
start_time = timeit.default_timer()
self.INDEX[self._COMMAND_INDEX_VERSION] = __version__
self.INDEX[self._COMMAND_INDEX_CLOUD_PROFILE] = self.cloud_profile
from collections import defaultdict
index = defaultdict(list)

# self.cli_ctx.invocation.commands_loader.command_table doesn't exist in DummyCli due to the lack of invocation
for command_name, command in command_table.items():
# Get the top-level name: <vm> create
top_command = command_name.split()[0]
# Get module name, like azure.cli.command_modules.vm, azext_webapp
module_name = command.loader.__module__
if module_name not in index[top_command]:
index[top_command].append(module_name)
elapsed_time = timeit.default_timer() - start_time
self.INDEX[self._COMMAND_INDEX] = index
logger.debug("Updated command index in %.3f seconds.", elapsed_time)

def invalidate(self):
"""Invalidate the command index.

This function MUST be called when installing or updating extensions. Otherwise, when an extension
1. overrides a built-in command, or
2. extends an existing command group,
the command or command group will only be loaded from the command modules as per the stale command index,
making the newly installed extension be ignored.

This function can be called when removing extensions and updating cloud profiles for double insurance.
"""
self.INDEX[self._COMMAND_INDEX_VERSION] = ""
self.INDEX[self._COMMAND_INDEX_CLOUD_PROFILE] = ""
self.INDEX[self._COMMAND_INDEX] = {}
logger.debug("Command index has been invalidated.")
8 changes: 4 additions & 4 deletions src/azure-cli-core/azure/cli/core/extension/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import requests
from pkg_resources import parse_version

from azure.cli.core import invalidate_command_index
from azure.cli.core import CommandIndex
from azure.cli.core.util import CLIError, reload_module
from azure.cli.core.extension import (extension_exists, build_extension_path, get_extensions, get_extension_modname,
get_extension, ext_compat_with_cli,
Expand Down Expand Up @@ -235,7 +235,7 @@ def add_extension(cmd, source=None, extension_name=None, index_url=None, yes=Non
"Please use with discretion.", extension_name)
elif extension_name and ext.preview:
logger.warning("The installed extension '%s' is in preview.", extension_name)
invalidate_command_index()
CommandIndex().invalidate()
except ExtensionNotInstalledException:
pass

Expand All @@ -255,7 +255,7 @@ def log_err(func, path, exc_info):
# We call this just before we remove the extension so we can get the metadata before it is gone
_augment_telemetry_with_ext_info(extension_name, ext)
shutil.rmtree(ext.path, onerror=log_err)
invalidate_command_index()
CommandIndex().invalidate()
except ExtensionNotInstalledException as e:
raise CLIError(e)

Expand Down Expand Up @@ -308,7 +308,7 @@ def update_extension(cmd, extension_name, index_url=None, pip_extra_index_urls=N
logger.debug('Copying %s to %s', backup_dir, extension_path)
shutil.copytree(backup_dir, extension_path)
raise CLIError('Failed to update. Rolled {} back to {}.'.format(extension_name, cur_version))
invalidate_command_index()
CommandIndex().invalidate()
except ExtensionNotInstalledException as e:
raise CLIError(e)

Expand Down
Loading