Skip to content

Commit 7598f4f

Browse files
committed
feat: add refresh token classes, fix base models not inherited from protocols
1 parent fdb4ea4 commit 7598f4f

File tree

2 files changed

+131
-16
lines changed

2 files changed

+131
-16
lines changed

fastapi_users_db_sqlmodel/__init__.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,32 @@
11
"""FastAPI Users database adapter for SQLModel."""
22

33
import uuid
4-
from typing import TYPE_CHECKING, Any, Dict, Generic, Optional, Type
4+
from typing import TYPE_CHECKING, Any, Dict, Generic, Optional, Type, _ProtocolMeta
55

66
from fastapi_users.db.base import BaseUserDatabase
7-
from fastapi_users.models import ID, OAP, UP
7+
from fastapi_users.models import (
8+
ID,
9+
UP,
10+
OAuthAccountProtocol,
11+
UserProtocol,
12+
)
813
from pydantic import UUID4, ConfigDict, EmailStr
914
from pydantic.version import VERSION as PYDANTIC_VERSION
1015
from sqlalchemy.ext.asyncio import AsyncSession
1116
from sqlalchemy.orm import selectinload
1217
from sqlmodel import AutoString, Field, Session, SQLModel, func, select
18+
from sqlmodel.main import SQLModelMetaclass
1319

1420
__version__ = "0.3.0"
1521
PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")
1622

1723

18-
class SQLModelBaseUserDB(SQLModel):
19-
__tablename__ = "user"
24+
class SQLModelProtocolMetaclass(SQLModelMetaclass, _ProtocolMeta):
25+
pass
26+
27+
28+
class SQLModelBaseUserDB(SQLModel, UserProtocol, metaclass=SQLModelProtocolMetaclass):
29+
__tablename__ = "user" # type: ignore
2030

2131
id: UUID4 = Field(default_factory=uuid.uuid4, primary_key=True, nullable=False)
2232
if TYPE_CHECKING: # pragma: no cover
@@ -41,8 +51,10 @@ class Config:
4151
orm_mode = True
4252

4353

44-
class SQLModelBaseOAuthAccount(SQLModel):
45-
__tablename__ = "oauthaccount"
54+
class SQLModelBaseOAuthAccount(
55+
SQLModel, OAuthAccountProtocol, metaclass=SQLModelProtocolMetaclass
56+
):
57+
__tablename__ = "oauthaccount" # type: ignore
4658

4759
id: UUID4 = Field(default_factory=uuid.uuid4, primary_key=True)
4860
user_id: UUID4 = Field(foreign_key="user.id", nullable=False)
@@ -54,10 +66,11 @@ class SQLModelBaseOAuthAccount(SQLModel):
5466
account_email: str = Field(nullable=False)
5567

5668
if PYDANTIC_V2:
69+
# pragma: no cover
5770
model_config = ConfigDict(from_attributes=True) # type: ignore
5871
else:
5972

60-
class Config:
73+
class Config: # pragma: no cover
6174
orm_mode = True
6275

6376

@@ -143,7 +156,7 @@ async def add_oauth_account(self, user: UP, create_dict: Dict[str, Any]) -> UP:
143156
return user
144157

145158
async def update_oauth_account(
146-
self, user: UP, oauth_account: OAP, update_dict: Dict[str, Any]
159+
self, user: UP, oauth_account: OAuthAccountProtocol, update_dict: Dict[str, Any]
147160
) -> UP:
148161
if self.oauth_account_model is None:
149162
raise NotImplementedError()
@@ -243,7 +256,7 @@ async def add_oauth_account(self, user: UP, create_dict: Dict[str, Any]) -> UP:
243256
return user
244257

245258
async def update_oauth_account(
246-
self, user: UP, oauth_account: OAP, update_dict: Dict[str, Any]
259+
self, user: UP, oauth_account: OAuthAccountProtocol, update_dict: Dict[str, Any]
247260
) -> UP:
248261
if self.oauth_account_model is None:
249262
raise NotImplementedError()

fastapi_users_db_sqlmodel/access_token.py

Lines changed: 109 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,17 @@
11
from datetime import datetime
22
from typing import Any, Dict, Generic, Optional, Type
33

4-
from fastapi_users.authentication.strategy.db import AP, AccessTokenDatabase
4+
from fastapi_users.authentication.strategy.db import (
5+
AP,
6+
APE,
7+
AccessRefreshTokenDatabase,
8+
AccessTokenDatabase,
9+
)
10+
from fastapi_users.authentication.strategy.db.adapter import BaseAccessTokenDatabase
11+
from fastapi_users.authentication.strategy.db.models import (
12+
AccessRefreshTokenProtocol,
13+
AccessTokenProtocol,
14+
)
515
from pydantic import UUID4, ConfigDict
616
from pydantic.version import VERSION as PYDANTIC_VERSION
717
from sqlalchemy import types
@@ -10,14 +20,23 @@
1020

1121
from fastapi_users_db_sqlmodel.generics import TIMESTAMPAware, now_utc
1222

23+
from . import SQLModelProtocolMetaclass
24+
1325
PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")
14-
class SQLModelBaseAccessToken(SQLModel):
15-
__tablename__ = "accesstoken"
1626

17-
token: str = Field(sa_type=types.String(length=43), primary_key=True)
27+
28+
class SQLModelBaseAccessToken(
29+
SQLModel, AccessTokenProtocol, metaclass=SQLModelProtocolMetaclass
30+
):
31+
__tablename__ = "accesstoken" # type: ignore
32+
33+
token: str = Field(
34+
sa_type=types.String(length=43), # type: ignore
35+
primary_key=True,
36+
)
1837
created_at: datetime = Field(
1938
default_factory=now_utc,
20-
sa_type=TIMESTAMPAware(timezone=True),
39+
sa_type=TIMESTAMPAware(timezone=True), # type: ignore
2140
nullable=False,
2241
index=True,
2342
)
@@ -26,11 +45,26 @@ class SQLModelBaseAccessToken(SQLModel):
2645
if PYDANTIC_V2: # pragma: no cover
2746
model_config = ConfigDict(from_attributes=True) # type: ignore
2847
else: # pragma: no cover
48+
2949
class Config:
3050
orm_mode = True
3151

3252

33-
class SQLModelAccessTokenDatabase(Generic[AP], AccessTokenDatabase[AP]):
53+
class SQLModelBaseAccessRefreshToken(
54+
SQLModelBaseAccessToken,
55+
AccessRefreshTokenProtocol,
56+
metaclass=SQLModelProtocolMetaclass,
57+
):
58+
__tablename__ = "accessrefreshtoken"
59+
60+
refresh_token: str = Field(
61+
sa_type=types.String(length=43), # type: ignore
62+
unique=True,
63+
index=True,
64+
)
65+
66+
67+
class BaseSQLModelAccessTokenDatabase(Generic[AP], BaseAccessTokenDatabase[str, AP]):
3468
"""
3569
Access token database adapter for SQLModel.
3670
@@ -77,7 +111,47 @@ async def delete(self, access_token: AP) -> None:
77111
self.session.commit()
78112

79113

80-
class SQLModelAccessTokenDatabaseAsync(Generic[AP], AccessTokenDatabase[AP]):
114+
class SQLModelAccessTokenDatabase(
115+
Generic[AP], BaseSQLModelAccessTokenDatabase[AP], AccessTokenDatabase[AP]
116+
):
117+
"""
118+
Access token database adapter for SQLModel.
119+
120+
:param session: SQLAlchemy session.
121+
:param access_token_model: SQLModel access token model.
122+
"""
123+
124+
125+
class SQLModelAccessRefreshTokenDatabase(
126+
Generic[APE], BaseSQLModelAccessTokenDatabase[APE], AccessRefreshTokenDatabase[APE]
127+
):
128+
"""
129+
Access token database adapter for SQLModel.
130+
131+
:param session: SQLAlchemy session.
132+
:param access_token_model: SQLModel access refresh token model.
133+
"""
134+
135+
async def get_by_refresh_token(
136+
self, refresh_token: str, max_age: Optional[datetime] = None
137+
) -> Optional[APE]:
138+
statement = select(self.access_token_model).where( # type: ignore
139+
self.access_token_model.refresh_token == refresh_token
140+
)
141+
if max_age is not None:
142+
statement = statement.where(self.access_token_model.created_at >= max_age)
143+
144+
results = self.session.exec(statement)
145+
access_token = results.first()
146+
if access_token is None:
147+
return None
148+
149+
return access_token
150+
151+
152+
class BaseSQLModelAccessTokenDatabaseAsync(
153+
Generic[AP], BaseAccessTokenDatabase[str, AP]
154+
):
81155
"""
82156
Access token database adapter for SQLModel working purely asynchronously.
83157
@@ -122,3 +196,31 @@ async def update(self, access_token: AP, update_dict: Dict[str, Any]) -> AP:
122196
async def delete(self, access_token: AP) -> None:
123197
await self.session.delete(access_token)
124198
await self.session.commit()
199+
200+
201+
class SQLModelAccessTokenDatabaseAsync(
202+
BaseSQLModelAccessTokenDatabaseAsync[AP], AccessTokenDatabase[AP], Generic[AP]
203+
):
204+
pass
205+
206+
207+
class SQLModelAccessRefreshTokenDatabaseAsync(
208+
BaseSQLModelAccessTokenDatabaseAsync[APE],
209+
AccessRefreshTokenDatabase[APE],
210+
Generic[APE],
211+
):
212+
async def get_by_refresh_token(
213+
self, refresh_token: str, max_age: Optional[datetime] = None
214+
) -> Optional[APE]:
215+
statement = select(self.access_token_model).where( # type: ignore
216+
self.access_token_model.refresh_token == refresh_token
217+
)
218+
if max_age is not None:
219+
statement = statement.where(self.access_token_model.created_at >= max_age)
220+
221+
results = await self.session.execute(statement)
222+
access_token = results.first()
223+
if access_token is None:
224+
return None
225+
226+
return access_token[0]

0 commit comments

Comments
 (0)