Skip to content

Commit e271cc1

Browse files
committed
Revamp OAuth account model management
1 parent 83ca318 commit e271cc1

File tree

6 files changed

+107
-55
lines changed

6 files changed

+107
-55
lines changed

fastapi_users/db/base.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from pydantic import UUID4
44

5-
from fastapi_users.models import UP
5+
from fastapi_users.models import OAP, UP
66
from fastapi_users.types import DependencyCallable
77

88

@@ -33,5 +33,15 @@ async def delete(self, user: UP) -> None:
3333
"""Delete a user."""
3434
raise NotImplementedError()
3535

36+
async def add_oauth_account(self, user: UP, create_dict: Dict[str, Any]) -> UP:
37+
"""Create an OAuth account and add it to the user."""
38+
raise NotImplementedError()
39+
40+
async def update_oauth_account(
41+
self, user: UP, oauth_account: OAP, update_dict: Dict[str, Any]
42+
) -> UP:
43+
"""Update an OAuth account on a user."""
44+
raise NotImplementedError()
45+
3646

3747
UserDatabaseDependency = DependencyCallable[BaseUserDatabase[UP]]

fastapi_users/manager.py

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,12 @@ async def create(
171171

172172
async def oauth_callback(
173173
self: "BaseUserManager[models.UOAP]",
174-
oauth_account: models.OAP,
174+
oauth_name: str,
175+
access_token: str,
176+
account_id: str,
177+
account_email: str,
178+
expires_at: Optional[int] = None,
179+
refresh_token: Optional[str] = None,
175180
request: Optional[Request] = None,
176181
) -> models.UOAP:
177182
"""
@@ -185,44 +190,53 @@ async def oauth_callback(
185190
If the user does not exist, it is created and the on_after_register handler
186191
is triggered.
187192
188-
:param oauth_account: The new OAuth account to create.
193+
:param oauth_name: Name of the OAuth client.
194+
:param access_token: Valid access token for the service provider.
195+
:param account_id: ID of the user on the service provider.
196+
:param account_email: E-mail of the user on the service provider.
197+
:param expires_at: Optional timestamp at which the access token expires.
198+
:param refresh_token: Optional refresh token to get a
199+
fresh access token from the service provider.
189200
:param request: Optional FastAPI request that
190201
triggered the operation, defaults to None
191202
:return: A user.
192203
"""
204+
oauth_account_dict = {
205+
"oauth_name": oauth_name,
206+
"access_token": access_token,
207+
"account_id": account_id,
208+
"account_email": account_email,
209+
"expires_at": expires_at,
210+
"refresh_token": refresh_token,
211+
}
212+
193213
try:
194-
user = await self.get_by_oauth_account(
195-
oauth_account.oauth_name, oauth_account.account_id
196-
)
214+
user = await self.get_by_oauth_account(oauth_name, account_id)
197215
except UserNotExists:
198216
try:
199217
# Link account
200-
user = await self.get_by_email(oauth_account.account_email)
201-
oauth_accounts = [*user.oauth_accounts, oauth_account]
202-
await self.user_db.update(user, {"oauth_accounts": oauth_accounts})
218+
user = await self.get_by_email(account_email)
219+
user = await self.user_db.add_oauth_account(user, oauth_account_dict)
203220
except UserNotExists:
204221
# Create account
205222
password = self.password_helper.generate()
206223
user_dict = {
207-
"email": oauth_account.account_email,
224+
"email": account_email,
208225
"hashed_password": self.password_helper.hash(password),
209-
"oauth_accounts": [oauth_account],
210226
}
211227
user = await self.user_db.create(user_dict)
228+
user = await self.user_db.add_oauth_account(user, oauth_account_dict)
212229
await self.on_after_register(user, request)
213230
else:
214231
# Update oauth
215-
updated_oauth_accounts = []
216-
for existing_oauth_account in user.oauth_accounts: # type: ignore
232+
for existing_oauth_account in user.oauth_accounts:
217233
if (
218-
existing_oauth_account.account_id == oauth_account.account_id
219-
and existing_oauth_account.oauth_name == oauth_account.oauth_name
234+
existing_oauth_account.account_id == account_id
235+
and existing_oauth_account.oauth_name == oauth_name
220236
):
221-
oauth_account.id = existing_oauth_account.id
222-
updated_oauth_accounts.append(oauth_account)
223-
else:
224-
updated_oauth_accounts.append(existing_oauth_account)
225-
await self.user_db.update(user, {"oauth_accounts": updated_oauth_accounts})
237+
user = await self.user_db.update_oauth_account(
238+
user, existing_oauth_account, oauth_account_dict
239+
)
226240

227241
return user
228242

fastapi_users/router/oauth.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from httpx_oauth.oauth2 import BaseOAuth2, OAuth2Token
77
from pydantic import BaseModel
88

9-
from fastapi_users import models, schemas
9+
from fastapi_users import models
1010
from fastapi_users.authentication import AuthenticationBackend, Strategy
1111
from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt
1212
from fastapi_users.manager import BaseUserManager, UserManagerDependency
@@ -114,17 +114,16 @@ async def callback(
114114
except jwt.DecodeError:
115115
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
116116

117-
new_oauth_account = schemas.BaseOAuthAccount(
118-
oauth_name=oauth_client.name,
119-
access_token=token["access_token"],
120-
expires_at=token.get("expires_at"),
121-
refresh_token=token.get("refresh_token"),
122-
account_id=account_id,
123-
account_email=account_email,
117+
user = await user_manager.oauth_callback(
118+
oauth_client.name,
119+
token["access_token"],
120+
account_id,
121+
account_email,
122+
token.get("expires_at"),
123+
token.get("refresh_token"),
124+
request,
124125
)
125126

126-
user = await user_manager.oauth_callback(new_oauth_account, request)
127-
128127
if not user.is_active:
129128
raise HTTPException(
130129
status_code=status.HTTP_400_BAD_REQUEST,

tests/conftest.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,32 @@ async def update(
430430
async def delete(self, user: UserOAuthModel) -> None:
431431
pass
432432

433+
async def add_oauth_account(
434+
self, user: UserOAuthModel, create_dict: Dict[str, Any]
435+
) -> UserOAuthModel:
436+
oauth_account = OAuthAccountModel(**create_dict)
437+
user.oauth_accounts.append(oauth_account)
438+
return user
439+
440+
async def update_oauth_account( # type: ignore
441+
self,
442+
user: UserOAuthModel,
443+
oauth_account: OAuthAccountModel,
444+
update_dict: Dict[str, Any],
445+
) -> UserOAuthModel:
446+
for field, value in update_dict.items():
447+
setattr(oauth_account, field, value)
448+
updated_oauth_accounts = []
449+
for existing_oauth_account in user.oauth_accounts:
450+
if (
451+
existing_oauth_account.account_id == oauth_account.account_id
452+
and existing_oauth_account.oauth_name == oauth_account.oauth_name
453+
):
454+
updated_oauth_accounts.append(oauth_account)
455+
else:
456+
updated_oauth_accounts.append(existing_oauth_account)
457+
return user
458+
433459
return MockUserDatabase()
434460

435461

tests/test_db_base.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
1+
import uuid
2+
13
import pytest
24

35
from fastapi_users.db import BaseUserDatabase
6+
from tests.conftest import OAuthAccountModel, UserModel
47

58

69
@pytest.mark.asyncio
710
@pytest.mark.db
8-
async def test_not_implemented_methods(user):
9-
base_user_db = BaseUserDatabase()
11+
async def test_not_implemented_methods(
12+
user: UserModel, oauth_account1: OAuthAccountModel
13+
):
14+
base_user_db = BaseUserDatabase[UserModel]()
1015

1116
with pytest.raises(NotImplementedError):
12-
await base_user_db.get("aaa")
17+
await base_user_db.get(uuid.uuid4())
1318

1419
with pytest.raises(NotImplementedError):
1520
await base_user_db.get_by_email("[email protected]")
@@ -25,3 +30,9 @@ async def test_not_implemented_methods(user):
2530

2631
with pytest.raises(NotImplementedError):
2732
await base_user_db.delete(user)
33+
34+
with pytest.raises(NotImplementedError):
35+
await base_user_db.add_oauth_account(user, {})
36+
37+
with pytest.raises(NotImplementedError):
38+
await base_user_db.update_oauth_account(user, oauth_account1, {})

tests/test_manager.py

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -173,14 +173,18 @@ async def test_existing_user_with_oauth(
173173
user_manager_oauth: UserManagerMock[UserOAuthModel],
174174
user_oauth: UserOAuthModel,
175175
):
176-
oauth_account = copy.deepcopy(user_oauth.oauth_accounts[0])
177-
oauth_account.id = uuid.uuid4()
178-
oauth_account.access_token = "UPDATED_TOKEN"
176+
oauth_account = user_oauth.oauth_accounts[0]
179177

180-
user = await user_manager_oauth.oauth_callback(oauth_account)
178+
user = await user_manager_oauth.oauth_callback(
179+
oauth_account.oauth_name,
180+
"UPDATED_TOKEN",
181+
oauth_account.account_id,
182+
oauth_account.account_email,
183+
)
181184

182185
assert user.id == user_oauth.id
183186
assert len(user.oauth_accounts) == 2
187+
assert user.oauth_accounts[0].id == oauth_account.id
184188
assert user.oauth_accounts[0].oauth_name == "service1"
185189
assert user.oauth_accounts[0].access_token == "UPDATED_TOKEN"
186190
assert user.oauth_accounts[1].access_token == "TOKEN"
@@ -193,36 +197,24 @@ async def test_existing_user_without_oauth(
193197
user_manager_oauth: UserManagerMock[UserOAuthModel],
194198
superuser_oauth: UserOAuthModel,
195199
):
196-
oauth_account = OAuthAccountModel(
197-
oauth_name="service1",
198-
access_token="TOKEN",
199-
expires_at=1579000751,
200-
account_id="superuser_oauth1",
201-
account_email=superuser_oauth.email,
200+
user = await user_manager_oauth.oauth_callback(
201+
"service1", "TOKEN", "superuser_oauth1", superuser_oauth.email, 1579000751
202202
)
203203

204-
user = await user_manager_oauth.oauth_callback(oauth_account)
205-
206204
assert user.id == superuser_oauth.id
207205
assert len(user.oauth_accounts) == 1
208-
assert user.oauth_accounts[0].id == oauth_account.id
206+
assert user.oauth_accounts[0].id is not None
209207

210208
assert user_manager_oauth.on_after_register.called is False
211209

212210
async def test_new_user(self, user_manager_oauth: UserManagerMock[UserOAuthModel]):
213-
oauth_account = OAuthAccountModel(
214-
oauth_name="service1",
215-
access_token="TOKEN",
216-
expires_at=1579000751,
217-
account_id="new_user_oauth1",
218-
account_email="[email protected]",
211+
user = await user_manager_oauth.oauth_callback(
212+
"service1", "TOKEN", "new_user_oauth1", "[email protected]", 1579000751
219213
)
220214

221-
user = await user_manager_oauth.oauth_callback(oauth_account)
222-
223215
assert user.email == "[email protected]"
224216
assert len(user.oauth_accounts) == 1
225-
assert user.oauth_accounts[0].id == oauth_account.id
217+
assert user.oauth_accounts[0].id is not None
226218

227219
assert user_manager_oauth.on_after_register.called is True
228220

0 commit comments

Comments
 (0)