|
7 | 7 | except ImportError: |
8 | 8 | from urlparse import parse_qs, urlparse, urlunparse |
9 | 9 | from urllib import urlencode, quote_plus |
| 10 | +import inspect |
10 | 11 | import logging |
11 | 12 | import warnings |
12 | 13 | import time |
@@ -104,6 +105,15 @@ def __init__( |
104 | 105 | or a raw JWT assertion in bytes (which we will relay to http layer). |
105 | 106 | It can also be a callable (recommended), |
106 | 107 | so that we will do lazy creation of an assertion. |
| 108 | +
|
| 109 | + The callable may accept zero arguments (legacy) or one |
| 110 | + required positional argument. Callables whose positional |
| 111 | + parameters all have default values (e.g. |
| 112 | + ``lambda token=token: token``) are treated as zero-arg. |
| 113 | + When the callable declares a required positional parameter, |
| 114 | + it will receive a dict containing ``"client_id"``, |
| 115 | + ``"token_endpoint"``, and optionally ``"fmi_path"`` |
| 116 | + (when an FMI path is set on the current request). |
107 | 117 | client_assertion_type (str): |
108 | 118 | The type of your :attr:`client_assertion` parameter. |
109 | 119 | It is typically the value of :attr:`CLIENT_ASSERTION_TYPE_SAML2` or |
@@ -168,6 +178,41 @@ def __init__( |
168 | 178 | # A workaround for requests not supporting session-wide timeout |
169 | 179 | self._http_client.request, timeout=timeout) |
170 | 180 |
|
| 181 | + @staticmethod |
| 182 | + def _accepts_context(func): |
| 183 | + """Check if a callable requires at least one positional argument. |
| 184 | +
|
| 185 | + Returns True only when the callable has a positional parameter |
| 186 | + **without** a default value. This ensures that legacy zero-arg |
| 187 | + callables — including ``lambda token=token: token`` patterns |
| 188 | + where every positional param has a default — are still invoked |
| 189 | + with no arguments. |
| 190 | + """ |
| 191 | + try: |
| 192 | + sig = inspect.signature(func) |
| 193 | + for p in sig.parameters.values(): |
| 194 | + if p.kind in ( |
| 195 | + inspect.Parameter.POSITIONAL_ONLY, |
| 196 | + inspect.Parameter.POSITIONAL_OR_KEYWORD, |
| 197 | + ) and p.default is inspect.Parameter.empty: |
| 198 | + return True |
| 199 | + return False |
| 200 | + except (ValueError, TypeError): |
| 201 | + return False # Signature not inspectable; treat as zero-arg |
| 202 | + |
| 203 | + def _invoke_assertion_callable(self, assertion_callable, data=None): |
| 204 | + """Invoke an assertion callable, passing context if it accepts one.""" |
| 205 | + if self._accepts_context(assertion_callable): |
| 206 | + context = { |
| 207 | + "client_id": self.client_id, |
| 208 | + "token_endpoint": self.configuration.get( |
| 209 | + "token_endpoint", ""), |
| 210 | + } |
| 211 | + if data and data.get("fmi_path"): |
| 212 | + context["fmi_path"] = data["fmi_path"] |
| 213 | + return assertion_callable(context) |
| 214 | + return assertion_callable() |
| 215 | + |
171 | 216 | def _build_auth_request_params(self, response_type, **kwargs): |
172 | 217 | # response_type is a string defined in |
173 | 218 | # https://tools.ietf.org/html/rfc6749#section-3.1.1 |
@@ -198,11 +243,11 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749 |
198 | 243 | # See https://tools.ietf.org/html/rfc7521#section-4.2 |
199 | 244 | encoder = self.client_assertion_encoders.get( |
200 | 245 | self.default_body["client_assertion_type"], lambda a: a) |
201 | | - _data["client_assertion"] = encoder( |
202 | | - self.client_assertion() # Do lazy on-the-fly computation |
203 | | - if callable(self.client_assertion) else self.client_assertion |
204 | | - ) # The type is bytes, which is preferable. See also: |
205 | | - # https://github.com/psf/requests/issues/4503#issuecomment-455001070 |
| 246 | + if callable(self.client_assertion): |
| 247 | + raw = self._invoke_assertion_callable(self.client_assertion, data) |
| 248 | + else: |
| 249 | + raw = self.client_assertion |
| 250 | + _data["client_assertion"] = encoder(raw) |
206 | 251 |
|
207 | 252 | _data.update(self.default_body) # It may contain authen parameters |
208 | 253 | _data.update(data or {}) # So the content in data param prevails |
@@ -770,6 +815,34 @@ class initialization. |
770 | 815 | data.update(scope=scope) |
771 | 816 | return self._obtain_token("client_credentials", data=data, **kwargs) |
772 | 817 |
|
| 818 | + def obtain_token_by_user_fic( |
| 819 | + self, scope, assertion, username=None, user_object_id=None, |
| 820 | + **kwargs): |
| 821 | + """Obtain token using the ``user_fic`` grant type. |
| 822 | +
|
| 823 | + This exchanges a federated identity credential (e.g. an agent |
| 824 | + instance token) for a user-scoped access token. |
| 825 | +
|
| 826 | + :param scope: Scopes for the target resource (already decorated |
| 827 | + with OIDC scopes by the caller). |
| 828 | + :param str assertion: The federated identity credential token. |
| 829 | + :param str username: The target user's UPN (mutually exclusive |
| 830 | + with *user_object_id*). |
| 831 | + :param str user_object_id: The target user's Object ID (mutually |
| 832 | + exclusive with *username*). |
| 833 | + """ |
| 834 | + data = kwargs.pop("data", {}) |
| 835 | + data.update( |
| 836 | + scope=scope, |
| 837 | + user_federated_identity_credential=assertion, |
| 838 | + client_info="1", |
| 839 | + ) |
| 840 | + if user_object_id: |
| 841 | + data["user_id"] = str(user_object_id) |
| 842 | + elif username: |
| 843 | + data["username"] = username |
| 844 | + return self._obtain_token("user_fic", data=data, **kwargs) |
| 845 | + |
773 | 846 | def __init__(self, |
774 | 847 | server_configuration, client_id, |
775 | 848 | on_obtaining_tokens=lambda event: None, # event is defined in _obtain_token(...) |
|
0 commit comments