1
+ import dataclasses
1
2
import uuid
2
- from datetime import datetime
3
- from typing import Dict , Optional
3
+ from datetime import datetime , timezone
4
+ from typing import Any , Dict , Optional
4
5
5
6
import pytest
6
7
7
8
from fastapi_users .authentication .strategy import (
8
9
AccessTokenDatabase ,
9
- BaseAccessToken ,
10
+ AccessTokenProtocol ,
10
11
DatabaseStrategy ,
11
12
)
13
+ from tests .conftest import UserModel
12
14
13
15
14
- class AccessToken (BaseAccessToken ):
15
- pass
16
+ @dataclasses .dataclass
17
+ class AccessTokenModel (AccessTokenProtocol ):
18
+ token : str
19
+ user_id : uuid .UUID
20
+ id : uuid .UUID = dataclasses .field (default_factory = uuid .uuid4 )
21
+ created_at : datetime = dataclasses .field (
22
+ default_factory = lambda : datetime .now (timezone .utc )
23
+ )
16
24
17
25
18
- class AccessTokenDatabaseMock (AccessTokenDatabase [AccessToken ]):
19
- store : Dict [str , AccessToken ]
26
+ class AccessTokenDatabaseMock (AccessTokenDatabase [AccessTokenModel ]):
27
+ store : Dict [str , AccessTokenModel ]
20
28
21
29
def __init__ (self ):
22
- self .access_token_model = AccessToken
23
30
self .store = {}
24
31
25
32
async def get_by_token (
26
33
self , token : str , max_age : Optional [datetime ] = None
27
- ) -> Optional [AccessToken ]:
34
+ ) -> Optional [AccessTokenModel ]:
28
35
try :
29
36
access_token = self .store [token ]
30
37
if max_age is not None and access_token .created_at < max_age :
@@ -33,15 +40,20 @@ async def get_by_token(
33
40
except KeyError :
34
41
return None
35
42
36
- async def create (self , access_token : AccessToken ) -> AccessToken :
43
+ async def create (self , create_dict : Dict [str , Any ]) -> AccessTokenModel :
44
+ access_token = AccessTokenModel (** create_dict )
37
45
self .store [access_token .token ] = access_token
38
46
return access_token
39
47
40
- async def update (self , access_token : AccessToken ) -> AccessToken :
48
+ async def update (
49
+ self , access_token : AccessTokenModel , update_dict : Dict [str , Any ]
50
+ ) -> AccessTokenModel :
51
+ for field , value in update_dict .items ():
52
+ setattr (access_token , field , value )
41
53
self .store [access_token .token ] = access_token
42
54
return access_token
43
55
44
- async def delete (self , access_token : AccessToken ) -> None :
56
+ async def delete (self , access_token : AccessTokenModel ) -> None :
45
57
try :
46
58
del self .store [access_token .token ]
47
59
except KeyError :
@@ -62,42 +74,47 @@ def database_strategy(access_token_database: AccessTokenDatabaseMock):
62
74
class TestReadToken :
63
75
@pytest .mark .asyncio
64
76
async def test_missing_token (
65
- self , database_strategy : DatabaseStrategy , user_manager
77
+ self ,
78
+ database_strategy : DatabaseStrategy [UserModel , AccessTokenModel ],
79
+ user_manager ,
66
80
):
67
81
authenticated_user = await database_strategy .read_token (None , user_manager )
68
82
assert authenticated_user is None
69
83
70
84
@pytest .mark .asyncio
71
85
async def test_invalid_token (
72
- self , database_strategy : DatabaseStrategy , user_manager
86
+ self ,
87
+ database_strategy : DatabaseStrategy [UserModel , AccessTokenModel ],
88
+ user_manager ,
73
89
):
74
90
authenticated_user = await database_strategy .read_token ("TOKEN" , user_manager )
75
91
assert authenticated_user is None
76
92
77
93
@pytest .mark .asyncio
78
94
async def test_valid_token_not_existing_user (
79
95
self ,
80
- database_strategy : DatabaseStrategy ,
96
+ database_strategy : DatabaseStrategy [ UserModel , AccessTokenModel ] ,
81
97
access_token_database : AccessTokenDatabaseMock ,
82
98
user_manager ,
83
99
):
84
100
await access_token_database .create (
85
- AccessToken (
86
- token = "TOKEN" , user_id = uuid .UUID ("d35d213e-f3d8-4f08-954a-7e0d1bea286f" )
87
- )
101
+ {
102
+ "token" : "TOKEN" ,
103
+ "user_id" : uuid .UUID ("d35d213e-f3d8-4f08-954a-7e0d1bea286f" ),
104
+ }
88
105
)
89
106
authenticated_user = await database_strategy .read_token ("TOKEN" , user_manager )
90
107
assert authenticated_user is None
91
108
92
109
@pytest .mark .asyncio
93
110
async def test_valid_token (
94
111
self ,
95
- database_strategy : DatabaseStrategy ,
112
+ database_strategy : DatabaseStrategy [ UserModel , AccessTokenModel ] ,
96
113
access_token_database : AccessTokenDatabaseMock ,
97
114
user_manager ,
98
- user ,
115
+ user : UserModel ,
99
116
):
100
- await access_token_database .create (AccessToken ( token = " TOKEN" , user_id = user .id ) )
117
+ await access_token_database .create ({ " token" : " TOKEN" , " user_id" : user .id } )
101
118
authenticated_user = await database_strategy .read_token ("TOKEN" , user_manager )
102
119
assert authenticated_user is not None
103
120
assert authenticated_user .id == user .id
@@ -106,9 +123,9 @@ async def test_valid_token(
106
123
@pytest .mark .authentication
107
124
@pytest .mark .asyncio
108
125
async def test_write_token (
109
- database_strategy : DatabaseStrategy ,
126
+ database_strategy : DatabaseStrategy [ UserModel , AccessTokenModel ] ,
110
127
access_token_database : AccessTokenDatabaseMock ,
111
- user ,
128
+ user : UserModel ,
112
129
):
113
130
token = await database_strategy .write_token (user )
114
131
@@ -120,11 +137,11 @@ async def test_write_token(
120
137
@pytest .mark .authentication
121
138
@pytest .mark .asyncio
122
139
async def test_destroy_token (
123
- database_strategy : DatabaseStrategy ,
140
+ database_strategy : DatabaseStrategy [ UserModel , AccessTokenModel ] ,
124
141
access_token_database : AccessTokenDatabaseMock ,
125
- user ,
142
+ user : UserModel ,
126
143
):
127
- await access_token_database .create (AccessToken ( token = " TOKEN" , user_id = user .id ) )
144
+ await access_token_database .create ({ " token" : " TOKEN" , " user_id" : user .id } )
128
145
129
146
await database_strategy .destroy_token ("TOKEN" , user )
130
147
0 commit comments