Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/azure/cli/_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,8 @@ def __init__(self, auth_ctx_factory=None):
self._load_creds()

def persist_cached_creds(self):
#be compatible with azure-xplat-cli, use 'ascii' so to save w/o a BOM
with codecs_open(self._token_file, 'w', encoding='ascii') as cred_file:
with os.fdopen(os.open(self._token_file, os.O_RDWR|os.O_CREAT|os.O_TRUNC, 0o600),
'w+') as cred_file:
items = self.adal_token_cache.read_items()
all_creds = [entry for _, entry in items]

Expand All @@ -332,6 +332,7 @@ def persist_cached_creds(self):

all_creds.extend(self._service_principal_creds)
cred_file.write(json.dumps(all_creds))

self.adal_token_cache.has_state_changed = False

def retrieve_token_for_user(self, username, tenant, resource):
Expand Down
21 changes: 12 additions & 9 deletions src/azure/cli/tests/test_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,8 +382,9 @@ def test_credscache_load_tokens_and_sp_creds(self, mock_read_file):
self.assertEqual(creds_cache._service_principal_creds, [test_sp])

@mock.patch('azure.cli._profile._read_file_content', autospec=True)
@mock.patch('azure.cli._profile.codecs_open', autospec=True)
def test_credscache_add_new_sp_creds(self, mock_open_for_write, mock_read_file):
@mock.patch('os.fdopen', autospec=True)
@mock.patch('os.open', autospec=True)
def test_credscache_add_new_sp_creds(self, _, mock_open_for_write, mock_read_file):
test_sp = {
"servicePrincipalId": "myapp",
"servicePrincipalTenant": "mytenant",
Expand All @@ -408,11 +409,12 @@ def test_credscache_add_new_sp_creds(self, mock_open_for_write, mock_read_file):
token_entries = [entry for _, entry in creds_cache.adal_token_cache.read_items()]
self.assertEqual(token_entries, [self.token_entry1])
self.assertEqual(creds_cache._service_principal_creds, [test_sp, test_sp2])
mock_open_for_write.assert_called_with(mock.ANY, 'w', encoding='ascii')
mock_open_for_write.assert_called_with(mock.ANY, 'w+')

@mock.patch('azure.cli._profile._read_file_content', autospec=True)
@mock.patch('azure.cli._profile.codecs_open', autospec=True)
def test_credscache_remove_creds(self, mock_open_for_write, mock_read_file):
@mock.patch('os.fdopen', autospec=True)
@mock.patch('os.open', autospec=True)
def test_credscache_remove_creds(self, _, mock_open_for_write, mock_read_file):
test_sp = {
"servicePrincipalId": "myapp",
"servicePrincipalTenant": "mytenant",
Expand All @@ -435,13 +437,14 @@ def test_credscache_remove_creds(self, mock_open_for_write, mock_read_file):
#assert #2
self.assertEqual(creds_cache._service_principal_creds, [])

mock_open_for_write.assert_called_with(mock.ANY, 'w', encoding='ascii')
mock_open_for_write.assert_called_with(mock.ANY, 'w+')
self.assertEqual(mock_open_for_write.call_count, 2)

@mock.patch('azure.cli._profile._read_file_content', autospec=True)
@mock.patch('azure.cli._profile.codecs_open', autospec=True)
@mock.patch('os.fdopen', autospec=True)
@mock.patch('os.open', autospec=True)
@mock.patch('adal.AuthenticationContext', autospec=True)
def test_credscache_new_token_added_by_adal(self, mock_adal_auth_context, mock_open_for_write, mock_read_file): # pylint: disable=line-too-long
def test_credscache_new_token_added_by_adal(self, mock_adal_auth_context, _, mock_open_for_write, mock_read_file): # pylint: disable=line-too-long
token_entry2 = {
"accessToken": "new token",
"tokenType": "Bearer",
Expand Down Expand Up @@ -469,7 +472,7 @@ def get_auth_context(authority, **kwargs): # pylint: disable=unused-argument
mock.ANY)

#assert
mock_open_for_write.assert_called_with(mock.ANY, 'w', encoding='ascii')
mock_open_for_write.assert_called_with(mock.ANY, 'w+')
self.assertEqual(token, 'new token')
self.assertEqual(token_type, token_entry2['tokenType'])

Expand Down