Skip to content

Commit 0e28123

Browse files
committed
Add pymongo aggregate funcs
1 parent 12a04ea commit 0e28123

File tree

4 files changed

+204
-7
lines changed

4 files changed

+204
-7
lines changed

fastapi_pagination/ext/pymongo.py

Lines changed: 116 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,25 @@
22

33
__all__ = [
44
"apaginate",
5+
"apaginate_aggregate",
56
"paginate",
7+
"paginate_aggregate",
68
]
79

810

911
from collections.abc import Mapping, Sequence
10-
from typing import Any, Optional, TypeVar
12+
from typing import Any, Literal, Optional, TypeVar, Union
1113

1214
from pymongo.asynchronous.collection import AsyncCollection
1315
from pymongo.collection import Collection
1416

1517
from fastapi_pagination.bases import AbstractParams
1618
from fastapi_pagination.config import Config
17-
from fastapi_pagination.flow import flow_expr, run_async_flow, run_sync_flow
18-
from fastapi_pagination.flows import generic_flow
19+
from fastapi_pagination.ext.utils import get_mongo_pipeline_filter_end
20+
from fastapi_pagination.flow import flow, flow_expr, run_async_flow, run_sync_flow
21+
from fastapi_pagination.flows import create_page_flow, generic_flow
1922
from fastapi_pagination.types import AdditionalData, ItemsTransformer, SyncItemsTransformer
23+
from fastapi_pagination.utils import verify_params
2024

2125
T = TypeVar("T", bound=Mapping[str, Any])
2226

@@ -89,3 +93,112 @@ async def apaginate(
8993
config=config,
9094
)
9195
)
96+
97+
98+
@flow
99+
def _aggregate_flow(
100+
is_async: bool,
101+
collection: Union[Collection[T], AsyncCollection[T]],
102+
aggregate_pipeline: Optional[list[dict[Any, Any]]] = None,
103+
params: Optional[AbstractParams] = None,
104+
*,
105+
transformer: Optional[ItemsTransformer] = None,
106+
additional_data: Optional[AdditionalData] = None,
107+
aggregation_filter_end: Optional[Union[int, Literal["auto"]]] = None,
108+
config: Optional[Config] = None,
109+
) -> Any:
110+
params, raw_params = verify_params(params, "limit-offset")
111+
aggregate_pipeline = aggregate_pipeline or []
112+
113+
paginate_data = []
114+
if raw_params.limit is not None:
115+
paginate_data.append({"$limit": raw_params.limit + (raw_params.offset or 0)})
116+
if raw_params.offset is not None:
117+
paginate_data.append({"$skip": raw_params.offset})
118+
119+
if aggregation_filter_end is not None:
120+
if aggregation_filter_end == "auto":
121+
aggregation_filter_end = get_mongo_pipeline_filter_end(aggregate_pipeline)
122+
transform_part = aggregate_pipeline[:aggregation_filter_end]
123+
aggregate_pipeline = aggregate_pipeline[aggregation_filter_end:]
124+
paginate_data.extend(transform_part)
125+
126+
cursor = yield collection.aggregate(
127+
[
128+
*aggregate_pipeline,
129+
{
130+
"$facet": {
131+
"metadata": [{"$count": "total"}],
132+
"data": paginate_data,
133+
},
134+
},
135+
],
136+
)
137+
138+
data, *_ = yield cursor.to_list(length=None)
139+
140+
items = data["data"]
141+
try:
142+
total = data["metadata"][0]["total"]
143+
except IndexError:
144+
total = 0
145+
146+
page = yield from create_page_flow(
147+
items,
148+
params,
149+
total=total,
150+
transformer=transformer,
151+
additional_data=additional_data,
152+
config=config,
153+
async_=is_async,
154+
)
155+
156+
return page
157+
158+
159+
async def apaginate_aggregate(
160+
collection: AsyncCollection[T],
161+
aggregate_pipeline: Optional[list[dict[Any, Any]]] = None,
162+
params: Optional[AbstractParams] = None,
163+
*,
164+
transformer: Optional[ItemsTransformer] = None,
165+
additional_data: Optional[AdditionalData] = None,
166+
aggregation_filter_end: Optional[Union[int, Literal["auto"]]] = None,
167+
config: Optional[Config] = None,
168+
) -> Any:
169+
return await run_async_flow(
170+
_aggregate_flow(
171+
is_async=True,
172+
collection=collection,
173+
aggregate_pipeline=aggregate_pipeline or [],
174+
params=params,
175+
transformer=transformer,
176+
additional_data=additional_data,
177+
aggregation_filter_end=aggregation_filter_end,
178+
config=config,
179+
)
180+
)
181+
182+
183+
def paginate_aggregate(
184+
collection: Collection[T],
185+
aggregate_pipeline: Optional[list[dict[Any, Any]]] = None,
186+
params: Optional[AbstractParams] = None,
187+
*,
188+
transformer: Optional[SyncItemsTransformer] = None,
189+
additional_data: Optional[AdditionalData] = None,
190+
aggregation_filter_end: Optional[Union[int, Literal["auto"]]] = None,
191+
config: Optional[Config] = None,
192+
) -> Any:
193+
return run_sync_flow(
194+
_aggregate_flow(
195+
is_async=False,
196+
collection=collection,
197+
aggregate_pipeline=aggregate_pipeline or [],
198+
params=params,
199+
transformer=transformer,
200+
additional_data=additional_data,
201+
aggregation_filter_end=aggregation_filter_end,
202+
config=config,
203+
)
204+
)

tests/ext/test_pymongo.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from pymongo import AsyncMongoClient, MongoClient
33
from pytest_asyncio import fixture as async_fixture
44

5-
from fastapi_pagination.ext.pymongo import apaginate, paginate
5+
from fastapi_pagination.ext.pymongo import apaginate, apaginate_aggregate, paginate, paginate_aggregate
66
from tests.base import BasePaginationTestSuite, async_sync_testsuite
77
from tests.utils import maybe_async
88

@@ -14,9 +14,8 @@ def database_url(mongodb_url) -> str:
1414
return mongodb_url
1515

1616

17-
@mongodb_test
1817
@async_sync_testsuite
19-
class TestPymongo(BasePaginationTestSuite):
18+
class _BasePymongoSuite:
2019
@async_fixture(scope="session")
2120
async def db_client(self, database_url, is_async_db):
2221
if is_async_db:
@@ -26,6 +25,9 @@ async def db_client(self, database_url, is_async_db):
2625
with MongoClient(database_url) as client:
2726
yield client
2827

28+
29+
@mongodb_test
30+
class TestPymongo(_BasePymongoSuite, BasePaginationTestSuite):
2931
@pytest.fixture(scope="session")
3032
def paginate_func(self, is_async_db):
3133
return apaginate if is_async_db else paginate
@@ -39,3 +41,40 @@ async def route():
3941
return await maybe_async(paginate_func(db_client.test.users))
4042

4143
return builder.build()
44+
45+
46+
@mongodb_test
47+
class TestPymongoAggregate(_BasePymongoSuite, BasePaginationTestSuite):
48+
@pytest.fixture(scope="session")
49+
def paginate_func(self, is_async_db):
50+
return apaginate_aggregate if is_async_db else paginate_aggregate
51+
52+
@async_fixture(scope="session")
53+
async def entities(self, db_client):
54+
cursor = await maybe_async(
55+
db_client.test_agg.users.aggregate(
56+
[
57+
{"$group": {"_id": "$name", "name": {"$first": "$name"}}},
58+
{"$sort": {"name": 1}},
59+
],
60+
)
61+
)
62+
return await maybe_async(cursor.to_list(length=None))
63+
64+
@pytest.fixture(scope="session")
65+
def app(self, builder, db_client, paginate_func):
66+
builder = builder.new()
67+
68+
@builder.both.default
69+
async def route():
70+
return await maybe_async(
71+
paginate_func(
72+
db_client.test_agg.users,
73+
[
74+
{"$group": {"_id": "$name", "name": {"$first": "$name"}}},
75+
{"$sort": {"name": 1}},
76+
],
77+
)
78+
)
79+
80+
return builder.build()
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import pytest
2+
from motor.motor_asyncio import AsyncIOMotorClient
3+
from pytest_asyncio import fixture as async_fixture
4+
5+
from fastapi_pagination.ext.motor import apaginate_aggregate
6+
from tests.base import BasePaginationTestSuite
7+
8+
from .utils import mongodb_test
9+
10+
11+
@pytest.fixture(scope="session")
12+
def db_client(database_url):
13+
client = AsyncIOMotorClient(database_url)
14+
yield client
15+
client.close()
16+
17+
18+
@mongodb_test
19+
class TestMotorAggregate(BasePaginationTestSuite):
20+
@async_fixture(scope="session")
21+
async def entities(self, db_client, raw_data):
22+
await db_client.test_agg.users.delete_many({})
23+
await db_client.test_agg.users.insert_many(raw_data)
24+
25+
cursor = db_client.test_agg.users.aggregate(
26+
[
27+
{"$group": {"_id": "$name", "name": {"$first": "$name"}}},
28+
{"$sort": {"name": 1}},
29+
],
30+
)
31+
return await cursor.to_list(length=None)
32+
33+
@pytest.fixture(scope="session")
34+
def app(self, builder, db_client):
35+
@builder.both.default
36+
async def route():
37+
return await apaginate_aggregate(
38+
db_client.test_agg.users,
39+
[
40+
{"$group": {"_id": "$name", "name": {"$first": "$name"}}},
41+
{"$sort": {"name": 1}},
42+
],
43+
)
44+
45+
return builder.build()

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)