Skip to content
Next Next commit
Making all platform-dependent parameters optional
  • Loading branch information
rayluo committed Dec 8, 2021
commit 42d482da5f2c5e93316abbe4534651639278431c
18 changes: 11 additions & 7 deletions msal_extensions/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import abc
import os
import errno
import hashlib
import logging
import sys
try:
Expand Down Expand Up @@ -50,6 +51,9 @@ def _mkdir_p(path):
else:
raise

def _auto_hash(input_string):
return hashlib.sha256(input_string.encode('utf-8')).hexdigest()


# We do not aim to wrap every os-specific exception.
# Here we define only the most common one,
Expand Down Expand Up @@ -197,19 +201,18 @@ class KeychainPersistence(BasePersistence):
and protected by native Keychain libraries on OSX"""
is_encrypted = True

def __init__(self, signal_location, service_name, account_name):
def __init__(self, signal_location, service_name=None, account_name=None):
"""Initialization could fail due to unsatisfied dependency.

:param signal_location: See :func:`persistence.LibsecretPersistence.__init__`
"""
if not (service_name and account_name): # It would hang on OSX
raise ValueError("service_name and account_name are required")
from .osx import Keychain, KeychainError # pylint: disable=import-outside-toplevel
self._file_persistence = FilePersistence(signal_location) # Favor composition
self._Keychain = Keychain # pylint: disable=invalid-name
self._KeychainError = KeychainError # pylint: disable=invalid-name
self._service_name = service_name
self._account_name = account_name
default_service_name = "msal-extensions" # This is also our package name
self._service_name = service_name or default_service_name
self._account_name = account_name or _auto_hash(signal_location)

def save(self, content):
with self._Keychain() as locker:
Expand Down Expand Up @@ -247,7 +250,7 @@ class LibsecretPersistence(BasePersistence):
and protected by native libsecret libraries on Linux"""
is_encrypted = True

def __init__(self, signal_location, schema_name, attributes, **kwargs):
def __init__(self, signal_location, schema_name=None, attributes=None, **kwargs):
"""Initialization could fail due to unsatisfied dependency.

:param string signal_location:
Expand All @@ -262,7 +265,8 @@ def __init__(self, signal_location, schema_name, attributes, **kwargs):
from .libsecret import ( # This uncertain import is deferred till runtime
LibSecretAgent, trial_run)
trial_run()
self._agent = LibSecretAgent(schema_name, attributes, **kwargs)
self._agent = LibSecretAgent(
schema_name or _auto_hash(signal_location), attributes or {}, **kwargs)
self._file_persistence = FilePersistence(signal_location) # Favor composition

def save(self, content):
Expand Down
5 changes: 2 additions & 3 deletions sample/persistence_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def build_persistence(location, fallback_to_plaintext=False):
if sys.platform.startswith('win'):
return FilePersistenceWithDataProtection(location)
if sys.platform.startswith('darwin'):
return KeychainPersistence(location, "my_service_name", "my_account_name")
return KeychainPersistence(location)
if sys.platform.startswith('linux'):
try:
return LibsecretPersistence(
Expand All @@ -21,8 +21,6 @@ def build_persistence(location, fallback_to_plaintext=False):
# unless there would frequently be a desktop session and
# a remote ssh session being active simultaneously.
location,
schema_name="my_schema_name",
attributes={"my_attr1": "foo", "my_attr2": "bar"},
)
except: # pylint: disable=bare-except
if not fallback_to_plaintext:
Expand All @@ -31,6 +29,7 @@ def build_persistence(location, fallback_to_plaintext=False):
return FilePersistence(location)

persistence = build_persistence("storage.bin", fallback_to_plaintext=False)
print("Type of persistence: {}".format(persistence.__class__.__name__))
print("Is this persistence encrypted?", persistence.is_encrypted)

data = { # It can be anything, here we demonstrate an arbitrary json object
Expand Down
5 changes: 2 additions & 3 deletions sample/token_cache_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def build_persistence(location, fallback_to_plaintext=False):
if sys.platform.startswith('win'):
return FilePersistenceWithDataProtection(location)
if sys.platform.startswith('darwin'):
return KeychainPersistence(location, "my_service_name", "my_account_name")
return KeychainPersistence(location)
if sys.platform.startswith('linux'):
try:
return LibsecretPersistence(
Expand All @@ -21,8 +21,6 @@ def build_persistence(location, fallback_to_plaintext=False):
# unless there would frequently be a desktop session and
# a remote ssh session being active simultaneously.
location,
schema_name="my_schema_name",
attributes={"my_attr1": "foo", "my_attr2": "bar"},
)
except: # pylint: disable=bare-except
if not fallback_to_plaintext:
Expand All @@ -31,6 +29,7 @@ def build_persistence(location, fallback_to_plaintext=False):
return FilePersistence(location)

persistence = build_persistence("token_cache.bin")
print("Type of persistence: {}".format(persistence.__class__.__name__))
print("Is this persistence encrypted?", persistence.is_encrypted)

cache = PersistedTokenCache(persistence)
Expand Down
9 changes: 2 additions & 7 deletions tests/test_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ def test_nonexistent_file_persistence_with_data_protection(temp_location):
not sys.platform.startswith('darwin'),
reason="Requires OSX. Whether running on TRAVIS CI does not seem to matter.")
def test_keychain_persistence(temp_location):
_test_persistence_roundtrip(KeychainPersistence(
temp_location, "my_service_name", "my_account_name"))
_test_persistence_roundtrip(KeychainPersistence(temp_location))

@pytest.mark.skipif(
not sys.platform.startswith('darwin'),
Expand All @@ -69,11 +68,7 @@ def test_nonexistent_keychain_persistence(temp_location):
is_running_on_travis_ci or not sys.platform.startswith('linux'),
reason="Requires Linux Desktop. Headless or SSH session won't work.")
def test_libsecret_persistence(temp_location):
_test_persistence_roundtrip(LibsecretPersistence(
temp_location,
"my_schema_name",
{"my_attr_1": "foo", "my_attr_2": "bar"},
))
_test_persistence_roundtrip(LibsecretPersistence(temp_location))

@pytest.mark.skipif(
is_running_on_travis_ci or not sys.platform.startswith('linux'),
Expand Down