Skip to content

Commit 38480f2

Browse files
jnjpngJin Peng
andauthored
fix: fix null precedence and pagination for list agents (#2927)
Co-authored-by: Jin Peng <[email protected]>
1 parent e78e7d5 commit 38480f2

File tree

1 file changed

+57
-33
lines changed

1 file changed

+57
-33
lines changed

letta/services/helpers/agent_manager_helper.py

Lines changed: 57 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import List, Literal, Optional
33

44
import numpy as np
5-
from sqlalchemy import Select, and_, asc, desc, func, literal, or_, select, union_all
5+
from sqlalchemy import Select, and_, asc, desc, func, literal, nulls_last, or_, select, union_all
66
from sqlalchemy.sql.expression import exists
77

88
from letta import system
@@ -430,23 +430,47 @@ def check_supports_structured_output(model: str, tool_rules: List[ToolRule]) ->
430430
return True
431431

432432

433-
def _cursor_filter(created_at_col, id_col, ref_created_at, ref_id, forward: bool):
433+
def _cursor_filter(sort_col, id_col, ref_sort_col, ref_id, forward: bool, nulls_last: bool = False):
434434
"""
435435
Returns a SQLAlchemy filter expression for cursor-based pagination.
436436
437437
If `forward` is True, returns records after the reference.
438438
If `forward` is False, returns records before the reference.
439+
440+
Handles NULL values in the sort column properly when nulls_last is True.
439441
"""
440-
if forward:
441-
return or_(
442-
created_at_col > ref_created_at,
443-
and_(created_at_col == ref_created_at, id_col > ref_id),
444-
)
442+
if not nulls_last:
443+
# Simple case: no special NULL handling needed
444+
if forward:
445+
return or_(
446+
sort_col > ref_sort_col,
447+
and_(sort_col == ref_sort_col, id_col > ref_id),
448+
)
449+
else:
450+
return or_(
451+
sort_col < ref_sort_col,
452+
and_(sort_col == ref_sort_col, id_col < ref_id),
453+
)
454+
455+
# Handle nulls_last case
456+
# TODO: add tests to check if this works for ascending order but nulls are stil last?
457+
if ref_sort_col is None:
458+
# Reference cursor is at a NULL value
459+
if forward:
460+
# Moving forward (e.g. previous) from NULL: either other NULLs with greater IDs or non-NULLs
461+
return or_(and_(sort_col.is_(None), id_col > ref_id), sort_col.isnot(None))
462+
else:
463+
# Moving backward (e.g. next) from NULL: NULLs with smaller IDs
464+
return and_(sort_col.is_(None), id_col < ref_id)
445465
else:
446-
return or_(
447-
created_at_col < ref_created_at,
448-
and_(created_at_col == ref_created_at, id_col < ref_id),
449-
)
466+
# Reference cursor is at a non-NULL value
467+
if forward:
468+
# Moving forward (e.g. previous) from non-NULL: only greater non-NULL values
469+
# (NULLs are at the end, so we don't include them when moving forward from non-NULL)
470+
return and_(sort_col.isnot(None), or_(sort_col > ref_sort_col, and_(sort_col == ref_sort_col, id_col > ref_id)))
471+
else:
472+
# Moving backward (e.g. next) from non-NULL: smaller non-NULL values or NULLs
473+
return or_(sort_col.is_(None), or_(sort_col < ref_sort_col, and_(sort_col == ref_sort_col, id_col < ref_id)))
450474

451475

452476
def _apply_pagination(
@@ -455,30 +479,30 @@ def _apply_pagination(
455479
# Determine the sort column
456480
if sort_by == "last_run_completion":
457481
sort_column = AgentModel.last_run_completion
482+
sort_nulls_last = True # TODO: handle this as a query param eventually
458483
else:
459484
sort_column = AgentModel.created_at
485+
sort_nulls_last = False
460486

461487
if after:
462-
if sort_by == "last_run_completion":
463-
result = session.execute(select(AgentModel.last_run_completion, AgentModel.id).where(AgentModel.id == after)).first()
464-
else:
465-
result = session.execute(select(AgentModel.created_at, AgentModel.id).where(AgentModel.id == after)).first()
488+
result = session.execute(select(sort_column, AgentModel.id).where(AgentModel.id == after)).first()
466489
if result:
467490
after_sort_value, after_id = result
468-
query = query.where(_cursor_filter(sort_column, AgentModel.id, after_sort_value, after_id, forward=ascending))
491+
query = query.where(
492+
_cursor_filter(sort_column, AgentModel.id, after_sort_value, after_id, forward=ascending, nulls_last=sort_nulls_last)
493+
)
469494

470495
if before:
471-
if sort_by == "last_run_completion":
472-
result = session.execute(select(AgentModel.last_run_completion, AgentModel.id).where(AgentModel.id == before)).first()
473-
else:
474-
result = session.execute(select(AgentModel.created_at, AgentModel.id).where(AgentModel.id == before)).first()
496+
result = session.execute(select(sort_column, AgentModel.id).where(AgentModel.id == before)).first()
475497
if result:
476498
before_sort_value, before_id = result
477-
query = query.where(_cursor_filter(sort_column, AgentModel.id, before_sort_value, before_id, forward=not ascending))
499+
query = query.where(
500+
_cursor_filter(sort_column, AgentModel.id, before_sort_value, before_id, forward=not ascending, nulls_last=sort_nulls_last)
501+
)
478502

479503
# Apply ordering
480504
order_fn = asc if ascending else desc
481-
query = query.order_by(order_fn(sort_column), order_fn(AgentModel.id))
505+
query = query.order_by(nulls_last(order_fn(sort_column)) if sort_nulls_last else order_fn(sort_column), order_fn(AgentModel.id))
482506
return query
483507

484508

@@ -488,30 +512,30 @@ async def _apply_pagination_async(
488512
# Determine the sort column
489513
if sort_by == "last_run_completion":
490514
sort_column = AgentModel.last_run_completion
515+
sort_nulls_last = True # TODO: handle this as a query param eventually
491516
else:
492517
sort_column = AgentModel.created_at
518+
sort_nulls_last = False
493519

494520
if after:
495-
if sort_by == "last_run_completion":
496-
result = (await session.execute(select(AgentModel.last_run_completion, AgentModel.id).where(AgentModel.id == after))).first()
497-
else:
498-
result = (await session.execute(select(AgentModel.created_at, AgentModel.id).where(AgentModel.id == after))).first()
521+
result = (await session.execute(select(sort_column, AgentModel.id).where(AgentModel.id == after))).first()
499522
if result:
500523
after_sort_value, after_id = result
501-
query = query.where(_cursor_filter(sort_column, AgentModel.id, after_sort_value, after_id, forward=ascending))
524+
query = query.where(
525+
_cursor_filter(sort_column, AgentModel.id, after_sort_value, after_id, forward=ascending, nulls_last=sort_nulls_last)
526+
)
502527

503528
if before:
504-
if sort_by == "last_run_completion":
505-
result = (await session.execute(select(AgentModel.last_run_completion, AgentModel.id).where(AgentModel.id == before))).first()
506-
else:
507-
result = (await session.execute(select(AgentModel.created_at, AgentModel.id).where(AgentModel.id == before))).first()
529+
result = (await session.execute(select(sort_column, AgentModel.id).where(AgentModel.id == before))).first()
508530
if result:
509531
before_sort_value, before_id = result
510-
query = query.where(_cursor_filter(sort_column, AgentModel.id, before_sort_value, before_id, forward=not ascending))
532+
query = query.where(
533+
_cursor_filter(sort_column, AgentModel.id, before_sort_value, before_id, forward=not ascending, nulls_last=sort_nulls_last)
534+
)
511535

512536
# Apply ordering
513537
order_fn = asc if ascending else desc
514-
query = query.order_by(order_fn(sort_column), order_fn(AgentModel.id))
538+
query = query.order_by(nulls_last(order_fn(sort_column)) if sort_nulls_last else order_fn(sort_column), order_fn(AgentModel.id))
515539
return query
516540

517541

0 commit comments

Comments
 (0)