Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 71 additions & 46 deletions src/azure-cli/azure/cli/command_modules/servicefabric/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from azure.cli.core.profiles import ResourceType, get_sdk
from azure.cli.command_modules.servicefabric._arm_deployment_utils import validate_and_deploy_arm_template
from azure.cli.command_modules.servicefabric._sf_utils import _get_resource_group_by_name, _create_resource_group_name
from azure.core.exceptions import ResourceNotFoundError

from azure.mgmt.servicefabric.models import (ClusterUpdateParameters,
ClientCertificateThumbprint,
Expand Down Expand Up @@ -302,7 +303,8 @@ def add_app_cert(cmd,
return client.get(resource_group_name, cluster_name)


def add_client_cert(client,
def add_client_cert(cmd,
client,
resource_group_name,
cluster_name,
is_admin=False,
Expand All @@ -312,6 +314,7 @@ def add_client_cert(client,
admin_client_thumbprints=None,
readonly_client_thumbprints=None,
client_certificate_common_names=None):
cli_ctx = cmd.cli_ctx
if thumbprint:
if certificate_common_name or certificate_issuer_thumbprint or admin_client_thumbprints or readonly_client_thumbprints or client_certificate_common_names:
raise CLIError(
Expand Down Expand Up @@ -376,16 +379,19 @@ def _add_common_name(cluster, is_admin, certificate_common_name, certificate_iss

patch_request = ClusterUpdateParameters(client_certificate_thumbprints=cluster.client_certificate_thumbprints,
client_certificate_common_names=cluster.client_certificate_common_names)
return client.update(resource_group_name, cluster_name, patch_request)
update_cluster_poll = client.begin_update(resource_group_name, cluster_name, patch_request)
return LongRunningOperation(cli_ctx)(update_cluster_poll)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is recommended to use sdk_no_wait instead of LongRunningOperation and support the --no-wait parameter.
However, you don't need to complete these in this PR, maybe you can submit a PR separately for this optimization.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok thanks, I will work on this recommendation on a separate pr



def remove_client_cert(client,
def remove_client_cert(cmd,
client,
resource_group_name,
cluster_name,
thumbprints=None,
certificate_common_name=None,
certificate_issuer_thumbprint=None,
client_certificate_common_names=None):
cli_ctx = cmd.cli_ctx
if thumbprints:
if certificate_common_name or certificate_issuer_thumbprint or client_certificate_common_names:
raise CLIError("--thumbprint can only specified alone")
Expand Down Expand Up @@ -441,7 +447,8 @@ def _remove_common_name(cluster, certificate_common_name, certificate_issuer_thu
patch_request = ClusterUpdateParameters(client_certificate_thumbprints=cluster.client_certificate_thumbprints,
client_certificate_common_names=cluster.client_certificate_common_names)

return client.update(resource_group_name, cluster_name, patch_request)
update_cluster_poll = client.begin_update(resource_group_name, cluster_name, patch_request)
return LongRunningOperation(cli_ctx)(update_cluster_poll)


def add_cluster_node(cmd, client, resource_group_name, cluster_name, node_type, number_of_nodes_to_add):
Expand All @@ -468,7 +475,8 @@ def add_cluster_node(cmd, client, resource_group_name, cluster_name, node_type,
node_type.vm_instance_count = vmss.sku.capacity
patch_request = ClusterUpdateParameters(node_types=cluster.node_types)

return client.update(resource_group_name, cluster_name, patch_request)
update_cluster_poll = client.begin_update(resource_group_name, cluster_name, patch_request)
return LongRunningOperation(cli_ctx)(update_cluster_poll)


def remove_cluster_node(cmd, client, resource_group_name, cluster_name, node_type, number_of_nodes_to_remove):
Expand Down Expand Up @@ -500,13 +508,14 @@ def remove_cluster_node(cmd, client, resource_group_name, cluster_name, node_typ
node_type.vm_instance_count = vmss.sku.capacity
patch_request = ClusterUpdateParameters(node_types=cluster.node_types)

return client.update(resource_group_name, cluster_name, patch_request)
sfrp_poll = client.begin_update(resource_group_name, cluster_name, patch_request)
return LongRunningOperation(cli_ctx)(sfrp_poll)


def update_cluster_durability(cmd, client, resource_group_name, cluster_name, node_type, durability_level):
cli_ctx = cmd.cli_ctx

# get cluster node type durablity
# get cluster node type durability
cluster = client.get(resource_group_name, cluster_name)
node_type_refs = [n for n in cluster.node_types if n.name.lower() == node_type.lower()]
if not node_type_refs:
Expand All @@ -517,7 +526,7 @@ def update_cluster_durability(cmd, client, resource_group_name, cluster_name, no
# get vmss extension durability
compute_client = compute_client_factory(cli_ctx)
vmss = _get_cluster_vmss_by_node_type(compute_client, resource_group_name, cluster.cluster_id, node_type)
_get_sf_vm_extension(vmss)

fabric_ext_ref = _get_sf_vm_extension(vmss)
if fabric_ext_ref is None:
raise CLIError("Failed to find service fabric extension.")
Expand All @@ -535,7 +544,7 @@ def update_cluster_durability(cmd, client, resource_group_name, cluster_name, no
if curr_node_type_durability.lower() != durability_level.lower():
node_type_ref.durability_level = durability_level
patch_request = ClusterUpdateParameters(node_types=cluster.node_types)
update_cluster_poll = client.update(resource_group_name, cluster_name, patch_request)
update_cluster_poll = client.begin_update(resource_group_name, cluster_name, patch_request)
LongRunningOperation(cli_ctx)(update_cluster_poll)

# update vmss sf extension durability
Expand All @@ -548,7 +557,8 @@ def update_cluster_durability(cmd, client, resource_group_name, cluster_name, no
return client.get(resource_group_name, cluster_name)


def update_cluster_upgrade_type(client,
def update_cluster_upgrade_type(cmd,
client,
resource_group_name,
cluster_name,
upgrade_mode,
Expand All @@ -557,6 +567,7 @@ def update_cluster_upgrade_type(client,
raise CLIError(
'--upgrade-mode can either be \'manual\' or \'automatic\'')

cli_ctx = cmd.cli_ctx
cluster = client.get(resource_group_name, cluster_name)
patch_request = ClusterUpdateParameters(node_types=cluster.node_types)
if upgrade_mode.lower() == 'manual':
Expand All @@ -566,16 +577,20 @@ def update_cluster_upgrade_type(client,
patch_request.cluster_code_version = version

patch_request.upgrade_mode = upgrade_mode
return client.update(resource_group_name, cluster_name, patch_request)
update_cluster_poll = client.begin_update(resource_group_name, cluster_name, patch_request)
return LongRunningOperation(cli_ctx)(update_cluster_poll)


def set_cluster_setting(client,
def set_cluster_setting(cmd,
client,
resource_group_name,
cluster_name,
section=None,
parameter=None,
value=None,
settings_section_description=None):
cli_ctx = cmd.cli_ctx

def _set(setting_dict, section, parameter, value):
if section not in setting_dict:
setting_dict[section] = {}
Expand All @@ -601,15 +616,19 @@ def _set(setting_dict, section, parameter, value):
setting_dict = _set(setting_dict, section, parameter, value)
settings = _dict_to_fabric_settings(setting_dict)
patch_request = ClusterUpdateParameters(fabric_settings=settings)
return client.update(resource_group_name, cluster_name, patch_request)
update_cluster_poll = client.begin_update(resource_group_name, cluster_name, patch_request)
return LongRunningOperation(cli_ctx)(update_cluster_poll)


def remove_cluster_setting(client,
def remove_cluster_setting(cmd,
client,
resource_group_name,
cluster_name,
section=None,
parameter=None,
settings_section_description=None):
cli_ctx = cmd.cli_ctx

def _remove(setting_dict, section, parameter):
if section not in setting_dict:
raise CLIError(
Expand All @@ -636,7 +655,8 @@ def _remove(setting_dict, section, parameter):

settings = _dict_to_fabric_settings(setting_dict)
patch_request = ClusterUpdateParameters(fabric_settings=settings)
return client.update(resource_group_name, cluster_name, patch_request)
update_cluster_poll = client.begin_update(resource_group_name, cluster_name, patch_request)
return LongRunningOperation(cli_ctx)(update_cluster_poll)


def update_cluster_reliability_level(cmd,
Expand Down Expand Up @@ -670,7 +690,8 @@ def update_cluster_reliability_level(cmd,
node_type.vm_instance_count = vmss.sku.capacity
patch_request = ClusterUpdateParameters(
node_types=cluster.node_types, reliability_level=reliability_level)
return client.update(resource_group_name, cluster_name, patch_request)
update_cluster_poll = client.begin_update(resource_group_name, cluster_name, patch_request)
return LongRunningOperation(cli_ctx)(update_cluster_poll)


def add_cluster_node_type(cmd,
Expand All @@ -692,8 +713,8 @@ def add_cluster_node_type(cmd,
if any(n for n in cluster.node_types if n.name.lower() == node_type):
raise CLIError("node type {} already exists in the cluster".format(node_type))

_create_vmss(cmd, resource_group_name, cluster_name, cluster, node_type, durability_level, vm_password, vm_user_name, vm_sku, vm_tier, capacity)
_add_node_type_to_sfrp(cmd, client, resource_group_name, cluster_name, cluster, node_type, capacity, durability_level)
_create_vmss(cmd, resource_group_name, cluster_name, cluster, node_type, durability_level, vm_password, vm_user_name, vm_sku, vm_tier, capacity)

return client.get(resource_group_name, cluster_name)

Expand All @@ -711,8 +732,8 @@ def _add_node_type_to_sfrp(cmd, client, resource_group_name, cluster_name, clust
start_port=DEFAULT_EPHEMERAL_START, end_port=DEFAULT_EPHEMERAL_END)))

patch_request = ClusterUpdateParameters(node_types=cluster.node_types)
poller = client.update(resource_group_name, cluster_name, patch_request)
LongRunningOperation(cmd.cli_ctx)(poller)
poller = client.begin_update(resource_group_name, cluster_name, patch_request)
return LongRunningOperation(cmd.cli_ctx)(poller)


def _create_vmss(cmd, resource_group_name, cluster_name, cluster, node_type_name, durability_level, vm_password, vm_user_name, vm_sku, vm_tier, capacity):
Expand Down Expand Up @@ -885,11 +906,13 @@ def create_vhd(cli_ctx, resource_group_name, cluster_name, node_type, location):
def create_storage_account(cli_ctx, resource_group_name, storage_name, location):
from azure.mgmt.storage.models import Sku, SkuName
storage_client = storage_client_factory(cli_ctx)
LongRunningOperation(cli_ctx)(storage_client.storage_accounts.create(resource_group_name,
storage_name,
StorageAccountCreateParameters(sku=Sku(name=SkuName.standard_lrs),
kind='storage',
location=location)))
storage_poll = storage_client.storage_accounts.begin_create(resource_group_name,
storage_name,
StorageAccountCreateParameters(sku=Sku(name=SkuName.standard_lrs),
kind='storage',
location=location))

LongRunningOperation(cli_ctx)(storage_poll)

acc_prop = storage_client.storage_accounts.get_properties(
resource_group_name, storage_name)
Expand Down Expand Up @@ -924,7 +947,7 @@ def create_storage_account(cli_ctx, resource_group_name, storage_name, location)

diagnostics_ext = None
fabric_ext = None
diagnostics_exts = [e for e in vmss_reference.virtual_machine_profile.extension_profile.extensions if e.type1.lower(
diagnostics_exts = [e for e in vmss_reference.virtual_machine_profile.extension_profile.extensions if e.type_properties_type.lower(
) == 'IaaSDiagnostics'.lower()]
if any(diagnostics_exts):
diagnostics_ext = diagnostics_exts[0]
Expand All @@ -940,8 +963,8 @@ def create_storage_account(cli_ctx, resource_group_name, storage_name, location)
json_data['storageAccountEndPoint'] = "https://core.windows.net/"
diagnostics_ext.protected_settings = json_data

fabric_exts = [e for e in vmss_reference.virtual_machine_profile.extension_profile.extensions if e.type1.lower(
) == SERVICE_FABRIC_WINDOWS_NODE_EXT_NAME or e.type1.lower() == SERVICE_FABRIC_LINUX_NODE_EXT_NAME]
fabric_exts = [e for e in vmss_reference.virtual_machine_profile.extension_profile.extensions if e.type_properties_type.lower(
) == SERVICE_FABRIC_WINDOWS_NODE_EXT_NAME or e.type_properties_type.lower() == SERVICE_FABRIC_LINUX_NODE_EXT_NAME]
if any(fabric_exts):
fabric_ext = fabric_exts[0]

Expand Down Expand Up @@ -1100,12 +1123,7 @@ def _create_certificate(cmd,
else:
if vault is None:
logger.info("Creating key vault")
if cmd.supported_api_version(resource_type=ResourceType.MGMT_KEYVAULT, min_api='2018-02-14'):
vault = _create_keyvault(
cmd, cli_ctx, vault_resource_group_name, vault_name, location, enabled_for_deployment=True).result()
else:
vault = _create_keyvault(
cmd, cli_ctx, vault_resource_group_name, vault_name, location, enabled_for_deployment=True)
vault = _create_keyvault(cmd, cli_ctx, vault_resource_group_name, vault_name, location, enabled_for_deployment=True)
logger.info("Wait for key vault ready")
time.sleep(20)
vault_uri = vault.properties.vault_uri
Expand Down Expand Up @@ -1346,6 +1364,8 @@ def _safe_get_vault(cli_ctx, resource_group_name, vault_name):
try:
vault = key_vault_client.get(resource_group_name, vault_name)
return vault
except ResourceNotFoundError:
return None
except CloudError as ex:
if ex.error.error == 'ResourceNotFound':
return None
Expand Down Expand Up @@ -1621,15 +1641,15 @@ def _create_keyvault(cmd,
tenant_id,
base_url=cli_ctx.cloud.endpoints.active_directory_graph_resource_id)
subscription = profile.get_subscription()
VaultCreateOrUpdateParameters = cmd.get_models('VaultCreateOrUpdateParameters', resource_type=ResourceType.MGMT_KEYVAULT)
VaultProperties = cmd.get_models('VaultProperties', resource_type=ResourceType.MGMT_KEYVAULT)
KeyVaultSku = cmd.get_models('Sku', resource_type=ResourceType.MGMT_KEYVAULT)
AccessPolicyEntry = cmd.get_models('AccessPolicyEntry', resource_type=ResourceType.MGMT_KEYVAULT)
Permissions = cmd.get_models('Permissions', resource_type=ResourceType.MGMT_KEYVAULT)
CertificatePermissions = get_sdk(cli_ctx, ResourceType.MGMT_KEYVAULT, 'models#CertificatePermissions')
KeyPermissions = get_sdk(cli_ctx, ResourceType.MGMT_KEYVAULT, 'models#KeyPermissions')
SecretPermissions = get_sdk(cli_ctx, ResourceType.MGMT_KEYVAULT, 'models#SecretPermissions')
KeyVaultSkuName = cmd.get_models('SkuName', resource_type=ResourceType.MGMT_KEYVAULT)
VaultCreateOrUpdateParameters = cmd.get_models('VaultCreateOrUpdateParameters', resource_type=ResourceType.MGMT_KEYVAULT, operation_group='vaults')
VaultProperties = cmd.get_models('VaultProperties', resource_type=ResourceType.MGMT_KEYVAULT, operation_group='vaults')
KeyVaultSku = cmd.get_models('Sku', resource_type=ResourceType.MGMT_KEYVAULT, operation_group='vaults')
AccessPolicyEntry = cmd.get_models('AccessPolicyEntry', resource_type=ResourceType.MGMT_KEYVAULT, operation_group='vaults')
Permissions = cmd.get_models('Permissions', resource_type=ResourceType.MGMT_KEYVAULT, operation_group='vaults')
CertificatePermissions = get_sdk(cli_ctx, ResourceType.MGMT_KEYVAULT, 'models#CertificatePermissions', operation_group='vaults')
KeyPermissions = get_sdk(cli_ctx, ResourceType.MGMT_KEYVAULT, 'models#KeyPermissions', operation_group='vaults')
SecretPermissions = get_sdk(cli_ctx, ResourceType.MGMT_KEYVAULT, 'models#SecretPermissions', operation_group='vaults')
KeyVaultSkuName = cmd.get_models('SkuName', resource_type=ResourceType.MGMT_KEYVAULT, operation_group='vaults')

if not sku:
sku = KeyVaultSkuName.standard.value
Expand Down Expand Up @@ -1678,20 +1698,25 @@ def _create_keyvault(cmd,
access_policies = [AccessPolicyEntry(tenant_id=tenant_id,
object_id=object_id,
permissions=permissions)]

properties = VaultProperties(tenant_id=tenant_id,
sku=KeyVaultSku(name=sku),
access_policies=access_policies,
vault_uri=None,
enabled_for_deployment=enabled_for_deployment,
enabled_for_disk_encryption=enabled_for_disk_encryption,
enabled_for_template_deployment=enabled_for_template_deployment)

parameters = VaultCreateOrUpdateParameters(location=location,
tags=tags,
properties=properties)
client = keyvault_client_factory(cli_ctx).vaults
return client.create_or_update(resource_group_name=resource_group_name,
vault_name=vault_name,
parameters=parameters)

keyvault_client = keyvault_client_factory(cli_ctx)
kv_poll = keyvault_client.vaults.begin_create_or_update(resource_group_name=resource_group_name,
vault_name=vault_name,
parameters=parameters)

return LongRunningOperation(cli_ctx)(kv_poll)


# pylint: disable=inconsistent-return-statements
Expand Down
Loading