diff --git a/src/azure-cli/azure/cli/command_modules/vm/_vm_utils.py b/src/azure-cli/azure/cli/command_modules/vm/_vm_utils.py index a84048061f9..399161b5ddd 100644 --- a/src/azure-cli/azure/cli/command_modules/vm/_vm_utils.py +++ b/src/azure-cli/azure/cli/command_modules/vm/_vm_utils.py @@ -711,34 +711,48 @@ def import_aaz_by_profile(profile, module_name): def generate_ssh_keys_ed25519(private_key_filepath, public_key_filepath): + def _open(filename, mode): + return os.open(filename, flags=os.O_WRONLY | os.O_TRUNC | os.O_CREAT, mode=mode) + from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey ssh_dir = os.path.dirname(private_key_filepath) if not os.path.exists(ssh_dir): - os.makedirs(ssh_dir) - os.chmod(ssh_dir, 0o700) + os.makedirs(ssh_dir, mode=0o700) - private_key = Ed25519PrivateKey.generate() - public_key = private_key.public_key() - private_bytes = private_key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.OpenSSH, - encryption_algorithm=serialization.NoEncryption() - ) - - with os.fdopen(os.open(private_key_filepath, flags=os.O_WRONLY | os.O_TRUNC | os.O_CREAT, mode=384, ), "w", ) as f: - f.write( - private_bytes.decode() + if os.path.isfile(private_key_filepath): + # Try to use existing private key if it exists. + with open(private_key_filepath, "rb") as f: + private_bytes = f.read() + private_key = serialization.load_ssh_private_key(private_bytes, password=None) + logger.warning("Private SSH key file '%s' was found in the directory: '%s'. " + "A paired public key file '%s' will be generated.", + private_key_filepath, ssh_dir, public_key_filepath) + + else: + # Otherwise generate new private key. + private_key = Ed25519PrivateKey.generate() + + # The private key will look like: + # -----BEGIN OPENSSH PRIVATE KEY----- + # ... + # -----END OPENSSH PRIVATE KEY----- + private_bytes = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.OpenSSH, + encryption_algorithm=serialization.NoEncryption() ) - os.chmod(private_key_filepath, 0o600) - with open(public_key_filepath, 'w') as public_key_file: - s = public_key.public_bytes( - encoding=serialization.Encoding.OpenSSH, - format=serialization.PublicFormat.OpenSSH) - public_key = s.decode(encoding="utf8").replace("\n", "") - public_key_file.write(public_key) - os.chmod(public_key_filepath, 0o644) + with os.fdopen(_open(private_key_filepath, 0o600), "wb") as f: + f.write(private_bytes) + + public_key = private_key.public_key() + public_bytes = public_key.public_bytes( + encoding=serialization.Encoding.OpenSSH, + format=serialization.PublicFormat.OpenSSH) + + with os.fdopen(_open(public_key_filepath, 0o644), 'wb') as f: + f.write(public_bytes) - return public_key + return public_bytes.decode()