diff --git a/src/google/adk/sessions/base_session_service.py b/src/google/adk/sessions/base_session_service.py index f2f6f9f22d..2b6dd8bb00 100644 --- a/src/google/adk/sessions/base_session_service.py +++ b/src/google/adk/sessions/base_session_service.py @@ -15,6 +15,7 @@ from __future__ import annotations import abc +import copy from typing import Any from typing import Optional @@ -102,6 +103,94 @@ async def delete_session( ) -> None: """Deletes a session.""" + @abc.abstractmethod + async def clone_session( + self, + *, + app_name: str, + src_user_id: str, + src_session_id: Optional[str] = None, + new_user_id: Optional[str] = None, + new_session_id: Optional[str] = None, + ) -> Session: + """Clones session(s) and their events to a new session. + + This method supports two modes: + + 1. Single session clone: When `src_session_id` is provided, clones that + specific session to the new session. + + 2. All sessions clone: When `src_session_id` is NOT provided, finds all + sessions for `src_user_id` and merges ALL their events into a single + new session. + + Events are automatically deduplicated by event ID - only the first + occurrence of each event ID is kept. + + Args: + app_name: The name of the app. + src_user_id: The source user ID whose session(s) to clone. + src_session_id: The source session ID to clone. If not provided, all + sessions for the source user will be merged into one new session. + new_user_id: The user ID for the new session. If not provided, uses the + same user_id as the source. + new_session_id: The session ID for the new session. If not provided, a + new ID will be auto-generated (UUID4). + + Returns: + The newly created session with cloned events. + + Raises: + ValueError: If no source sessions are found. + AlreadyExistsError: If a session with new_session_id already exists. + """ + + def _prepare_sessions_for_cloning( + self, source_sessions: list[Session] + ) -> tuple[dict[str, Any], list[Event]]: + """Prepares source sessions for cloning by merging states and deduplicating events. + + This is a shared helper method used by all clone_session implementations + to ensure consistent behavior across different session service backends. + + The method: + 1. Sorts sessions by last_update_time for deterministic state merging + 2. Merges states from all sessions (later sessions overwrite earlier ones) + 3. Collects all events, sorts by timestamp, and deduplicates by event ID + + Args: + source_sessions: List of source sessions to process. + + Returns: + A tuple of (merged_state, deduplicated_events): + - merged_state: Combined state from all sessions (deep copied) + - deduplicated_events: Chronologically sorted, deduplicated events + """ + # Sort sessions by update time for deterministic state merging + # Use sorted() to avoid modifying the input list in-place + sorted_sessions = sorted(source_sessions, key=lambda s: s.last_update_time) + + # Merge states from all source sessions + merged_state: dict[str, Any] = {} + for session in sorted_sessions: + merged_state.update(copy.deepcopy(session.state)) + + # Collect all events, sort by timestamp, then deduplicate + # to ensure chronological "first occurrence wins" + all_source_events: list[Event] = [] + for session in sorted_sessions: + all_source_events.extend(session.events) + all_source_events.sort(key=lambda e: e.timestamp) + + all_events: list[Event] = [] + seen_event_ids: set[str] = set() + for event in all_source_events: + if event.id not in seen_event_ids: + seen_event_ids.add(event.id) + all_events.append(event) + + return merged_state, all_events + async def append_event(self, session: Session, event: Event) -> Event: """Appends an event to a session object.""" if event.partial: diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index 863bbfa861..09de7c6417 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -415,6 +415,98 @@ async def delete_session( await sql_session.execute(stmt) await sql_session.commit() + @override + async def clone_session( + self, + *, + app_name: str, + src_user_id: str, + src_session_id: Optional[str] = None, + new_user_id: Optional[str] = None, + new_session_id: Optional[str] = None, + ) -> Session: + await self._prepare_tables() + + # Use source values as defaults + new_user_id = new_user_id or src_user_id + + schema = self._get_schema_classes() + + # Collect source sessions and their events + source_sessions = [] + if src_session_id: + # Single session clone - use get_session (no N+1 issue) + session = await self.get_session( + app_name=app_name, + user_id=src_user_id, + session_id=src_session_id, + ) + if not session: + raise ValueError( + f"Source session {src_session_id} not found for user {src_user_id}." + ) + source_sessions.append(session) + else: + # All sessions clone - optimized to avoid N+1 query problem + # Step 1: Get all sessions with state (no events) + list_response = await self.list_sessions( + app_name=app_name, user_id=src_user_id + ) + if not list_response.sessions: + raise ValueError(f"No sessions found for user {src_user_id}.") + + session_ids = [sess.id for sess in list_response.sessions] + + # Step 2: Fetch ALL events for all session IDs in a single query + async with self.database_session_factory() as sql_session: + stmt = ( + select(schema.StorageEvent) + .filter(schema.StorageEvent.app_name == app_name) + .filter(schema.StorageEvent.user_id == src_user_id) + .filter(schema.StorageEvent.session_id.in_(session_ids)) + .order_by(schema.StorageEvent.timestamp.asc()) + ) + result = await sql_session.execute(stmt) + all_storage_events = result.scalars().all() + + # Step 3: Map events back to sessions + events_by_session_id = {} + for storage_event in all_storage_events: + events_by_session_id.setdefault(storage_event.session_id, []).append( + storage_event.to_event() + ) + + # Build full session objects with events + for sess in list_response.sessions: + sess.events = events_by_session_id.get(sess.id, []) + source_sessions.append(sess) + + # Use shared helper for state merging and event deduplication + merged_state, all_events = self._prepare_sessions_for_cloning( + source_sessions + ) + + # Create the new session (new_session_id=None triggers UUID4 generation) + new_session = await self.create_session( + app_name=app_name, + user_id=new_user_id, + state=merged_state, + session_id=new_session_id, + ) + + # Copy events to the new session using bulk insert + async with self.database_session_factory() as sql_session: + new_storage_events = [ + schema.StorageEvent.from_event(new_session, copy.deepcopy(event)) + for event in all_events + ] + sql_session.add_all(new_storage_events) + await sql_session.commit() + + # Return the new session with events (avoid redundant DB query) + new_session.events = all_events + return new_session + @override async def append_event(self, session: Session, event: Event) -> Event: await self._prepare_tables() @@ -436,8 +528,8 @@ async def append_event(self, session: Session, event: Event) -> Event: if storage_session.update_timestamp_tz > session.last_update_time: raise ValueError( "The last_update_time provided in the session object" - f" {datetime.fromtimestamp(session.last_update_time):'%Y-%m-%d %H:%M:%S'} is" - " earlier than the update_time in the storage_session" + f" {datetime.fromtimestamp(session.last_update_time):'%Y-%m-%d %H:%M:%S'}" + " is earlier than the update_time in the storage_session" f" {datetime.fromtimestamp(storage_session.update_timestamp_tz):'%Y-%m-%d %H:%M:%S'}." " Please check if it is a stale session." ) diff --git a/src/google/adk/sessions/in_memory_session_service.py b/src/google/adk/sessions/in_memory_session_service.py index 6ba7f0bb01..47c900cf81 100644 --- a/src/google/adk/sessions/in_memory_session_service.py +++ b/src/google/adk/sessions/in_memory_session_service.py @@ -286,6 +286,104 @@ def _delete_session_impl( self.sessions[app_name][user_id].pop(session_id) + @override + async def clone_session( + self, + *, + app_name: str, + src_user_id: str, + src_session_id: Optional[str] = None, + new_user_id: Optional[str] = None, + new_session_id: Optional[str] = None, + ) -> Session: + return self._clone_session_impl( + app_name=app_name, + src_user_id=src_user_id, + src_session_id=src_session_id, + new_user_id=new_user_id, + new_session_id=new_session_id, + ) + + def _clone_session_impl( + self, + *, + app_name: str, + src_user_id: str, + src_session_id: Optional[str] = None, + new_user_id: Optional[str] = None, + new_session_id: Optional[str] = None, + ) -> Session: + # Use source values as defaults + new_user_id = new_user_id or src_user_id + + # Collect source sessions and their events + source_sessions = [] + if src_session_id: + # Single session clone + session = self._get_session_impl( + app_name=app_name, + user_id=src_user_id, + session_id=src_session_id, + ) + if not session: + raise ValueError( + f'Source session {src_session_id} not found for user {src_user_id}.' + ) + source_sessions.append(session) + else: + # All sessions clone - optimized direct access to avoid N+1 lookups + if ( + app_name not in self.sessions + or src_user_id not in self.sessions[app_name] + ): + raise ValueError(f'No sessions found for user {src_user_id}.') + + user_sessions = self.sessions[app_name][src_user_id] + if not user_sessions: + raise ValueError(f'No sessions found for user {src_user_id}.') + + # Directly access storage sessions and build full session objects + for session_id, storage_session in user_sessions.items(): + # Deep copy the session to avoid mutations + copied_session = copy.deepcopy(storage_session) + # Merge state with app and user state + copied_session = self._merge_state( + app_name, src_user_id, copied_session + ) + source_sessions.append(copied_session) + + # Use shared helper for state merging and event deduplication + merged_state, all_events = self._prepare_sessions_for_cloning( + source_sessions + ) + # Deep copy events for in-memory storage isolation + all_events = [copy.deepcopy(event) for event in all_events] + + # Create the new session (new_session_id=None triggers UUID4 generation) + new_session = self._create_session_impl( + app_name=app_name, + user_id=new_user_id, + state=merged_state, + session_id=new_session_id, + ) + + # Get latest update time explicitly (don't rely on sorting side effects) + latest_update_time = ( + max(s.last_update_time for s in source_sessions) + if source_sessions + else 0.0 + ) + + # Get the storage session and set events + storage_session = self.sessions[app_name][new_user_id][new_session.id] + storage_session.events = all_events + storage_session.last_update_time = latest_update_time + + # Return the new session with events (avoid redundant lookup) + new_session.events = all_events + new_session.last_update_time = latest_update_time + return new_session + @override async def append_event(self, session: Session, event: Event) -> Event: if event.partial: diff --git a/src/google/adk/sessions/sqlite_session_service.py b/src/google/adk/sessions/sqlite_session_service.py index 1d9516ec73..6592709594 100644 --- a/src/google/adk/sessions/sqlite_session_service.py +++ b/src/google/adk/sessions/sqlite_session_service.py @@ -356,6 +356,107 @@ async def delete_session( ) await db.commit() + @override + async def clone_session( + self, + *, + app_name: str, + src_user_id: str, + src_session_id: Optional[str] = None, + new_user_id: Optional[str] = None, + new_session_id: Optional[str] = None, + ) -> Session: + # Use source values as defaults + new_user_id = new_user_id or src_user_id + + # Collect source sessions and their events + source_sessions = [] + if src_session_id: + # Single session clone - use get_session (no N+1 issue) + session = await self.get_session( + app_name=app_name, + user_id=src_user_id, + session_id=src_session_id, + ) + if not session: + raise ValueError( + f"Source session {src_session_id} not found for user {src_user_id}." + ) + source_sessions.append(session) + else: + # All sessions clone - optimized to avoid N+1 query problem + # Step 1: Get all sessions with state (no events) + list_response = await self.list_sessions( + app_name=app_name, user_id=src_user_id + ) + if not list_response.sessions: + raise ValueError(f"No sessions found for user {src_user_id}.") + + session_ids = [sess.id for sess in list_response.sessions] + + # Step 2: Fetch ALL events for all session IDs in a single query + async with self._get_db_connection() as db: + placeholders = ",".join("?" * len(session_ids)) + query = f""" + SELECT session_id, event_data FROM events + WHERE app_name=? AND user_id=? AND session_id IN ({placeholders}) + ORDER BY timestamp ASC + """ + params = [app_name, src_user_id] + session_ids + event_rows = await db.execute_fetchall(query, params) + + # Step 3: Map events back to sessions + events_by_session_id = {} + for row in event_rows: + events_by_session_id.setdefault(row["session_id"], []).append( + Event.model_validate_json(row["event_data"]) + ) + + # Build full session objects with events + for sess in list_response.sessions: + sess.events = events_by_session_id.get(sess.id, []) + source_sessions.append(sess) + + # Use shared helper for state merging and event deduplication + merged_state, all_events = self._prepare_sessions_for_cloning( + source_sessions + ) + + # Create the new session (new_session_id=None triggers UUID4 generation) + new_session = await self.create_session( + app_name=app_name, + user_id=new_user_id, + state=merged_state, + session_id=new_session_id, + ) + + # Copy events to the new session using bulk insert + async with self._get_db_connection() as db: + event_params = [] + for event in all_events: + cloned_event = copy.deepcopy(event) + event_params.append(( + cloned_event.id, + new_session.app_name, + new_session.user_id, + new_session.id, + cloned_event.invocation_id, + cloned_event.timestamp, + cloned_event.model_dump_json(exclude_none=True), + )) + await db.executemany( + """ + INSERT INTO events (id, app_name, user_id, session_id, invocation_id, timestamp, event_data) + VALUES (?, ?, ?, ?, ?, ?, ?) + """, + event_params, + ) + await db.commit() + + # Return the new session with events (avoid redundant DB query) + new_session.events = all_events + return new_session + @override async def append_event(self, session: Session, event: Event) -> Event: if event.partial: diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index 3f9e514e03..decea6bd13 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -14,6 +14,7 @@ from __future__ import annotations import asyncio +import copy import datetime import json import logging @@ -244,6 +245,89 @@ async def delete_session( logger.error('Error deleting session %s: %s', session_id, e) raise + @override + async def clone_session( + self, + *, + app_name: str, + src_user_id: str, + src_session_id: Optional[str] = None, + new_user_id: Optional[str] = None, + new_session_id: Optional[str] = None, + ) -> Session: + if new_session_id: + raise ValueError( + 'User-provided session id (new_session_id) is not supported for' + ' VertexAISessionService. The session ID is auto-generated by the' + ' Vertex AI backend.' + ) + + # Use source values as defaults + new_user_id = new_user_id or src_user_id + + # Collect source sessions and their events + source_sessions = [] + if src_session_id: + # Single session clone + session = await self.get_session( + app_name=app_name, + user_id=src_user_id, + session_id=src_session_id, + ) + if not session: + raise ValueError( + f'Source session {src_session_id} not found for user {src_user_id}.' + ) + source_sessions.append(session) + else: + # All sessions clone - get all sessions for the user + list_response = await self.list_sessions( + app_name=app_name, user_id=src_user_id + ) + if not list_response.sessions: + raise ValueError(f'No sessions found for user {src_user_id}.') + + # Fetch all sessions with events in parallel using asyncio.gather + # (Vertex AI API doesn't support batch retrieval, so we parallelize) + fetch_tasks = [ + self.get_session( + app_name=app_name, + user_id=src_user_id, + session_id=sess.id, + ) + for sess in list_response.sessions + ] + fetched_sessions = await asyncio.gather(*fetch_tasks) + source_sessions = [s for s in fetched_sessions if s is not None] + if not source_sessions and list_response.sessions: + raise ValueError( + f'Could not retrieve any source sessions for user {src_user_id}. ' + 'They may have been deleted after being listed.' + ) + + # Use shared helper for state merging and event deduplication + merged_state, all_events = self._prepare_sessions_for_cloning( + source_sessions + ) + + # Create the new session (ID is auto-generated by Vertex AI backend) + new_session = await self.create_session( + app_name=app_name, + user_id=new_user_id, + state=merged_state, + ) + + # Copy events to the new session (deep copy to avoid mutations) + # Note: Each event requires a separate API call to Vertex AI. For sessions + # with many events, this may be slow. Vertex AI does not currently support + # batch event appending. + for event in all_events: + await self.append_event(new_session, copy.deepcopy(event)) + + # Return the new session with events (already populated via append_event) + new_session.events = all_events + return new_session + @override async def append_event(self, session: Session, event: Event) -> Event: # Update the in-memory session. diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index 96d2f38726..d909951679 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -603,3 +603,276 @@ async def test_partial_events_are_not_persisted(session_service): app_name=app_name, user_id=user_id, session_id=session.id ) assert len(session_got.events) == 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'new_user_id,new_session_id', + [ + (None, None), # Basic clone - same user, auto-generated session ID + ('user2', None), # Different user, auto-generated session ID + (None, 'custom_session'), # Same user, custom session ID + ('user2', 'custom_session'), # Different user and custom session ID + ], + ids=['basic', 'different_user', 'custom_session_id', 'both_custom'], +) +async def test_clone_session_single_session( + session_service, new_user_id, new_session_id +): + """Test clone_session with various parameter combinations.""" + app_name = 'my_app' + source_user_id = 'user' + + # Create source session with events + source_session = await session_service.create_session( + app_name=app_name, user_id=source_user_id, state={'key': 'value'} + ) + event1 = Event(invocation_id='inv1', author='user') + event2 = Event(invocation_id='inv2', author='model') + await session_service.append_event(source_session, event1) + await session_service.append_event(source_session, event2) + + # Clone the session + cloned_session = await session_service.clone_session( + app_name=app_name, + src_user_id=source_user_id, + src_session_id=source_session.id, + new_user_id=new_user_id, + new_session_id=new_session_id, + ) + + # Determine expected values + expected_user_id = new_user_id if new_user_id else source_user_id + + # Verify the cloned session + assert cloned_session is not None + assert cloned_session.id != source_session.id + assert cloned_session.app_name == app_name + assert cloned_session.user_id == expected_user_id + assert cloned_session.state == {'key': 'value'} + assert len(cloned_session.events) == 2 + assert cloned_session.events[0].invocation_id == event1.invocation_id + assert cloned_session.events[1].invocation_id == event2.invocation_id + + # Verify custom session ID if provided + if new_session_id: + assert cloned_session.id == new_session_id + + # Verify the cloned session is persisted correctly + fetched_session = await session_service.get_session( + app_name=app_name, user_id=expected_user_id, session_id=cloned_session.id + ) + assert fetched_session is not None + assert fetched_session.user_id == expected_user_id + + +@pytest.mark.asyncio +async def test_clone_session_with_existing_id_raises_error(session_service): + """Test that clone_session raises error if destination session_id exists.""" + app_name = 'my_app' + user_id = 'user' + + # Create source and target sessions + source_session = await session_service.create_session( + app_name=app_name, user_id=user_id, session_id='source' + ) + await session_service.create_session( + app_name=app_name, user_id=user_id, session_id='existing_target' + ) + + # Attempt to clone to existing session ID + with pytest.raises(AlreadyExistsError): + await session_service.clone_session( + app_name=app_name, + src_user_id=user_id, + src_session_id=source_session.id, + new_session_id='existing_target', + ) + + +@pytest.mark.asyncio +async def test_clone_session_preserves_event_content(session_service): + """Test that clone_session preserves full event content.""" + app_name = 'my_app' + user_id = 'user' + + # Create source session with detailed event + source_session = await session_service.create_session( + app_name=app_name, user_id=user_id + ) + event = Event( + invocation_id='invocation', + author='user', + content=types.Content(role='user', parts=[types.Part(text='test_text')]), + actions=EventActions( + artifact_delta={'file': 0}, + transfer_to_agent='agent', + ), + ) + await session_service.append_event(source_session, event) + + # Clone the session + cloned_session = await session_service.clone_session( + app_name=app_name, + src_user_id=user_id, + src_session_id=source_session.id, + ) + + # Verify event content is preserved + assert len(cloned_session.events) == 1 + cloned_event = cloned_session.events[0] + assert cloned_event.invocation_id == event.invocation_id + assert cloned_event.author == event.author + assert cloned_event.content == event.content + assert cloned_event.actions.artifact_delta == event.actions.artifact_delta + assert ( + cloned_event.actions.transfer_to_agent == event.actions.transfer_to_agent + ) + + +@pytest.mark.asyncio +async def test_clone_session_does_not_affect_source(session_service): + """Test that cloning does not modify the source session.""" + app_name = 'my_app' + user_id = 'user' + + # Create source session + source_session = await session_service.create_session( + app_name=app_name, user_id=user_id, session_id='source_session' + ) + event = Event(invocation_id='inv1', author='user') + await session_service.append_event(source_session, event) + + original_source = await session_service.get_session( + app_name=app_name, user_id=user_id, session_id='source_session' + ) + original_event_count = len(original_source.events) + + # Clone the session + cloned_session = await session_service.clone_session( + app_name=app_name, + src_user_id=user_id, + src_session_id='source_session', + ) + + # Add event to cloned session + new_event = Event(invocation_id='inv2', author='model') + await session_service.append_event(cloned_session, new_event) + + # Verify source is unchanged + source_after_clone = await session_service.get_session( + app_name=app_name, user_id=user_id, session_id='source_session' + ) + assert len(source_after_clone.events) == original_event_count + + +@pytest.mark.asyncio +async def test_clone_all_user_sessions(session_service): + """Test clone_session without src_session_id merges all user sessions.""" + app_name = 'my_app' + source_user_id = 'user1' + dest_user_id = 'user2' + + # Create multiple source sessions for user1 + session1 = await session_service.create_session( + app_name=app_name, + user_id=source_user_id, + session_id='session1', + state={'key1': 'value1'}, + ) + session2 = await session_service.create_session( + app_name=app_name, + user_id=source_user_id, + session_id='session2', + state={'key2': 'value2'}, + ) + + # Add events to each session + event1 = Event(invocation_id='inv1', author='user') + event2 = Event(invocation_id='inv2', author='model') + event3 = Event(invocation_id='inv3', author='user') + await session_service.append_event(session1, event1) + await session_service.append_event(session1, event2) + await session_service.append_event(session2, event3) + + # Clone ALL sessions for user1 to user2 (no src_session_id) + cloned_session = await session_service.clone_session( + app_name=app_name, + src_user_id=source_user_id, + new_user_id=dest_user_id, + new_session_id='merged_session', + ) + + # Verify merged session + assert cloned_session is not None + assert cloned_session.user_id == dest_user_id + assert cloned_session.id == 'merged_session' + # Should have all 3 events from both source sessions + assert len(cloned_session.events) == 3 + # State should be merged from both sessions + assert 'key1' in cloned_session.state + assert 'key2' in cloned_session.state + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'src_session_id,error_match', + [ + ('non_existent_session', 'not found'), # Specific session not found + (None, 'No sessions found'), # User has no sessions (clone all mode) + ], + ids=['session_not_found', 'no_sessions_for_user'], +) +async def test_clone_session_source_not_found_raises_error( + session_service, src_session_id, error_match +): + """Test clone_session raises ValueError when source cannot be found.""" + app_name = 'my_app' + user_id = 'user_with_no_sessions' + + with pytest.raises(ValueError, match=error_match): + await session_service.clone_session( + app_name=app_name, + src_user_id=user_id, + src_session_id=src_session_id, + ) + + +@pytest.mark.asyncio +async def test_clone_session_deduplicates_events(session_service): + """Test clone_session automatically deduplicates events by ID.""" + app_name = 'my_app' + user_id = 'user' + + # Create two sessions with some events having the same ID + session1 = await session_service.create_session( + app_name=app_name, user_id=user_id, session_id='session1' + ) + session2 = await session_service.create_session( + app_name=app_name, user_id=user_id, session_id='session2' + ) + + # Create events - event1 and event3 have the same ID (duplicate) + event1 = Event(id='shared_event_id', invocation_id='inv1', author='user') + event2 = Event(id='unique_event_1', invocation_id='inv2', author='model') + event3 = Event(id='shared_event_id', invocation_id='inv3', author='user') + event4 = Event(id='unique_event_2', invocation_id='inv4', author='model') + + await session_service.append_event(session1, event1) + await session_service.append_event(session1, event2) + await session_service.append_event(session2, event3) + await session_service.append_event(session2, event4) + + # Clone - should have 3 events (duplicate automatically removed) + cloned_session = await session_service.clone_session( + app_name=app_name, + src_user_id=user_id, + ) + assert len(cloned_session.events) == 3 + # Verify the unique event IDs + event_ids = [e.id for e in cloned_session.events] + assert 'shared_event_id' in event_ids + assert 'unique_event_1' in event_ids + assert 'unique_event_2' in event_ids + # Count occurrences - shared_event_id should appear only once + assert event_ids.count('shared_event_id') == 1