diff --git a/src/google/adk/agents/invocation_context.py b/src/google/adk/agents/invocation_context.py index 24fdce9d59..8ebfd82222 100644 --- a/src/google/adk/agents/invocation_context.py +++ b/src/google/adk/agents/invocation_context.py @@ -14,7 +14,10 @@ from __future__ import annotations +import asyncio from typing import Any +from typing import Awaitable +from typing import Callable from typing import Optional import uuid @@ -213,6 +216,19 @@ class InvocationContext(BaseModel): of this invocation. """ + _tool_call_cache_lock: asyncio.Lock = PrivateAttr( + default_factory=asyncio.Lock + ) + _tool_call_cache: dict[tuple[Any, ...], asyncio.Task] = PrivateAttr( + default_factory=dict + ) + """Caches tool call results within a single invocation. + + This is used to prevent redundant tool execution when the model repeats the + same function call (same tool name + same args) multiple times during a single + invocation. + """ + @property def is_resumable(self) -> bool: """Returns whether the current invocation is resumable.""" @@ -221,6 +237,76 @@ def is_resumable(self) -> bool: and self.resumability_config.is_resumable ) + @staticmethod + def _canonicalize_tool_args(value: Any) -> Any: + """Converts a JSON-like structure into a stable, hashable representation.""" + if isinstance(value, dict): + return tuple( + (k, InvocationContext._canonicalize_tool_args(v)) + for k, v in sorted(value.items()) + ) + if isinstance(value, list): + return tuple(InvocationContext._canonicalize_tool_args(v) for v in value) + if isinstance(value, (str, int, float, bool)) or value is None: + return value + # Fallback: keep it hashable and stable. + return repr(value) + + def _tool_call_cache_key( + self, *, tool_name: str, tool_args: dict[str, Any] + ) -> tuple[Any, ...]: + """Builds a cache key for a tool call within this invocation.""" + return ( + self.branch, + tool_name, + InvocationContext._canonicalize_tool_args(tool_args), + ) + + async def get_or_execute_deduped_tool_call( + self, + *, + tool_name: str, + tool_args: dict[str, Any], + execute: Callable[[], Awaitable[Any]], + dedupe: bool = False, + ) -> tuple[Any, bool]: + """Returns cached tool result for identical calls, otherwise executes once. + + Args: + tool_name: Tool name. + tool_args: Tool arguments from the model. + execute: A coroutine factory that executes the tool and returns its + response. + + Returns: + A tuple of (tool_result, cache_hit). + """ + if not dedupe: + return await execute(), False + + key = self._tool_call_cache_key(tool_name=tool_name, tool_args=tool_args) + + async with self._tool_call_cache_lock: + task = self._tool_call_cache.get(key) + if task is None: + task = asyncio.create_task(execute()) + self._tool_call_cache[key] = task + cache_hit = False + else: + cache_hit = True + + try: + result = await task + except Exception: + # If the execution failed, remove from cache so subsequent calls can + # retry instead of returning a cached exception forever. + async with self._tool_call_cache_lock: + if self._tool_call_cache.get(key) is task: + self._tool_call_cache.pop(key, None) + raise + + return result, cache_hit + def set_agent_state( self, agent_name: str, diff --git a/src/google/adk/agents/run_config.py b/src/google/adk/agents/run_config.py index ae210ef471..9a6f740577 100644 --- a/src/google/adk/agents/run_config.py +++ b/src/google/adk/agents/run_config.py @@ -251,6 +251,19 @@ class RunConfig(BaseModel): - Less than or equal to 0: This allows for unbounded number of llm calls. """ + dedupe_tool_calls: bool = False + """ + Whether to deduplicate identical tool calls (same tool name + same arguments) + within a single invocation. + + This helps prevent redundant tool execution when the model repeats the same + function call multiple times (for example, when a tool is slow or the model + does not follow the instruction to call a tool only once). + + Note: Only the tool result is reused; tool side effects (state/artifact + deltas) are only applied once from the first execution. + """ + custom_metadata: Optional[dict[str, Any]] = None """Custom metadata for the current invocation.""" diff --git a/src/google/adk/cli/adk_web_server.py b/src/google/adk/cli/adk_web_server.py index 752af89c34..8f842f0b7e 100644 --- a/src/google/adk/cli/adk_web_server.py +++ b/src/google/adk/cli/adk_web_server.py @@ -1648,6 +1648,66 @@ async def run_agent_live( await websocket.close(code=1002, reason="Session not found") return + # Determine if this is a live/audio session + # For live sessions, Gemini Live API provides transparent session resumption + # where the model automatically replays its last response. Replaying events + # manually would cause duplicates (Issue #3395). + # For text-only sessions, we need to replay all events (Issue #3573). + def is_live_session(events: list) -> bool: + """Check if session contains audio/video or transcription data.""" + # Check last few events for live session indicators + for event in reversed(events[-5:] if len(events) > 5 else events): + # Check for transcription data (input/output) + if hasattr(event, 'input_transcription') and event.input_transcription: + return True + if hasattr(event, 'output_transcription') and event.output_transcription: + return True + # Check content for audio/video + if event.content: + for part in event.content.parts: + if part.inline_data and ( + part.inline_data.mime_type.startswith("audio/") + or part.inline_data.mime_type.startswith("video/") + ): + return True + if part.file_data and ( + part.file_data.mime_type.startswith("audio/") + or part.file_data.mime_type.startswith("video/") + ): + return True + return False + + # Replay existing session events for text-only sessions + # Skip replay for live/audio sessions to avoid conflicts with + # Gemini Live API's built-in session resumption + should_replay = session.events and not is_live_session(session.events) + + if should_replay: + logger.info( + "Replaying %d existing events for text-only session %s", + len(session.events), + session_id, + ) + for event in session.events: + try: + await websocket.send_text( + event.model_dump_json(exclude_none=True, by_alias=True) + ) + except Exception as e: + logger.error( + "Failed to replay event %s during session restoration: %s", + event.id, + e, + ) + # Continue replaying other events even if one fails + continue + elif session.events and not should_replay: + logger.info( + "Skipping event replay for live/audio session %s (relying on " + "Gemini Live API's transparent session resumption)", + session_id, + ) + live_request_queue = LiveRequestQueue() async def forward_events(): diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index ffe1657be1..233f454d75 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -340,75 +340,92 @@ async def _run_on_tool_error_callbacks( async def _run_with_trace(): nonlocal function_args - # Step 1: Check if plugin before_tool_callback overrides the function - # response. - function_response = ( - await invocation_context.plugin_manager.run_before_tool_callback( - tool=tool, tool_args=function_args, tool_context=tool_context - ) - ) + async def _execute_tool_pipeline() -> Any: + """Executes tool call pipeline once; result can be cached by invocation.""" + # Step 1: Check if plugin before_tool_callback overrides the function + # response. + function_response = ( + await invocation_context.plugin_manager.run_before_tool_callback( + tool=tool, tool_args=function_args, tool_context=tool_context + ) + ) - # Step 2: If no overrides are provided from the plugins, further run the - # canonical callback. - if function_response is None: - for callback in agent.canonical_before_tool_callbacks: - function_response = callback( - tool=tool, args=function_args, tool_context=tool_context - ) - if inspect.isawaitable(function_response): - function_response = await function_response - if function_response: - break + # Step 2: If no overrides are provided from the plugins, further run the + # canonical callback. + if function_response is None: + for callback in agent.canonical_before_tool_callbacks: + function_response = callback( + tool=tool, args=function_args, tool_context=tool_context + ) + if inspect.isawaitable(function_response): + function_response = await function_response + if function_response: + break + + # Step 3: Otherwise, proceed calling the tool normally. + if function_response is None: + try: + function_response = await __call_tool_async( + tool, args=function_args, tool_context=tool_context + ) + except Exception as tool_error: + error_response = await _run_on_tool_error_callbacks( + tool=tool, + tool_args=function_args, + tool_context=tool_context, + error=tool_error, + ) + if error_response is not None: + function_response = error_response + else: + raise tool_error - # Step 3: Otherwise, proceed calling the tool normally. - if function_response is None: - try: - function_response = await __call_tool_async( - tool, args=function_args, tool_context=tool_context - ) - except Exception as tool_error: - error_response = await _run_on_tool_error_callbacks( - tool=tool, - tool_args=function_args, - tool_context=tool_context, - error=tool_error, - ) - if error_response is not None: - function_response = error_response - else: - raise tool_error + # Step 4: Check if plugin after_tool_callback overrides the function + # response. + altered_function_response = ( + await invocation_context.plugin_manager.run_after_tool_callback( + tool=tool, + tool_args=function_args, + tool_context=tool_context, + result=function_response, + ) + ) - # Step 4: Check if plugin after_tool_callback overrides the function - # response. - altered_function_response = ( - await invocation_context.plugin_manager.run_after_tool_callback( - tool=tool, + # Step 5: If no overrides are provided from the plugins, further run the + # canonical after_tool_callbacks. + if altered_function_response is None: + for callback in agent.canonical_after_tool_callbacks: + altered_function_response = callback( + tool=tool, + args=function_args, + tool_context=tool_context, + tool_response=function_response, + ) + if inspect.isawaitable(altered_function_response): + altered_function_response = await altered_function_response + if altered_function_response: + break + + # Step 6: If alternative response exists from after_tool_callback, use it + # instead of the original function response. + if altered_function_response is not None: + function_response = altered_function_response + + return function_response + + should_dedupe = bool( + invocation_context.run_config + and invocation_context.run_config.dedupe_tool_calls + ) or tool.is_long_running + function_response, cache_hit = ( + await invocation_context.get_or_execute_deduped_tool_call( + tool_name=tool.name, tool_args=function_args, - tool_context=tool_context, - result=function_response, + execute=_execute_tool_pipeline, + dedupe=should_dedupe, ) ) - # Step 5: If no overrides are provided from the plugins, further run the - # canonical after_tool_callbacks. - if altered_function_response is None: - for callback in agent.canonical_after_tool_callbacks: - altered_function_response = callback( - tool=tool, - args=function_args, - tool_context=tool_context, - tool_response=function_response, - ) - if inspect.isawaitable(altered_function_response): - altered_function_response = await altered_function_response - if altered_function_response: - break - - # Step 6: If alternative response exists from after_tool_callback, use it - # instead of the original function response. - if altered_function_response is not None: - function_response = altered_function_response - if tool.is_long_running: # Allow long running function to return None to not provide function # response. @@ -423,6 +440,11 @@ async def _run_with_trace(): function_response_event = __build_response_event( tool, function_response, tool_context, invocation_context ) + if cache_hit: + function_response_event.custom_metadata = ( + function_response_event.custom_metadata or {} + ) + function_response_event.custom_metadata['adk_tool_call_cache_hit'] = True return function_response_event with tracer.start_as_current_span(f'execute_tool {tool.name}'): @@ -517,48 +539,69 @@ async def _execute_single_function_call_live( async def _run_with_trace(): nonlocal function_args - # Do not use "args" as the variable name, because it is a reserved keyword - # in python debugger. - # Make a deep copy to avoid being modified. - function_response = None + async def _execute_tool_pipeline() -> Any: + # Do not use "args" as the variable name, because it is a reserved keyword + # in python debugger. + # Make a deep copy to avoid being modified. + function_response = None - # Handle before_tool_callbacks - iterate through the canonical callback - # list - for callback in agent.canonical_before_tool_callbacks: - function_response = callback( - tool=tool, args=function_args, tool_context=tool_context - ) - if inspect.isawaitable(function_response): - function_response = await function_response - if function_response: - break - - if function_response is None: - function_response = await _process_function_live_helper( - tool, - tool_context, - function_call, - function_args, - invocation_context, - streaming_lock, - ) + # Handle before_tool_callbacks - iterate through the canonical callback + # list. + for callback in agent.canonical_before_tool_callbacks: + function_response = callback( + tool=tool, args=function_args, tool_context=tool_context + ) + if inspect.isawaitable(function_response): + function_response = await function_response + if function_response: + break - # Calls after_tool_callback if it exists. - altered_function_response = None - for callback in agent.canonical_after_tool_callbacks: - altered_function_response = callback( - tool=tool, - args=function_args, - tool_context=tool_context, - tool_response=function_response, - ) - if inspect.isawaitable(altered_function_response): - altered_function_response = await altered_function_response - if altered_function_response: - break + if function_response is None: + function_response = await _process_function_live_helper( + tool, + tool_context, + function_call, + function_args, + invocation_context, + streaming_lock, + ) - if altered_function_response is not None: - function_response = altered_function_response + # Calls after_tool_callback if it exists. + altered_function_response = None + for callback in agent.canonical_after_tool_callbacks: + altered_function_response = callback( + tool=tool, + args=function_args, + tool_context=tool_context, + tool_response=function_response, + ) + if inspect.isawaitable(altered_function_response): + altered_function_response = await altered_function_response + if altered_function_response: + break + + if altered_function_response is not None: + function_response = altered_function_response + + return function_response + + # Never cache stop_streaming calls (control operation). + if function_call.name == 'stop_streaming': + function_response = await _execute_tool_pipeline() + cache_hit = False + else: + should_dedupe = bool( + invocation_context.run_config + and invocation_context.run_config.dedupe_tool_calls + ) or tool.is_long_running + function_response, cache_hit = ( + await invocation_context.get_or_execute_deduped_tool_call( + tool_name=tool.name, + tool_args=function_args, + execute=_execute_tool_pipeline, + dedupe=should_dedupe, + ) + ) if tool.is_long_running: # Allow async function to return None to not provide function response. @@ -573,6 +616,11 @@ async def _run_with_trace(): function_response_event = __build_response_event( tool, function_response, tool_context, invocation_context ) + if cache_hit: + function_response_event.custom_metadata = ( + function_response_event.custom_metadata or {} + ) + function_response_event.custom_metadata['adk_tool_call_cache_hit'] = True return function_response_event with tracer.start_as_current_span(f'execute_tool {tool.name}'): diff --git a/tests/unittests/cli/test_live_session_restoration.py b/tests/unittests/cli/test_live_session_restoration.py new file mode 100644 index 0000000000..ea9f822201 --- /dev/null +++ b/tests/unittests/cli/test_live_session_restoration.py @@ -0,0 +1,374 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for BIDI live session restoration with SQLite persistence (Issue #3573, #3395).""" + +from __future__ import annotations + +from unittest.mock import AsyncMock +from unittest.mock import MagicMock +from unittest.mock import patch + +from google.adk.events.event import Event +from google.adk.sessions.session import Session +from google.genai import types +import pytest + + +@pytest.mark.asyncio +async def test_live_session_replays_all_events_on_reconnection(): + """Test that reconnecting to a live session replays all events including user messages. + + This tests the fix for Issue #3573 where user messages were stored in the + database but not sent back to the client on reconnection. + """ + # Create a mock session with both user and agent events + user_event = Event( + id="event-user-1", + author="user", + content=types.Content(parts=[types.Part(text="Hello, assistant!")]), + invocation_id="inv-1", + ) + agent_event = Event( + id="event-agent-1", + author="test_agent", + content=types.Content( + parts=[types.Part(text="Hello! How can I help you?")] + ), + invocation_id="inv-1", + ) + user_event2 = Event( + id="event-user-2", + author="user", + content=types.Content( + parts=[types.Part(text="What's the weather today?")] + ), + invocation_id="inv-2", + ) + agent_event2 = Event( + id="event-agent-2", + author="test_agent", + content=types.Content( + parts=[types.Part(text="I can help you check the weather.")] + ), + invocation_id="inv-2", + ) + + mock_session = Session( + app_name="test_app", + user_id="test_user", + id="test_session", + state={}, + events=[user_event, agent_event, user_event2, agent_event2], + last_update_time=1234567890.0, + ) + + # Mock WebSocket to capture replayed events + mock_websocket = AsyncMock() + replayed_events = [] + + async def capture_send_text(data): + replayed_events.append(data) + + mock_websocket.send_text = capture_send_text + + # Test the core event replay logic that should be in run_agent_live + # This simulates what happens when a client reconnects: + # 1. Session is loaded (with all events) + session = mock_session + + # 2. All existing events should be replayed to the client + if session and session.events: + for event in session.events: + await mock_websocket.send_text( + event.model_dump_json(exclude_none=True, by_alias=True) + ) + + # Verify that all 4 events were replayed (2 user + 2 agent) + assert len(replayed_events) == 4 + + # Verify that events were sent in order + import json + + event_data = [json.loads(data) for data in replayed_events] + + # Check that user messages are included + assert event_data[0]["author"] == "user" + assert "Hello, assistant!" in event_data[0]["content"]["parts"][0]["text"] + + assert event_data[1]["author"] == "test_agent" + + assert event_data[2]["author"] == "user" + assert "weather" in event_data[2]["content"]["parts"][0]["text"] + + assert event_data[3]["author"] == "test_agent" + + +@pytest.mark.asyncio +async def test_live_session_handles_empty_events_gracefully(): + """Test that session restoration handles sessions with no events.""" + mock_session = Session( + app_name="test_app", + user_id="test_user", + id="new_session", + state={}, + events=[], # No events yet + last_update_time=1234567890.0, + ) + + mock_session_service = AsyncMock() + mock_session_service.get_session.return_value = mock_session + + mock_websocket = AsyncMock() + replayed_events = [] + + async def capture_send_text(data): + replayed_events.append(data) + + mock_websocket.send_text = capture_send_text + + # Simulate event replay logic + session = await mock_session_service.get_session( + app_name="test_app", user_id="test_user", session_id="new_session" + ) + + if session and session.events: + for event in session.events: + await mock_websocket.send_text( + event.model_dump_json(exclude_none=True, by_alias=True) + ) + + # Should not send any events for an empty session + assert len(replayed_events) == 0 + + +@pytest.mark.asyncio +async def test_live_session_continues_after_replay_failure(): + """Test that session continues even if one event fails to replay.""" + # Create events where one might fail to serialize + event1 = Event( + id="event-1", + author="user", + content=types.Content(parts=[types.Part(text="First message")]), + invocation_id="inv-1", + ) + event2 = Event( + id="event-2", + author="agent", + content=types.Content(parts=[types.Part(text="Second message")]), + invocation_id="inv-1", + ) + event3 = Event( + id="event-3", + author="user", + content=types.Content(parts=[types.Part(text="Third message")]), + invocation_id="inv-2", + ) + + mock_session = Session( + app_name="test_app", + user_id="test_user", + id="test_session", + state={}, + events=[event1, event2, event3], + last_update_time=1234567890.0, + ) + + mock_websocket = AsyncMock() + replayed_events = [] + send_call_count = 0 + + async def capture_send_text(data): + nonlocal send_call_count + send_call_count += 1 + # Simulate failure on second event + if send_call_count == 2: + raise Exception("Simulated network error") + replayed_events.append(data) + + mock_websocket.send_text = capture_send_text + + # Simulate event replay with error handling + if mock_session and mock_session.events: + for event in mock_session.events: + try: + await mock_websocket.send_text( + event.model_dump_json(exclude_none=True, by_alias=True) + ) + except Exception: + # Continue replaying even if one fails + continue + + # Should have replayed 2 events successfully (skipped the failing one) + assert len(replayed_events) == 2 + assert send_call_count == 3 # Attempted all 3 + + +@pytest.mark.asyncio +async def test_live_session_skips_replay_for_audio_sessions(): + """Test that live sessions with audio/transcription skip replay to avoid duplicates. + + This tests the fix for Issue #3395 where Gemini Live API's transparent session + resumption would conflict with our event replay, causing duplicate responses. + When a session contains audio or transcription data, we rely on the model's + session resumption instead of replaying events. + """ + # Create a session with transcription data (indicating live mode) + user_event_with_transcription = Event( + id="event-user-1", + author="user", + content=types.Content(parts=[types.Part(text="Hello")]), + invocation_id="inv-1", + input_transcription=types.Transcription( + text="Hello", finished=True + ), + ) + agent_event_with_transcription = Event( + id="event-agent-1", + author="test_agent", + content=types.Content( + parts=[types.Part(text="Hello! How can I help you?")] + ), + invocation_id="inv-1", + output_transcription=types.Transcription( + text="Hello! How can I help you?", finished=True + ), + ) + + mock_session = Session( + app_name="test_app", + user_id="test_user", + id="test_session", + state={}, + events=[user_event_with_transcription, agent_event_with_transcription], + last_update_time=1234567890.0, + ) + + mock_websocket = AsyncMock() + replayed_events = [] + + async def capture_send_text(data): + replayed_events.append(data) + + mock_websocket.send_text = capture_send_text + + # Helper to detect live sessions (copied from adk_web_server.py logic) + def is_live_session(events: list) -> bool: + for event in reversed(events[-5:] if len(events) > 5 else events): + if hasattr(event, 'input_transcription') and event.input_transcription: + return True + if hasattr(event, 'output_transcription') and event.output_transcription: + return True + if event.content: + for part in event.content.parts: + if part.inline_data and ( + part.inline_data.mime_type.startswith("audio/") + or part.inline_data.mime_type.startswith("video/") + ): + return True + if part.file_data and ( + part.file_data.mime_type.startswith("audio/") + or part.file_data.mime_type.startswith("video/") + ): + return True + return False + + # Test the conditional replay logic + session = mock_session + should_replay = session.events and not is_live_session(session.events) + + if should_replay: + for event in session.events: + await mock_websocket.send_text( + event.model_dump_json(exclude_none=True, by_alias=True) + ) + + # Expect no events to be replayed because it's a live/audio session + assert len(replayed_events) == 0 + + +@pytest.mark.asyncio +async def test_text_session_replays_events_normally(): + """Test that text-only sessions still replay events as expected. + + This ensures that the fix for Issue #3395 (skipping replay for live sessions) + doesn't break the fix for Issue #3573 (replaying events for text sessions). + """ + # Create a session with only text content (no transcriptions) + user_event = Event( + id="event-user-1", + author="user", + content=types.Content(parts=[types.Part(text="Hello")]), + invocation_id="inv-1", + ) + agent_event = Event( + id="event-agent-1", + author="test_agent", + content=types.Content( + parts=[types.Part(text="Hello! How can I help you?")] + ), + invocation_id="inv-1", + ) + + mock_session = Session( + app_name="test_app", + user_id="test_user", + id="test_session", + state={}, + events=[user_event, agent_event], + last_update_time=1234567890.0, + ) + + mock_websocket = AsyncMock() + replayed_events = [] + + async def capture_send_text(data): + replayed_events.append(data) + + mock_websocket.send_text = capture_send_text + + # Helper to detect live sessions + def is_live_session(events: list) -> bool: + for event in reversed(events[-5:] if len(events) > 5 else events): + if hasattr(event, 'input_transcription') and event.input_transcription: + return True + if hasattr(event, 'output_transcription') and event.output_transcription: + return True + if event.content: + for part in event.content.parts: + if part.inline_data and ( + part.inline_data.mime_type.startswith("audio/") + or part.inline_data.mime_type.startswith("video/") + ): + return True + if part.file_data and ( + part.file_data.mime_type.startswith("audio/") + or part.file_data.mime_type.startswith("video/") + ): + return True + return False + + # Test the conditional replay logic + session = mock_session + should_replay = session.events and not is_live_session(session.events) + + if should_replay: + for event in session.events: + await mock_websocket.send_text( + event.model_dump_json(exclude_none=True, by_alias=True) + ) + + # Expect both events to be replayed for text-only session + assert len(replayed_events) == 2 + diff --git a/tests/unittests/flows/llm_flows/test_tool_call_deduplication.py b/tests/unittests/flows/llm_flows/test_tool_call_deduplication.py new file mode 100644 index 0000000000..fb51581ebb --- /dev/null +++ b/tests/unittests/flows/llm_flows/test_tool_call_deduplication.py @@ -0,0 +1,123 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for tool call de-duplication (Issue #3940).""" + +from google.adk.agents.llm_agent import Agent +from google.adk.agents.run_config import RunConfig +from google.genai import types +import pytest + +from ... import testing_utils + + +def _function_call(name: str, args: dict) -> types.Part: + return types.Part.from_function_call(name=name, args=args) + + +@pytest.mark.asyncio +async def test_dedupe_identical_tool_calls_across_steps(): + """Identical tool calls should execute once and reuse the cached result.""" + responses = [ + _function_call("test_tool", {"x": 1}), + _function_call("test_tool", {"x": 1}), + "done", + ] + mock_model = testing_utils.MockModel.create(responses=responses) + + call_count = 0 + + def test_tool(x: int) -> dict: + nonlocal call_count + call_count += 1 + return {"result": call_count} + + agent = Agent(name="root_agent", model=mock_model, tools=[test_tool]) + runner = testing_utils.InMemoryRunner(root_agent=agent) + + run_config = RunConfig(dedupe_tool_calls=True) + events = [] + async for event in runner.runner.run_async( + user_id=runner.session.user_id, + session_id=runner.session.id, + new_message=testing_utils.get_user_content("run"), + run_config=run_config, + ): + events.append(event) + simplified = testing_utils.simplify_events(events) + + # Tool should execute exactly once even though the model calls it twice. + assert call_count == 1 + + # Both tool responses should contain the same cached payload. + tool_responses = [ + content + for _, content in simplified + if isinstance(content, types.Part) and content.function_response + ] + assert len(tool_responses) == 2 + assert tool_responses[0].function_response.response == {"result": 1} + assert tool_responses[1].function_response.response == {"result": 1} + + +def test_dedupe_identical_tool_calls_within_one_step(): + """Identical tool calls within the same step should execute once.""" + responses = [ + [ + _function_call("test_tool", {"x": 1}), + _function_call("test_tool", {"x": 1}), + ], + "done", + ] + mock_model = testing_utils.MockModel.create(responses=responses) + + call_count = 0 + + def test_tool(x: int) -> dict: + nonlocal call_count + call_count += 1 + return {"result": call_count} + + agent = Agent(name="root_agent", model=mock_model, tools=[test_tool]) + runner = testing_utils.InMemoryRunner(root_agent=agent) + + run_config = RunConfig(dedupe_tool_calls=True) + events = list( + runner.runner.run( + user_id=runner.session.user_id, + session_id=runner.session.id, + new_message=testing_utils.get_user_content("run"), + run_config=run_config, + ) + ) + simplified = testing_utils.simplify_events(events) + + assert call_count == 1 + + # The merged tool response event contains 2 function_response parts. + merged_parts = [ + content + for _, content in simplified + if isinstance(content, list) + and all(isinstance(p, types.Part) for p in content) + and any(p.function_response for p in content) + ] + assert len(merged_parts) == 1 + function_responses = [ + p.function_response.response + for p in merged_parts[0] + if p.function_response is not None + ] + assert function_responses == [{"result": 1}, {"result": 1}] +