Skip to content

Commit 87ac51a

Browse files
committed
Revamp AccessToken DB strategy to adopt generic model approach
1 parent e271cc1 commit 87ac51a

File tree

7 files changed

+85
-67
lines changed

7 files changed

+85
-67
lines changed

fastapi_users/authentication/strategy/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
StrategyDestroyNotSupportedError,
44
)
55
from fastapi_users.authentication.strategy.db import (
6-
A,
6+
AP,
77
AccessTokenDatabase,
8-
BaseAccessToken,
8+
AccessTokenProtocol,
99
DatabaseStrategy,
1010
)
1111
from fastapi_users.authentication.strategy.jwt import JWTStrategy
@@ -16,9 +16,9 @@
1616
pass
1717

1818
__all__ = [
19-
"A",
19+
"AP",
2020
"AccessTokenDatabase",
21-
"BaseAccessToken",
21+
"AccessTokenProtocol",
2222
"DatabaseStrategy",
2323
"JWTStrategy",
2424
"Strategy",
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from fastapi_users.authentication.strategy.db.adapter import AccessTokenDatabase
2-
from fastapi_users.authentication.strategy.db.models import A, BaseAccessToken
2+
from fastapi_users.authentication.strategy.db.models import AP, AccessTokenProtocol
33
from fastapi_users.authentication.strategy.db.strategy import DatabaseStrategy
44

5-
__all__ = ["A", "AccessTokenDatabase", "BaseAccessToken", "DatabaseStrategy"]
5+
__all__ = ["AP", "AccessTokenDatabase", "AccessTokenProtocol", "DatabaseStrategy"]
Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,32 @@
11
import sys
22
from datetime import datetime
3-
from typing import Generic, Optional, Type
3+
from typing import Any, Dict, Generic, Optional
44

55
if sys.version_info < (3, 8):
66
from typing_extensions import Protocol # pragma: no cover
77
else:
88
from typing import Protocol # pragma: no cover
99

10-
from fastapi_users.authentication.strategy.db.models import A
10+
from fastapi_users.authentication.strategy.db.models import AP
1111

1212

13-
class AccessTokenDatabase(Protocol, Generic[A]):
14-
"""
15-
Protocol for retrieving, creating and updating access tokens from a database.
16-
17-
:param access_token_model: Pydantic model of an access token.
18-
"""
19-
20-
access_token_model: Type[A]
13+
class AccessTokenDatabase(Protocol, Generic[AP]):
14+
"""Protocol for retrieving, creating and updating access tokens from a database."""
2115

2216
async def get_by_token(
2317
self, token: str, max_age: Optional[datetime] = None
24-
) -> Optional[A]:
18+
) -> Optional[AP]:
2519
"""Get a single access token by token."""
2620
... # pragma: no cover
2721

28-
async def create(self, access_token: A) -> A:
22+
async def create(self, create_dict: Dict[str, Any]) -> AP:
2923
"""Create an access token."""
3024
... # pragma: no cover
3125

32-
async def update(self, access_token: A) -> A:
26+
async def update(self, access_token: AP, update_dict: Dict[str, Any]) -> AP:
3327
"""Update an access token."""
3428
... # pragma: no cover
3529

36-
async def delete(self, access_token: A) -> None:
30+
async def delete(self, access_token: AP) -> None:
3731
"""Delete an access token."""
3832
... # pragma: no cover
Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,23 @@
1-
from datetime import datetime, timezone
1+
import sys
2+
import uuid
3+
from datetime import datetime
24
from typing import TypeVar
35

4-
from pydantic import UUID4, BaseModel, Field
6+
if sys.version_info < (3, 8):
7+
from typing_extensions import Protocol # pragma: no cover
8+
else:
9+
from typing import Protocol # pragma: no cover
510

611

7-
def now_utc():
8-
return datetime.now(timezone.utc)
9-
10-
11-
class BaseAccessToken(BaseModel):
12-
"""Base access token model."""
12+
class AccessTokenProtocol(Protocol):
13+
"""Access token protocol that ORM model should follow."""
1314

1415
token: str
15-
user_id: UUID4
16-
created_at: datetime = Field(default_factory=now_utc)
16+
user_id: uuid.UUID
17+
created_at: datetime
1718

18-
class Config:
19-
orm_mode = True
19+
def __init__(self, *args, **kwargs) -> None:
20+
... # pragma: no cover
2021

2122

22-
A = TypeVar("A", bound=BaseAccessToken)
23+
AP = TypeVar("AP", bound=AccessTokenProtocol)

fastapi_users/authentication/strategy/db/strategy.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
import secrets
22
from datetime import datetime, timedelta, timezone
3-
from typing import Generic, Optional
3+
from typing import Any, Dict, Generic, Optional
44

55
from fastapi_users import models
66
from fastapi_users.authentication.strategy.base import Strategy
77
from fastapi_users.authentication.strategy.db.adapter import AccessTokenDatabase
8-
from fastapi_users.authentication.strategy.db.models import A
8+
from fastapi_users.authentication.strategy.db.models import AP
99
from fastapi_users.manager import BaseUserManager, UserNotExists
1010

1111

12-
class DatabaseStrategy(Strategy, Generic[models.UP, A]):
12+
class DatabaseStrategy(Strategy, Generic[models.UP, AP]):
1313
def __init__(
14-
self, database: AccessTokenDatabase[A], lifetime_seconds: Optional[int] = None
14+
self, database: AccessTokenDatabase[AP], lifetime_seconds: Optional[int] = None
1515
):
1616
self.database = database
1717
self.lifetime_seconds = lifetime_seconds
@@ -39,15 +39,15 @@ async def read_token(
3939
return None
4040

4141
async def write_token(self, user: models.UP) -> str:
42-
access_token = self._create_access_token(user)
43-
await self.database.create(access_token)
42+
access_token_dict = self._create_access_token_dict(user)
43+
access_token = await self.database.create(access_token_dict)
4444
return access_token.token
4545

4646
async def destroy_token(self, token: str, user: models.UP) -> None:
4747
access_token = await self.database.get_by_token(token)
4848
if access_token is not None:
4949
await self.database.delete(access_token)
5050

51-
def _create_access_token(self, user: models.UP) -> A:
51+
def _create_access_token_dict(self, user: models.UP) -> Dict[str, Any]:
5252
token = secrets.token_urlsafe()
53-
return self.database.access_token_model(token=token, user_id=user.id)
53+
return {"token": token, "user_id": user.id}

fastapi_users/models.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010

1111
class UserProtocol(Protocol):
12+
"""User protocol that ORM model should follow."""
13+
1214
id: uuid.UUID
1315
email: str
1416
hashed_password: str
@@ -21,6 +23,8 @@ def __init__(self, *args, **kwargs) -> None:
2123

2224

2325
class OAuthAccountProtocol(Protocol):
26+
"""OAuth account protocol that ORM model should follow."""
27+
2428
id: uuid.UUID
2529
oauth_name: str
2630
access_token: str
@@ -38,6 +42,8 @@ def __init__(self, *args, **kwargs) -> None:
3842

3943

4044
class UserOAuthProtocol(UserProtocol, Generic[OAP]):
45+
"""User protocol including a list of OAuth accounts."""
46+
4147
oauth_accounts: List[OAP]
4248

4349

tests/test_authentication_strategy_db.py

Lines changed: 43 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,37 @@
1+
import dataclasses
12
import uuid
2-
from datetime import datetime
3-
from typing import Dict, Optional
3+
from datetime import datetime, timezone
4+
from typing import Any, Dict, Optional
45

56
import pytest
67

78
from fastapi_users.authentication.strategy import (
89
AccessTokenDatabase,
9-
BaseAccessToken,
10+
AccessTokenProtocol,
1011
DatabaseStrategy,
1112
)
13+
from tests.conftest import UserModel
1214

1315

14-
class AccessToken(BaseAccessToken):
15-
pass
16+
@dataclasses.dataclass
17+
class AccessTokenModel(AccessTokenProtocol):
18+
token: str
19+
user_id: uuid.UUID
20+
id: uuid.UUID = dataclasses.field(default_factory=uuid.uuid4)
21+
created_at: datetime = dataclasses.field(
22+
default_factory=lambda: datetime.now(timezone.utc)
23+
)
1624

1725

18-
class AccessTokenDatabaseMock(AccessTokenDatabase[AccessToken]):
19-
store: Dict[str, AccessToken]
26+
class AccessTokenDatabaseMock(AccessTokenDatabase[AccessTokenModel]):
27+
store: Dict[str, AccessTokenModel]
2028

2129
def __init__(self):
22-
self.access_token_model = AccessToken
2330
self.store = {}
2431

2532
async def get_by_token(
2633
self, token: str, max_age: Optional[datetime] = None
27-
) -> Optional[AccessToken]:
34+
) -> Optional[AccessTokenModel]:
2835
try:
2936
access_token = self.store[token]
3037
if max_age is not None and access_token.created_at < max_age:
@@ -33,15 +40,20 @@ async def get_by_token(
3340
except KeyError:
3441
return None
3542

36-
async def create(self, access_token: AccessToken) -> AccessToken:
43+
async def create(self, create_dict: Dict[str, Any]) -> AccessTokenModel:
44+
access_token = AccessTokenModel(**create_dict)
3745
self.store[access_token.token] = access_token
3846
return access_token
3947

40-
async def update(self, access_token: AccessToken) -> AccessToken:
48+
async def update(
49+
self, access_token: AccessTokenModel, update_dict: Dict[str, Any]
50+
) -> AccessTokenModel:
51+
for field, value in update_dict.items():
52+
setattr(access_token, field, value)
4153
self.store[access_token.token] = access_token
4254
return access_token
4355

44-
async def delete(self, access_token: AccessToken) -> None:
56+
async def delete(self, access_token: AccessTokenModel) -> None:
4557
try:
4658
del self.store[access_token.token]
4759
except KeyError:
@@ -62,42 +74,47 @@ def database_strategy(access_token_database: AccessTokenDatabaseMock):
6274
class TestReadToken:
6375
@pytest.mark.asyncio
6476
async def test_missing_token(
65-
self, database_strategy: DatabaseStrategy, user_manager
77+
self,
78+
database_strategy: DatabaseStrategy[UserModel, AccessTokenModel],
79+
user_manager,
6680
):
6781
authenticated_user = await database_strategy.read_token(None, user_manager)
6882
assert authenticated_user is None
6983

7084
@pytest.mark.asyncio
7185
async def test_invalid_token(
72-
self, database_strategy: DatabaseStrategy, user_manager
86+
self,
87+
database_strategy: DatabaseStrategy[UserModel, AccessTokenModel],
88+
user_manager,
7389
):
7490
authenticated_user = await database_strategy.read_token("TOKEN", user_manager)
7591
assert authenticated_user is None
7692

7793
@pytest.mark.asyncio
7894
async def test_valid_token_not_existing_user(
7995
self,
80-
database_strategy: DatabaseStrategy,
96+
database_strategy: DatabaseStrategy[UserModel, AccessTokenModel],
8197
access_token_database: AccessTokenDatabaseMock,
8298
user_manager,
8399
):
84100
await access_token_database.create(
85-
AccessToken(
86-
token="TOKEN", user_id=uuid.UUID("d35d213e-f3d8-4f08-954a-7e0d1bea286f")
87-
)
101+
{
102+
"token": "TOKEN",
103+
"user_id": uuid.UUID("d35d213e-f3d8-4f08-954a-7e0d1bea286f"),
104+
}
88105
)
89106
authenticated_user = await database_strategy.read_token("TOKEN", user_manager)
90107
assert authenticated_user is None
91108

92109
@pytest.mark.asyncio
93110
async def test_valid_token(
94111
self,
95-
database_strategy: DatabaseStrategy,
112+
database_strategy: DatabaseStrategy[UserModel, AccessTokenModel],
96113
access_token_database: AccessTokenDatabaseMock,
97114
user_manager,
98-
user,
115+
user: UserModel,
99116
):
100-
await access_token_database.create(AccessToken(token="TOKEN", user_id=user.id))
117+
await access_token_database.create({"token": "TOKEN", "user_id": user.id})
101118
authenticated_user = await database_strategy.read_token("TOKEN", user_manager)
102119
assert authenticated_user is not None
103120
assert authenticated_user.id == user.id
@@ -106,9 +123,9 @@ async def test_valid_token(
106123
@pytest.mark.authentication
107124
@pytest.mark.asyncio
108125
async def test_write_token(
109-
database_strategy: DatabaseStrategy,
126+
database_strategy: DatabaseStrategy[UserModel, AccessTokenModel],
110127
access_token_database: AccessTokenDatabaseMock,
111-
user,
128+
user: UserModel,
112129
):
113130
token = await database_strategy.write_token(user)
114131

@@ -120,11 +137,11 @@ async def test_write_token(
120137
@pytest.mark.authentication
121138
@pytest.mark.asyncio
122139
async def test_destroy_token(
123-
database_strategy: DatabaseStrategy,
140+
database_strategy: DatabaseStrategy[UserModel, AccessTokenModel],
124141
access_token_database: AccessTokenDatabaseMock,
125-
user,
142+
user: UserModel,
126143
):
127-
await access_token_database.create(AccessToken(token="TOKEN", user_id=user.id))
144+
await access_token_database.create({"token": "TOKEN", "user_id": user.id})
128145

129146
await database_strategy.destroy_token("TOKEN", user)
130147

0 commit comments

Comments
 (0)