Skip to content

Commit 1944dc3

Browse files
authored
Add option to control sqlalchemy count query (#624)
1 parent 07ed043 commit 1944dc3

File tree

7 files changed

+120
-16
lines changed

7 files changed

+120
-16
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,6 @@ dmypy.json
134134

135135
# IDEA
136136
.idea/
137+
138+
# ruff
139+
.ruff_cache/

fastapi_pagination/ext/sqlalchemy.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from contextlib import suppress
1111
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, overload
1212

13-
from sqlalchemy import func, literal_column, select
13+
from sqlalchemy import func, select
1414
from sqlalchemy.orm import Query, Session, noload
1515
from typing_extensions import TypeAlias
1616

@@ -51,9 +51,16 @@ def paginate_query(query: Select, params: AbstractParams) -> Select:
5151
return generic_query_apply_params(query, params.to_raw_params().as_limit_offset())
5252

5353

54-
def count_query(query: Select) -> Select:
55-
count_subquery = query.order_by(None).options(noload("*")).subquery()
56-
return select(func.count(literal_column("*"))).select_from(count_subquery)
54+
def count_query(query: Select, *, use_subquery: bool = True) -> Select:
55+
query = query.order_by(None).options(noload("*"))
56+
57+
if use_subquery:
58+
return select(func.count()).select_from(query.subquery())
59+
60+
return query.with_only_columns( # noqa: PIE804
61+
func.count(),
62+
**{"maintain_column_froms": True},
63+
)
5764

5865

5966
def _maybe_unique(result: Any, unique: bool) -> Any:
@@ -66,6 +73,7 @@ def exec_pagination(
6673
conn: SyncConn,
6774
transformer: Optional[ItemsTransformer] = None,
6875
additional_data: AdditionalData = None,
76+
subquery_count: bool = True,
6977
unique: bool = True,
7078
async_: bool = False,
7179
) -> AbstractPage[Any]:
@@ -82,6 +90,8 @@ def _apply_items_transformer(*args: Any, **kwargs: Any) -> Any:
8290
if is_cursor(raw_params):
8391
if paging is None:
8492
raise ImportError("sqlakeyset is not installed")
93+
if not getattr(query, "_order_by_clauses", True):
94+
raise ValueError("Cursor pagination requires ordering")
8595

8696
page = paging.select_page(
8797
conn,
@@ -100,7 +110,7 @@ def _apply_items_transformer(*args: Any, **kwargs: Any) -> Any:
100110
**(additional_data or {}),
101111
)
102112

103-
total = conn.scalar(count_query(query))
113+
total = conn.scalar(count_query(query, use_subquery=subquery_count))
104114
query = paginate_query(query, params)
105115
items = _maybe_unique(conn.execute(query), unique)
106116
items = unwrap_scalars(items)
@@ -130,6 +140,7 @@ def paginate(
130140
query: Query[Any],
131141
params: Optional[AbstractParams] = None,
132142
*,
143+
subquery_count: bool = True,
133144
transformer: Optional[SyncItemsTransformer] = None,
134145
additional_data: AdditionalData = None,
135146
) -> Any:
@@ -142,6 +153,7 @@ def paginate(
142153
query: Select,
143154
params: Optional[AbstractParams] = None,
144155
*,
156+
subquery_count: bool = True,
145157
transformer: Optional[SyncItemsTransformer] = None,
146158
additional_data: AdditionalData = None,
147159
unique: bool = True,
@@ -155,6 +167,7 @@ async def paginate(
155167
query: Select,
156168
params: Optional[AbstractParams] = None,
157169
*,
170+
subquery_count: bool = True,
158171
transformer: Optional[AsyncItemsTransformer] = None,
159172
additional_data: AdditionalData = None,
160173
unique: bool = True,
@@ -166,9 +179,9 @@ def paginate(*args: Any, **kwargs: Any) -> Any:
166179
try:
167180
assert args
168181
assert isinstance(args[0], Query)
169-
query, conn, params, transformer, additional_data, unique = _old_paginate_sign(*args, **kwargs)
182+
query, conn, params, transformer, additional_data, unique, subquery_count = _old_paginate_sign(*args, **kwargs)
170183
except (TypeError, AssertionError):
171-
query, conn, params, transformer, additional_data, unique = _new_paginate_sign(*args, **kwargs)
184+
query, conn, params, transformer, additional_data, unique, subquery_count = _new_paginate_sign(*args, **kwargs)
172185

173186
params, _ = verify_params(params, "limit-offset", "cursor")
174187

@@ -181,20 +194,22 @@ def paginate(*args: Any, **kwargs: Any) -> Any:
181194
sync_conn,
182195
transformer,
183196
additional_data,
197+
subquery_count,
184198
unique,
185199
async_=True,
186200
)
187201

188-
return exec_pagination(query, params, conn, transformer, additional_data, unique, async_=False)
202+
return exec_pagination(query, params, conn, transformer, additional_data, subquery_count, unique, async_=False)
189203

190204

191205
def _old_paginate_sign(
192206
query: Query[Any],
193207
params: Optional[AbstractParams] = None,
194208
*,
209+
subquery_count: bool = True,
195210
transformer: Optional[ItemsTransformer] = None,
196211
additional_data: AdditionalData = None,
197-
) -> Tuple[Select, SyncConn, Optional[AbstractParams], Optional[ItemsTransformer], AdditionalData, bool]:
212+
) -> Tuple[Select, SyncConn, Optional[AbstractParams], Optional[ItemsTransformer], AdditionalData, bool, bool]:
198213
if query.session is None:
199214
raise ValueError("query.session is None")
200215

@@ -205,16 +220,17 @@ def _old_paginate_sign(
205220
stacklevel=3,
206221
)
207222

208-
return query, query.session, params, transformer, additional_data, True # type: ignore
223+
return query, query.session, params, transformer, additional_data, True, subquery_count # type: ignore
209224

210225

211226
def _new_paginate_sign(
212227
conn: SyncConn,
213228
query: Select,
214229
params: Optional[AbstractParams] = None,
215230
*,
231+
subquery_count: bool = True,
216232
transformer: Optional[ItemsTransformer] = None,
217233
additional_data: AdditionalData = None,
218234
unique: bool = True,
219-
) -> Tuple[Select, SyncConn, Optional[AbstractParams], Optional[ItemsTransformer], AdditionalData, bool]:
220-
return query, conn, params, transformer, additional_data, unique
235+
) -> Tuple[Select, SyncConn, Optional[AbstractParams], Optional[ItemsTransformer], AdditionalData, bool, bool]:
236+
return query, conn, params, transformer, additional_data, unique, subquery_count

fastapi_pagination/ext/sqlmodel.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def paginate(
2020
query: Select[TSQLModel],
2121
params: Optional[AbstractParams] = None,
2222
*,
23+
subquery_count: bool = True,
2324
transformer: Optional[SyncItemsTransformer] = None,
2425
additional_data: AdditionalData = None,
2526
unique: bool = True,
@@ -33,6 +34,7 @@ def paginate(
3334
query: SelectOfScalar[T],
3435
params: Optional[AbstractParams] = None,
3536
*,
37+
subquery_count: bool = True,
3638
transformer: Optional[SyncItemsTransformer] = None,
3739
additional_data: AdditionalData = None,
3840
unique: bool = True,
@@ -46,6 +48,7 @@ def paginate(
4648
query: Type[TSQLModel],
4749
params: Optional[AbstractParams] = None,
4850
*,
51+
subquery_count: bool = True,
4952
transformer: Optional[SyncItemsTransformer] = None,
5053
additional_data: AdditionalData = None,
5154
unique: bool = True,
@@ -59,6 +62,7 @@ async def paginate(
5962
query: Select[TSQLModel],
6063
params: Optional[AbstractParams] = None,
6164
*,
65+
subquery_count: bool = True,
6266
transformer: Optional[AsyncItemsTransformer] = None,
6367
additional_data: AdditionalData = None,
6468
unique: bool = True,
@@ -72,6 +76,7 @@ async def paginate(
7276
query: SelectOfScalar[T],
7377
params: Optional[AbstractParams] = None,
7478
*,
79+
subquery_count: bool = True,
7580
transformer: Optional[AsyncItemsTransformer] = None,
7681
additional_data: AdditionalData = None,
7782
unique: bool = True,
@@ -85,6 +90,7 @@ async def paginate(
8590
query: Type[TSQLModel],
8691
params: Optional[AbstractParams] = None,
8792
*,
93+
subquery_count: bool = True,
8894
transformer: Optional[AsyncItemsTransformer] = None,
8995
additional_data: AdditionalData = None,
9096
) -> Any:
@@ -97,6 +103,7 @@ def paginate(
97103
query: Any,
98104
params: Optional[AbstractParams] = None,
99105
*,
106+
subquery_count: bool = True,
100107
transformer: Optional[ItemsTransformer] = None,
101108
additional_data: AdditionalData = None,
102109
unique: bool = True,
@@ -108,6 +115,7 @@ def paginate(
108115
session,
109116
query,
110117
params,
118+
subquery_count=subquery_count,
111119
transformer=transformer,
112120
additional_data=additional_data,
113121
unique=unique,

tests/ext/test_sqlalchemy.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,31 @@
11
from typing import Iterator
22

33
from fastapi import Depends, FastAPI
4-
from pytest import fixture
4+
from pytest import fixture, skip
5+
from sqlalchemy import select
56
from sqlalchemy.orm.session import Session
67

78
from fastapi_pagination import LimitOffsetPage, Page, add_pagination
89
from fastapi_pagination.ext.sqlalchemy import paginate
910

1011
from ..base import BasePaginationTestCase
11-
from .utils import sqlalchemy20
12+
from .utils import is_sqlalchemy20, sqlalchemy20
13+
14+
15+
@fixture(
16+
scope="session",
17+
params=[True, False],
18+
ids=["subquery_count", "no_subquery_count"],
19+
)
20+
def use_subquery_count(request):
21+
if request.param and not is_sqlalchemy20:
22+
skip("subquery_count is not supported for SQLAlchemy<2.0")
23+
24+
return request.param
1225

1326

1427
@fixture(scope="session")
15-
def app(sa_user, sa_session, model_cls):
28+
def app(sa_user, sa_session, model_cls, use_subquery_count):
1629
app = FastAPI()
1730

1831
def get_db() -> Iterator[Session]:
@@ -25,7 +38,7 @@ def get_db() -> Iterator[Session]:
2538
@app.get("/default", response_model=Page[model_cls])
2639
@app.get("/limit-offset", response_model=LimitOffsetPage[model_cls])
2740
def route(db: Session = Depends(get_db)):
28-
return paginate(db.query(sa_user))
41+
return paginate(db, select(sa_user), subquery_count=use_subquery_count)
2942

3043
return add_pagination(app)
3144

tests/ext/test_sqlalchemy_cursor.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Iterator, List
22

3+
from _pytest.python_api import raises
34
from fastapi import Depends, FastAPI, status
45
from pydantic import parse_obj_as
56
from pytest import fixture, mark
@@ -30,6 +31,10 @@ def get_db() -> Iterator[Session]:
3031
def route(db: Session = Depends(get_db)):
3132
return paginate(db, select(sa_user).order_by(sa_user.id, sa_user.name))
3233

34+
@app.get("/no-order", response_model=CursorPage[UserOut])
35+
def route_on_order(db: Session = Depends(get_db)):
36+
return paginate(db, select(sa_user))
37+
3338
return add_pagination(app)
3439

3540

@@ -74,3 +79,13 @@ async def test_cursor(app, client, entities):
7479
cursor = data["previous_page"]
7580

7681
assert items == entities
82+
83+
84+
@sqlalchemy20
85+
@mark.asyncio
86+
async def test_no_order(app, client, entities):
87+
with raises(
88+
ValueError,
89+
match="^Cursor pagination requires ordering$",
90+
):
91+
await client.get("/no-order")
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from typing import Iterator
2+
3+
from fastapi import Depends, FastAPI
4+
from pytest import fixture
5+
from sqlalchemy.orm.session import Session
6+
7+
from fastapi_pagination import LimitOffsetPage, Page, add_pagination
8+
from fastapi_pagination.ext.sqlalchemy import paginate
9+
10+
from ..base import BasePaginationTestCase
11+
from .utils import sqlalchemy20
12+
13+
14+
@fixture(scope="session")
15+
def app(sa_user, sa_session, model_cls):
16+
app = FastAPI()
17+
18+
def get_db() -> Iterator[Session]:
19+
db = sa_session()
20+
try:
21+
yield db
22+
finally:
23+
db.close()
24+
25+
@app.get("/default", response_model=Page[model_cls])
26+
@app.get("/limit-offset", response_model=LimitOffsetPage[model_cls])
27+
def route(db: Session = Depends(get_db)):
28+
return paginate(db.query(sa_user))
29+
30+
return add_pagination(app)
31+
32+
33+
@sqlalchemy20
34+
class TestSQLAlchemy(BasePaginationTestCase):
35+
pass

tests/ext/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
11
from pytest import mark
22

3+
try:
4+
from sqlalchemy import __version__ as sqlalchemy_version
5+
6+
is_sqlalchemy20 = tuple(map(int, sqlalchemy_version.split("."))) >= (2, 0, 0)
7+
8+
del sqlalchemy_version
9+
except (ImportError, AttributeError):
10+
is_sqlalchemy20 = False
11+
312
sqlalchemy20 = mark.sqlalchemy20
13+
14+
only_sqlalchemy20 = mark.skipif(
15+
lambda: not is_sqlalchemy20,
16+
reason="Only for SQLAlchemy 2.0",
17+
)

0 commit comments

Comments
 (0)