22from typing import List , Literal , Optional
33
44import numpy as np
5- from sqlalchemy import Select , and_ , asc , desc , func , literal , or_ , select , union_all
5+ from sqlalchemy import Select , and_ , asc , desc , func , literal , nulls_last , or_ , select , union_all
66from sqlalchemy .sql .expression import exists
77
88from letta import system
@@ -430,23 +430,47 @@ def check_supports_structured_output(model: str, tool_rules: List[ToolRule]) ->
430430 return True
431431
432432
433- def _cursor_filter (created_at_col , id_col , ref_created_at , ref_id , forward : bool ):
433+ def _cursor_filter (sort_col , id_col , ref_sort_col , ref_id , forward : bool , nulls_last : bool = False ):
434434 """
435435 Returns a SQLAlchemy filter expression for cursor-based pagination.
436436
437437 If `forward` is True, returns records after the reference.
438438 If `forward` is False, returns records before the reference.
439+
440+ Handles NULL values in the sort column properly when nulls_last is True.
439441 """
440- if forward :
441- return or_ (
442- created_at_col > ref_created_at ,
443- and_ (created_at_col == ref_created_at , id_col > ref_id ),
444- )
442+ if not nulls_last :
443+ # Simple case: no special NULL handling needed
444+ if forward :
445+ return or_ (
446+ sort_col > ref_sort_col ,
447+ and_ (sort_col == ref_sort_col , id_col > ref_id ),
448+ )
449+ else :
450+ return or_ (
451+ sort_col < ref_sort_col ,
452+ and_ (sort_col == ref_sort_col , id_col < ref_id ),
453+ )
454+
455+ # Handle nulls_last case
456+ # TODO: add tests to check if this works for ascending order but nulls are stil last?
457+ if ref_sort_col is None :
458+ # Reference cursor is at a NULL value
459+ if forward :
460+ # Moving forward (e.g. previous) from NULL: either other NULLs with greater IDs or non-NULLs
461+ return or_ (and_ (sort_col .is_ (None ), id_col > ref_id ), sort_col .isnot (None ))
462+ else :
463+ # Moving backward (e.g. next) from NULL: NULLs with smaller IDs
464+ return and_ (sort_col .is_ (None ), id_col < ref_id )
445465 else :
446- return or_ (
447- created_at_col < ref_created_at ,
448- and_ (created_at_col == ref_created_at , id_col < ref_id ),
449- )
466+ # Reference cursor is at a non-NULL value
467+ if forward :
468+ # Moving forward (e.g. previous) from non-NULL: only greater non-NULL values
469+ # (NULLs are at the end, so we don't include them when moving forward from non-NULL)
470+ return and_ (sort_col .isnot (None ), or_ (sort_col > ref_sort_col , and_ (sort_col == ref_sort_col , id_col > ref_id )))
471+ else :
472+ # Moving backward (e.g. next) from non-NULL: smaller non-NULL values or NULLs
473+ return or_ (sort_col .is_ (None ), or_ (sort_col < ref_sort_col , and_ (sort_col == ref_sort_col , id_col < ref_id )))
450474
451475
452476def _apply_pagination (
@@ -455,30 +479,30 @@ def _apply_pagination(
455479 # Determine the sort column
456480 if sort_by == "last_run_completion" :
457481 sort_column = AgentModel .last_run_completion
482+ sort_nulls_last = True # TODO: handle this as a query param eventually
458483 else :
459484 sort_column = AgentModel .created_at
485+ sort_nulls_last = False
460486
461487 if after :
462- if sort_by == "last_run_completion" :
463- result = session .execute (select (AgentModel .last_run_completion , AgentModel .id ).where (AgentModel .id == after )).first ()
464- else :
465- result = session .execute (select (AgentModel .created_at , AgentModel .id ).where (AgentModel .id == after )).first ()
488+ result = session .execute (select (sort_column , AgentModel .id ).where (AgentModel .id == after )).first ()
466489 if result :
467490 after_sort_value , after_id = result
468- query = query .where (_cursor_filter (sort_column , AgentModel .id , after_sort_value , after_id , forward = ascending ))
491+ query = query .where (
492+ _cursor_filter (sort_column , AgentModel .id , after_sort_value , after_id , forward = ascending , nulls_last = sort_nulls_last )
493+ )
469494
470495 if before :
471- if sort_by == "last_run_completion" :
472- result = session .execute (select (AgentModel .last_run_completion , AgentModel .id ).where (AgentModel .id == before )).first ()
473- else :
474- result = session .execute (select (AgentModel .created_at , AgentModel .id ).where (AgentModel .id == before )).first ()
496+ result = session .execute (select (sort_column , AgentModel .id ).where (AgentModel .id == before )).first ()
475497 if result :
476498 before_sort_value , before_id = result
477- query = query .where (_cursor_filter (sort_column , AgentModel .id , before_sort_value , before_id , forward = not ascending ))
499+ query = query .where (
500+ _cursor_filter (sort_column , AgentModel .id , before_sort_value , before_id , forward = not ascending , nulls_last = sort_nulls_last )
501+ )
478502
479503 # Apply ordering
480504 order_fn = asc if ascending else desc
481- query = query .order_by (order_fn (sort_column ), order_fn (AgentModel .id ))
505+ query = query .order_by (nulls_last ( order_fn ( sort_column )) if sort_nulls_last else order_fn (sort_column ), order_fn (AgentModel .id ))
482506 return query
483507
484508
@@ -488,30 +512,30 @@ async def _apply_pagination_async(
488512 # Determine the sort column
489513 if sort_by == "last_run_completion" :
490514 sort_column = AgentModel .last_run_completion
515+ sort_nulls_last = True # TODO: handle this as a query param eventually
491516 else :
492517 sort_column = AgentModel .created_at
518+ sort_nulls_last = False
493519
494520 if after :
495- if sort_by == "last_run_completion" :
496- result = (await session .execute (select (AgentModel .last_run_completion , AgentModel .id ).where (AgentModel .id == after ))).first ()
497- else :
498- result = (await session .execute (select (AgentModel .created_at , AgentModel .id ).where (AgentModel .id == after ))).first ()
521+ result = (await session .execute (select (sort_column , AgentModel .id ).where (AgentModel .id == after ))).first ()
499522 if result :
500523 after_sort_value , after_id = result
501- query = query .where (_cursor_filter (sort_column , AgentModel .id , after_sort_value , after_id , forward = ascending ))
524+ query = query .where (
525+ _cursor_filter (sort_column , AgentModel .id , after_sort_value , after_id , forward = ascending , nulls_last = sort_nulls_last )
526+ )
502527
503528 if before :
504- if sort_by == "last_run_completion" :
505- result = (await session .execute (select (AgentModel .last_run_completion , AgentModel .id ).where (AgentModel .id == before ))).first ()
506- else :
507- result = (await session .execute (select (AgentModel .created_at , AgentModel .id ).where (AgentModel .id == before ))).first ()
529+ result = (await session .execute (select (sort_column , AgentModel .id ).where (AgentModel .id == before ))).first ()
508530 if result :
509531 before_sort_value , before_id = result
510- query = query .where (_cursor_filter (sort_column , AgentModel .id , before_sort_value , before_id , forward = not ascending ))
532+ query = query .where (
533+ _cursor_filter (sort_column , AgentModel .id , before_sort_value , before_id , forward = not ascending , nulls_last = sort_nulls_last )
534+ )
511535
512536 # Apply ordering
513537 order_fn = asc if ascending else desc
514- query = query .order_by (order_fn (sort_column ), order_fn (AgentModel .id ))
538+ query = query .order_by (nulls_last ( order_fn ( sort_column )) if sort_nulls_last else order_fn (sort_column ), order_fn (AgentModel .id ))
515539 return query
516540
517541
0 commit comments