1
1
from datetime import datetime
2
2
from typing import Any , Dict , Generic , Optional , Type
3
3
4
- from fastapi_users .authentication .strategy .db import AP , AccessTokenDatabase
4
+ from fastapi_users .authentication .strategy .db import (
5
+ AP ,
6
+ APE ,
7
+ AccessRefreshTokenDatabase ,
8
+ AccessTokenDatabase ,
9
+ )
10
+ from fastapi_users .authentication .strategy .db .adapter import BaseAccessTokenDatabase
11
+ from fastapi_users .authentication .strategy .db .models import (
12
+ AccessRefreshTokenProtocol ,
13
+ AccessTokenProtocol ,
14
+ )
5
15
from pydantic import UUID4 , ConfigDict
6
16
from pydantic .version import VERSION as PYDANTIC_VERSION
7
17
from sqlalchemy import types
10
20
11
21
from fastapi_users_db_sqlmodel .generics import TIMESTAMPAware , now_utc
12
22
23
+ from . import SQLModelProtocolMetaclass
24
+
13
25
PYDANTIC_V2 = PYDANTIC_VERSION .startswith ("2." )
14
- class SQLModelBaseAccessToken (SQLModel ):
15
- __tablename__ = "accesstoken"
16
26
17
- token : str = Field (sa_type = types .String (length = 43 ), primary_key = True )
27
+
28
+ class SQLModelBaseAccessToken (
29
+ SQLModel , AccessTokenProtocol , metaclass = SQLModelProtocolMetaclass
30
+ ):
31
+ __tablename__ = "accesstoken" # type: ignore
32
+
33
+ token : str = Field (
34
+ sa_type = types .String (length = 43 ), # type: ignore
35
+ primary_key = True ,
36
+ )
18
37
created_at : datetime = Field (
19
38
default_factory = now_utc ,
20
- sa_type = TIMESTAMPAware (timezone = True ),
39
+ sa_type = TIMESTAMPAware (timezone = True ), # type: ignore
21
40
nullable = False ,
22
41
index = True ,
23
42
)
@@ -26,11 +45,26 @@ class SQLModelBaseAccessToken(SQLModel):
26
45
if PYDANTIC_V2 : # pragma: no cover
27
46
model_config = ConfigDict (from_attributes = True ) # type: ignore
28
47
else : # pragma: no cover
48
+
29
49
class Config :
30
50
orm_mode = True
31
51
32
52
33
- class SQLModelAccessTokenDatabase (Generic [AP ], AccessTokenDatabase [AP ]):
53
+ class SQLModelBaseAccessRefreshToken (
54
+ SQLModelBaseAccessToken ,
55
+ AccessRefreshTokenProtocol ,
56
+ metaclass = SQLModelProtocolMetaclass ,
57
+ ):
58
+ __tablename__ = "accessrefreshtoken"
59
+
60
+ refresh_token : str = Field (
61
+ sa_type = types .String (length = 43 ), # type: ignore
62
+ unique = True ,
63
+ index = True ,
64
+ )
65
+
66
+
67
+ class BaseSQLModelAccessTokenDatabase (Generic [AP ], BaseAccessTokenDatabase [str , AP ]):
34
68
"""
35
69
Access token database adapter for SQLModel.
36
70
@@ -77,7 +111,47 @@ async def delete(self, access_token: AP) -> None:
77
111
self .session .commit ()
78
112
79
113
80
- class SQLModelAccessTokenDatabaseAsync (Generic [AP ], AccessTokenDatabase [AP ]):
114
+ class SQLModelAccessTokenDatabase (
115
+ Generic [AP ], BaseSQLModelAccessTokenDatabase [AP ], AccessTokenDatabase [AP ]
116
+ ):
117
+ """
118
+ Access token database adapter for SQLModel.
119
+
120
+ :param session: SQLAlchemy session.
121
+ :param access_token_model: SQLModel access token model.
122
+ """
123
+
124
+
125
+ class SQLModelAccessRefreshTokenDatabase (
126
+ Generic [APE ], BaseSQLModelAccessTokenDatabase [APE ], AccessRefreshTokenDatabase [APE ]
127
+ ):
128
+ """
129
+ Access token database adapter for SQLModel.
130
+
131
+ :param session: SQLAlchemy session.
132
+ :param access_token_model: SQLModel access refresh token model.
133
+ """
134
+
135
+ async def get_by_refresh_token (
136
+ self , refresh_token : str , max_age : Optional [datetime ] = None
137
+ ) -> Optional [APE ]:
138
+ statement = select (self .access_token_model ).where ( # type: ignore
139
+ self .access_token_model .refresh_token == refresh_token
140
+ )
141
+ if max_age is not None :
142
+ statement = statement .where (self .access_token_model .created_at >= max_age )
143
+
144
+ results = self .session .exec (statement )
145
+ access_token = results .first ()
146
+ if access_token is None :
147
+ return None
148
+
149
+ return access_token
150
+
151
+
152
+ class BaseSQLModelAccessTokenDatabaseAsync (
153
+ Generic [AP ], BaseAccessTokenDatabase [str , AP ]
154
+ ):
81
155
"""
82
156
Access token database adapter for SQLModel working purely asynchronously.
83
157
@@ -122,3 +196,31 @@ async def update(self, access_token: AP, update_dict: Dict[str, Any]) -> AP:
122
196
async def delete (self , access_token : AP ) -> None :
123
197
await self .session .delete (access_token )
124
198
await self .session .commit ()
199
+
200
+
201
+ class SQLModelAccessTokenDatabaseAsync (
202
+ BaseSQLModelAccessTokenDatabaseAsync [AP ], AccessTokenDatabase [AP ], Generic [AP ]
203
+ ):
204
+ pass
205
+
206
+
207
+ class SQLModelAccessRefreshTokenDatabaseAsync (
208
+ BaseSQLModelAccessTokenDatabaseAsync [APE ],
209
+ AccessRefreshTokenDatabase [APE ],
210
+ Generic [APE ],
211
+ ):
212
+ async def get_by_refresh_token (
213
+ self , refresh_token : str , max_age : Optional [datetime ] = None
214
+ ) -> Optional [APE ]:
215
+ statement = select (self .access_token_model ).where ( # type: ignore
216
+ self .access_token_model .refresh_token == refresh_token
217
+ )
218
+ if max_age is not None :
219
+ statement = statement .where (self .access_token_model .created_at >= max_age )
220
+
221
+ results = await self .session .execute (statement )
222
+ access_token = results .first ()
223
+ if access_token is None :
224
+ return None
225
+
226
+ return access_token [0 ]
0 commit comments