diff --git a/src/bedrock_agentcore/memory/integrations/strands/bedrock_converter.py b/src/bedrock_agentcore/memory/integrations/strands/bedrock_converter.py index 1f0905c..b9c06c4 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/bedrock_converter.py +++ b/src/bedrock_agentcore/memory/integrations/strands/bedrock_converter.py @@ -72,7 +72,7 @@ def events_to_messages(events: list[dict[str, Any]]) -> list[SessionMessage]: list[SessionMessage]: list of SessionMessage objects. """ messages = [] - for event in events: + for event in reversed(events): for payload_item in event.get("payload", []): if "conversational" in payload_item: conv = payload_item["conversational"] @@ -93,7 +93,7 @@ def events_to_messages(events: list[dict[str, Any]]) -> list[SessionMessage]: logger.error("This is not a SessionMessage but just a blob message. Ignoring") except (json.JSONDecodeError, ValueError): logger.error("Failed to parse blob content: %s", payload_item) - return list(reversed(messages)) + return messages @staticmethod def total_length(message: tuple[str, str]) -> int: diff --git a/src/bedrock_agentcore/memory/integrations/strands/config.py b/src/bedrock_agentcore/memory/integrations/strands/config.py index d2d5cef..7017568 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/config.py +++ b/src/bedrock_agentcore/memory/integrations/strands/config.py @@ -29,9 +29,12 @@ class AgentCoreMemoryConfig(BaseModel): session_id: Required unique ID for the session actor_id: Required unique ID for the agent instance/user retrieval_config: Optional dictionary mapping namespaces to retrieval configurations + batch_size: Number of messages to batch before sending to AgentCore Memory. + Default of 1 means immediate sending (no batching). Max 100. """ memory_id: str = Field(min_length=1) session_id: str = Field(min_length=1) actor_id: str = Field(min_length=1) retrieval_config: Optional[Dict[str, RetrievalConfig]] = None + batch_size: int = Field(default=1, ge=1, le=100) diff --git a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py index 09e886d..288d00c 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py +++ b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py @@ -119,6 +119,10 @@ def __init__( session = boto_session or boto3.Session(region_name=region_name) self.has_existing_agent = False + # Batching support - stores pre-processed messages: (session_id, messages, is_blob, timestamp) + self._message_buffer: list[tuple[str, list[tuple[str, str]], bool, datetime]] = [] + self._buffer_lock = threading.Lock() + # Add strands-agents to the request user agent if boto_client_config: existing_user_agent = getattr(boto_client_config, "user_agent_extra", None) @@ -380,6 +384,9 @@ def create_message( ) -> Optional[dict[str, Any]]: """Create a new message in AgentCore Memory. + If batch_size > 1, the message is buffered and sent when the buffer reaches batch_size. + Use _flush_messages() or close() to send any remaining buffered messages. + Args: session_id (str): The session ID to create the message in. agent_id (str): The agent ID associated with the message (only here for the interface. @@ -389,6 +396,7 @@ def create_message( Returns: Optional[dict[str, Any]]: The created event data from AgentCore Memory. + Returns empty dict if message is buffered (batch_size > 1). Raises: SessionException: If session ID doesn't match configuration or message creation fails. @@ -409,16 +417,33 @@ def create_message( if session_id != self.config.session_id: raise SessionException(f"Session ID mismatch: expected {self.config.session_id}, got {session_id}") - try: - messages = AgentCoreMemoryConverter.message_to_payload(session_message) - if not messages: - return + # Convert and check size ONCE (not again at flush) + messages = AgentCoreMemoryConverter.message_to_payload(session_message) + if not messages: + return None + + is_blob = AgentCoreMemoryConverter.exceeds_conversational_limit(messages[0]) + + # Parse the original timestamp and use it as desired timestamp + original_timestamp = datetime.fromisoformat(session_message.created_at.replace("Z", "+00:00")) + monotonic_timestamp = self._get_monotonic_timestamp(original_timestamp) + + if self.config.batch_size > 1: + # Buffer the pre-processed message + should_flush = False + with self._buffer_lock: + self._message_buffer.append((session_id, messages, is_blob, monotonic_timestamp)) + should_flush = len(self._message_buffer) >= self.config.batch_size + + # Flush outside the lock to prevent deadlock + if should_flush: + self._flush_messages() - # Parse the original timestamp and use it as desired timestamp - original_timestamp = datetime.fromisoformat(session_message.created_at.replace("Z", "+00:00")) - monotonic_timestamp = self._get_monotonic_timestamp(original_timestamp) + return {} # No eventId yet - if not AgentCoreMemoryConverter.exceeds_conversational_limit(messages[0]): + # Immediate send (batch_size == 1) + try: + if not is_blob: event = self.memory_client.create_event( memory_id=self.config.memory_id, actor_id=self.config.actor_id, @@ -645,3 +670,131 @@ def initialize(self, agent: "Agent", **kwargs: Any) -> None: RepositorySessionManager.initialize(self, agent, **kwargs) # endregion RepositorySessionManager overrides + + # region Batching support + + def _flush_messages(self) -> list[dict[str, Any]]: + """Flush all buffered messages to AgentCore Memory. + + Call this method to send any remaining buffered messages when batch_size > 1. + This is automatically called when the buffer reaches batch_size, but should + also be called explicitly when the session is complete (via close() or context manager). + + Messages are batched by session_id - all conversational messages for the same + session are combined into a single create_event() call to reduce API calls. + Blob messages (>9KB) are sent individually as they require a different API path. + + Returns: + list[dict[str, Any]]: List of created event responses from AgentCore Memory. + + Raises: + SessionException: If any message creation fails. On failure, all messages + remain in the buffer to prevent data loss. + """ + with self._buffer_lock: + messages_to_send = list(self._message_buffer) + + if not messages_to_send: + return [] + + # Group conversational messages by session_id, preserve order + # Structure: {session_id: {"messages": [...], "timestamp": latest_timestamp}} + session_groups: dict[str, dict[str, Any]] = {} + blob_messages: list[tuple[str, list[tuple[str, str]], datetime]] = [] + + for session_id, messages, is_blob, monotonic_timestamp in messages_to_send: + if is_blob: + # Blobs cannot be combined - collect them separately + blob_messages.append((session_id, messages, monotonic_timestamp)) + else: + # Group conversational messages by session_id + if session_id not in session_groups: + session_groups[session_id] = {"messages": [], "timestamp": monotonic_timestamp} + # Extend messages list to preserve order (earlier messages first) + session_groups[session_id]["messages"].extend(messages) + # Use the latest timestamp for the combined event + if monotonic_timestamp > session_groups[session_id]["timestamp"]: + session_groups[session_id]["timestamp"] = monotonic_timestamp + + results = [] + try: + # Send one create_event per session_id with combined messages + for session_id, group in session_groups.items(): + event = self.memory_client.create_event( + memory_id=self.config.memory_id, + actor_id=self.config.actor_id, + session_id=session_id, + messages=group["messages"], + event_timestamp=group["timestamp"], + ) + results.append(event) + logger.debug("Flushed batched event for session %s: %s", session_id, event.get("eventId")) + + # Send blob messages individually (they use a different API path) + for session_id, messages, monotonic_timestamp in blob_messages: + event = self.memory_client.gmdp_client.create_event( + memoryId=self.config.memory_id, + actorId=self.config.actor_id, + sessionId=session_id, + payload=[ + {"blob": json.dumps(messages[0])}, + ], + eventTimestamp=monotonic_timestamp, + ) + results.append(event) + logger.debug("Flushed blob event for session %s: %s", session_id, event.get("eventId")) + + # Clear buffer only after ALL messages succeed + with self._buffer_lock: + self._message_buffer.clear() + + except Exception as e: + logger.error("Failed to flush messages to AgentCore Memory for session: %s", e) + raise SessionException(f"Failed to flush messages: {e}") from e + + logger.info("Flushed %d events to AgentCore Memory", len(results)) + return results + + def pending_message_count(self) -> int: + """Return the number of messages pending in the buffer. + + Returns: + int: Number of buffered messages waiting to be sent. + """ + with self._buffer_lock: + return len(self._message_buffer) + + def close(self) -> None: + """Explicitly flush pending messages and close the session manager. + + Call this method when the session is complete to ensure all buffered + messages are sent to AgentCore Memory. Alternatively, use the context + manager protocol (with statement) for automatic cleanup. + """ + self._flush_messages() + + def __enter__(self) -> "AgentCoreMemorySessionManager": + """Enter the context manager. + + Returns: + AgentCoreMemorySessionManager: This session manager instance. + """ + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + """Exit the context manager and flush any pending messages. + + Args: + exc_type: Exception type if an exception occurred. + exc_val: Exception value if an exception occurred. + exc_tb: Exception traceback if an exception occurred. + """ + try: + self._flush_messages() + except Exception as e: + if exc_type is not None: + logger.error("Failed to flush messages during exception handling: %s", e) + else: + raise + + # endregion Batching support diff --git a/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py b/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py index 5461da7..554223c 100644 --- a/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py +++ b/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py @@ -1,5 +1,6 @@ """Tests for AgentCoreMemorySessionManager.""" +import logging from unittest.mock import Mock, patch import pytest @@ -48,25 +49,45 @@ def mock_memory_client(): return client -@pytest.fixture -def session_manager(agentcore_config, mock_memory_client): - """Create an AgentCoreMemorySessionManager with mocked dependencies.""" +def _create_session_manager(config, mock_memory_client): + """Helper to create a session manager with mocked dependencies.""" with patch( "bedrock_agentcore.memory.integrations.strands.session_manager.MemoryClient", return_value=mock_memory_client + ), patch("boto3.Session") as mock_boto_session, patch( + "strands.session.repository_session_manager.RepositorySessionManager.__init__", return_value=None ): - with patch("boto3.Session") as mock_boto_session: - mock_session = Mock() - mock_session.region_name = "us-west-2" - mock_session.client.return_value = Mock() - mock_boto_session.return_value = mock_session + mock_session = Mock() + mock_session.region_name = "us-west-2" + mock_session.client.return_value = Mock() + mock_boto_session.return_value = mock_session + + manager = AgentCoreMemorySessionManager(config) + manager.session_id = config.session_id + manager.session = Session(session_id=config.session_id, session_type=SessionType.AGENT) + return manager + + +@pytest.fixture +def session_manager(agentcore_config, mock_memory_client): + """Create an AgentCoreMemorySessionManager with mocked dependencies.""" + return _create_session_manager(agentcore_config, mock_memory_client) - with patch( - "strands.session.repository_session_manager.RepositorySessionManager.__init__", return_value=None - ): - manager = AgentCoreMemorySessionManager(agentcore_config) - manager.session_id = agentcore_config.session_id - manager.session = Session(session_id=agentcore_config.session_id, session_type=SessionType.AGENT) - return manager + +@pytest.fixture +def batching_config(): + """Create a config with batch_size > 1.""" + return AgentCoreMemoryConfig( + memory_id="test-memory-123", + session_id="test-session-456", + actor_id="test-actor-789", + batch_size=10, + ) + + +@pytest.fixture +def batching_session_manager(batching_config, mock_memory_client): + """Create a session manager with batching enabled.""" + return _create_session_manager(batching_config, mock_memory_client) @pytest.fixture @@ -229,17 +250,13 @@ def test_read_agent_no_events(self, session_manager, mock_memory_client): assert result is None - @patch("bedrock_agentcore.memory.integrations.strands.session_manager.time.sleep") - def test_read_agent_legacy_migration(self, mock_sleep, session_manager, mock_memory_client): + def test_read_agent_legacy_migration(self, session_manager, mock_memory_client): """Test reading a legacy agent event triggers migration.""" legacy_agent_data = '{"agent_id": "test-agent-123", "state": {}, "conversation_manager_state": {}}' - # New approach with metadata is retried 3 times (all return empty) - # Then legacy actor_id approach returns the legacy event + # New approach with metadata returns empty, then legacy approach returns the event mock_memory_client.list_events.side_effect = [ - [], # New approach - attempt 1 - [], # New approach - attempt 2 - [], # New approach - attempt 3 + [], # New approach with metadata - returns empty [{"eventId": "legacy-agent-event-1", "payload": [{"blob": legacy_agent_data}]}], # Legacy approach ] mock_memory_client.gmdp_client.create_event.return_value = {"event": {"eventId": "new-agent-event-1"}} @@ -580,8 +597,8 @@ def test_load_long_term_memories_with_validation_failure(self, mock_memory_clien # Should not call retrieve_memories due to validation failure assert mock_memory_client.retrieve_memories.call_count == 0 - # No memories should be stored - assert "ltm_memories" not in test_agent.state._state + # No memories should be stored (agent.state is unmodified since we mocked the method) + assert test_agent.state.get("ltm_memories") is None def test_retry_with_backoff_success(self, session_manager): """Test retry mechanism with eventual success.""" @@ -1165,3 +1182,702 @@ def test_list_messages_with_limit_calculates_max_results(self, session_manager, mock_memory_client.list_events.assert_called_once() call_kwargs = mock_memory_client.list_events.call_args[1] assert call_kwargs["max_results"] == 550 # limit + offset + + +class TestBatchingConfig: + """Test batch_size configuration validation.""" + + def test_batch_size_default_value(self): + """Test batch_size defaults to 1 (immediate send).""" + config = AgentCoreMemoryConfig( + memory_id="test-memory", + session_id="test-session", + actor_id="test-actor", + ) + assert config.batch_size == 1 + + def test_batch_size_custom_value(self): + """Test batch_size can be set to a custom value.""" + config = AgentCoreMemoryConfig( + memory_id="test-memory", + session_id="test-session", + actor_id="test-actor", + batch_size=10, + ) + assert config.batch_size == 10 + + def test_batch_size_maximum_value(self): + """Test batch_size accepts maximum value of 100.""" + config = AgentCoreMemoryConfig( + memory_id="test-memory", + session_id="test-session", + actor_id="test-actor", + batch_size=100, + ) + assert config.batch_size == 100 + + def test_batch_size_exceeds_maximum_raises_error(self): + """Test batch_size above 100 raises validation error.""" + with pytest.raises(ValueError): + AgentCoreMemoryConfig( + memory_id="test-memory", + session_id="test-session", + actor_id="test-actor", + batch_size=101, + ) + + def test_batch_size_zero_raises_error(self): + """Test batch_size of 0 raises validation error.""" + with pytest.raises(ValueError): + AgentCoreMemoryConfig( + memory_id="test-memory", + session_id="test-session", + actor_id="test-actor", + batch_size=0, + ) + + def test_batch_size_negative_raises_error(self): + """Test negative batch_size raises validation error.""" + with pytest.raises(ValueError): + AgentCoreMemoryConfig( + memory_id="test-memory", + session_id="test-session", + actor_id="test-actor", + batch_size=-1, + ) + + +class TestBatchingBufferManagement: + """Test batching buffer management and pending_message_count.""" + + @pytest.fixture + def batching_config(self): + """Override with batch_size=5 for buffer management tests.""" + return AgentCoreMemoryConfig( + memory_id="test-memory-123", + session_id="test-session-456", + actor_id="test-actor-789", + batch_size=5, + ) + + @pytest.fixture + def batching_session_manager(self, batching_config, mock_memory_client): + """Create a session manager with batch_size=5.""" + return _create_session_manager(batching_config, mock_memory_client) + + def test_pending_message_count_empty_buffer(self, batching_session_manager): + """Test pending_message_count returns 0 for empty buffer.""" + assert batching_session_manager.pending_message_count() == 0 + + def test_pending_message_count_with_buffered_messages(self, batching_session_manager, mock_memory_client): + """Test pending_message_count returns correct count.""" + # Add messages to buffer (batch_size=5, so won't auto-flush) + for i in range(3): + message = SessionMessage( + message={"role": "user", "content": [{"text": f"Message {i}"}]}, + message_id=i, + created_at="2024-01-01T12:00:00Z", + ) + batching_session_manager.create_message("test-session-456", "test-agent", message) + + assert batching_session_manager.pending_message_count() == 3 + # Verify no events were sent (still buffered) + mock_memory_client.create_event.assert_not_called() + + def test_buffer_auto_flushes_at_batch_size(self, batching_session_manager, mock_memory_client): + """Test buffer automatically flushes when reaching batch_size.""" + mock_memory_client.create_event.return_value = {"eventId": "event_123"} + + # Add exactly batch_size messages (5) + for i in range(5): + message = SessionMessage( + message={"role": "user", "content": [{"text": f"Message {i}"}]}, + message_id=i, + created_at="2024-01-01T12:00:00Z", + ) + batching_session_manager.create_message("test-session-456", "test-agent", message) + + # Buffer should have been flushed + assert batching_session_manager.pending_message_count() == 0 + # One batched API call for all messages in the same session + assert mock_memory_client.create_event.call_count == 1 + + def test_create_message_returns_empty_dict_when_buffered(self, batching_session_manager): + """Test create_message returns empty dict when message is buffered.""" + message = SessionMessage( + message={"role": "user", "content": [{"text": "Hello"}]}, + message_id=1, + created_at="2024-01-01T12:00:00Z", + ) + + result = batching_session_manager.create_message("test-session-456", "test-agent", message) + + assert result == {} + + +class TestBatchingFlush: + """Test _flush_messages behavior.""" + + def test__flush_messages_empty_buffer(self, batching_session_manager): + """Test _flush_messages with empty buffer returns empty list.""" + results = batching_session_manager._flush_messages() + assert results == [] + + def test__flush_messages_sends_all_buffered(self, batching_session_manager, mock_memory_client): + """Test _flush_messages sends all buffered messages in a single batched call.""" + mock_memory_client.create_event.return_value = {"eventId": "event_123"} + + # Add 3 messages (below batch_size of 10) + for i in range(3): + message = SessionMessage( + message={"role": "user", "content": [{"text": f"Message {i}"}]}, + message_id=i, + created_at="2024-01-01T12:00:00Z", + ) + batching_session_manager.create_message("test-session-456", "test-agent", message) + + assert batching_session_manager.pending_message_count() == 3 + + # Flush manually + results = batching_session_manager._flush_messages() + + # One batched API call for all messages in the same session + assert len(results) == 1 + assert batching_session_manager.pending_message_count() == 0 + assert mock_memory_client.create_event.call_count == 1 + + def test__flush_messages_maintains_order(self, batching_session_manager, mock_memory_client): + """Test _flush_messages maintains message order within batched payload.""" + sent_payloads = [] + + def track_create_event(**kwargs): + sent_payloads.append(kwargs.get("messages")) + return {"eventId": f"event_{len(sent_payloads)}"} + + mock_memory_client.create_event.side_effect = track_create_event + + # Add messages with distinct content + for i in range(3): + message = SessionMessage( + message={"role": "user", "content": [{"text": f"Message_{i}"}]}, + message_id=i, + created_at=f"2024-01-01T12:0{i}:00Z", + ) + batching_session_manager.create_message("test-session-456", "test-agent", message) + + batching_session_manager._flush_messages() + + # Should be one batched call with messages in order + assert len(sent_payloads) == 1 + combined_messages = sent_payloads[0] + assert len(combined_messages) == 3 + for i, msg in enumerate(combined_messages): + assert f"Message_{i}" in msg[0] + + def test__flush_messages_clears_buffer(self, batching_session_manager, mock_memory_client): + """Test _flush_messages clears the buffer after sending.""" + mock_memory_client.create_event.return_value = {"eventId": "event_123"} + + message = SessionMessage( + message={"role": "user", "content": [{"text": "Hello"}]}, + message_id=1, + created_at="2024-01-01T12:00:00Z", + ) + batching_session_manager.create_message("test-session-456", "test-agent", message) + + # First flush + batching_session_manager._flush_messages() + assert batching_session_manager.pending_message_count() == 0 + + # Second flush should be no-op + results = batching_session_manager._flush_messages() + assert results == [] + + def test__flush_messages_exception_handling(self, batching_session_manager, mock_memory_client): + """Test _flush_messages raises SessionException on failure.""" + mock_memory_client.create_event.side_effect = Exception("API Error") + + message = SessionMessage( + message={"role": "user", "content": [{"text": "Hello"}]}, + message_id=1, + created_at="2024-01-01T12:00:00Z", + ) + batching_session_manager.create_message("test-session-456", "test-agent", message) + + with pytest.raises(SessionException, match="Failed to flush messages"): + batching_session_manager._flush_messages() + + def test_partial_flush_failure_preserves_all_messages(self, batching_session_manager, mock_memory_client): + """Test that on flush failure, all messages remain in buffer to prevent data loss.""" + mock_memory_client.create_event.side_effect = Exception("API Error") + + # Add multiple messages + for i in range(3): + message = SessionMessage( + message={"role": "user", "content": [{"text": f"Message {i}"}]}, + message_id=i, + created_at=f"2024-01-01T12:0{i}:00Z", + ) + batching_session_manager.create_message("test-session-456", "test-agent", message) + + assert batching_session_manager.pending_message_count() == 3 + + # Flush should fail + with pytest.raises(SessionException): + batching_session_manager._flush_messages() + + # All messages should still be in buffer (not cleared on failure) + assert batching_session_manager.pending_message_count() == 3 + + # Fix the mock and retry - should succeed now + mock_memory_client.create_event.side_effect = None + mock_memory_client.create_event.return_value = {"eventId": "event_123"} + + results = batching_session_manager._flush_messages() + assert len(results) == 1 # One batched call for all messages + assert batching_session_manager.pending_message_count() == 0 + + def test_batching_combines_messages_for_same_session(self, batching_session_manager, mock_memory_client): + """Test that multiple messages for the same session are combined into one API call.""" + sent_payloads = [] + + def track_create_event(**kwargs): + sent_payloads.append(kwargs.get("messages")) + return {"eventId": f"event_{len(sent_payloads)}"} + + mock_memory_client.create_event.side_effect = track_create_event + + # Add 5 messages to the same session + for i in range(5): + message = SessionMessage( + message={"role": "user", "content": [{"text": f"Message_{i}"}]}, + message_id=i, + created_at=f"2024-01-01T12:0{i}:00Z", + ) + batching_session_manager.create_message("test-session-456", "test-agent", message) + + batching_session_manager._flush_messages() + + # Should be ONE API call with all 5 messages combined + assert mock_memory_client.create_event.call_count == 1 + assert len(sent_payloads) == 1 + # The combined payload should have all 5 messages + assert len(sent_payloads[0]) == 5 + # Messages should be in order + for i in range(5): + assert f"Message_{i}" in sent_payloads[0][i][0] + + def test_multiple_sessions_grouped_into_separate_api_calls(self, batching_session_manager, mock_memory_client): + """Test that messages to different sessions are grouped into separate API calls. + + Note: In normal usage, create_message enforces session_id == config.session_id, + so all messages go to one session. This test verifies the internal grouping logic + by directly manipulating the buffer. + """ + from datetime import datetime, timezone + + calls_by_session = {} + + def track_create_event(**kwargs): + session_id = kwargs.get("session_id") + messages = kwargs.get("messages") + calls_by_session[session_id] = messages + return {"eventId": f"event_{session_id}"} + + mock_memory_client.create_event.side_effect = track_create_event + + # Directly populate buffer with messages for multiple sessions + # Buffer format: (session_id, messages, is_blob, monotonic_timestamp) + base_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + batching_session_manager._message_buffer = [ + ("session-A", [("SessionA_Message_0", "user")], False, base_time), + ("session-A", [("SessionA_Message_1", "user")], False, base_time), + ("session-B", [("SessionB_Message_0", "user")], False, base_time), + ("session-B", [("SessionB_Message_1", "user")], False, base_time), + ("session-B", [("SessionB_Message_2", "user")], False, base_time), + ("session-A", [("SessionA_Message_2", "user")], False, base_time), # Non-consecutive + ] + + batching_session_manager._flush_messages() + + # Should be TWO API calls - one per session + assert mock_memory_client.create_event.call_count == 2 + assert len(calls_by_session) == 2 + + # Session A should have 3 messages combined + assert "session-A" in calls_by_session + assert len(calls_by_session["session-A"]) == 3 + assert calls_by_session["session-A"][0] == ("SessionA_Message_0", "user") + assert calls_by_session["session-A"][1] == ("SessionA_Message_1", "user") + assert calls_by_session["session-A"][2] == ("SessionA_Message_2", "user") + + # Session B should have 3 messages combined + assert "session-B" in calls_by_session + assert len(calls_by_session["session-B"]) == 3 + for i in range(3): + assert calls_by_session["session-B"][i] == (f"SessionB_Message_{i}", "user") + + def test_latest_timestamp_used_for_combined_events(self, batching_session_manager, mock_memory_client): + """Test that the latest timestamp from grouped messages is used for the combined event.""" + captured_timestamps = [] + + def track_create_event(**kwargs): + captured_timestamps.append(kwargs.get("event_timestamp")) + return {"eventId": "event_123"} + + mock_memory_client.create_event.side_effect = track_create_event + + # Add messages with different timestamps (out of order) + timestamps = ["2024-01-01T12:05:00Z", "2024-01-01T12:01:00Z", "2024-01-01T12:10:00Z"] + for i, ts in enumerate(timestamps): + message = SessionMessage( + message={"role": "user", "content": [{"text": f"Message_{i}"}]}, + message_id=i, + created_at=ts, + ) + batching_session_manager.create_message("test-session-456", "test-agent", message) + + batching_session_manager._flush_messages() + + # The combined event should use the latest timestamp (12:10:00) + assert len(captured_timestamps) == 1 + # The timestamp should be the latest one (12:10:00) + from datetime import datetime, timezone + + expected_latest = datetime(2024, 1, 1, 12, 10, 0, tzinfo=timezone.utc) + # Account for monotonic timestamp adjustment (may add microseconds) + assert captured_timestamps[0] >= expected_latest + + def test_partial_failure_multiple_sessions_preserves_buffer(self, batching_session_manager, mock_memory_client): + """Test that when one session fails, ALL messages remain in buffer. + + Note: Tests internal grouping logic by directly manipulating buffer. + """ + from datetime import datetime, timezone + + def fail_on_second_session(**kwargs): + session_id = kwargs.get("session_id") + if session_id == "session-B": + raise Exception("API Error for session B") + return {"eventId": f"event_{session_id}"} + + mock_memory_client.create_event.side_effect = fail_on_second_session + + # Directly populate buffer with messages for multiple sessions + base_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + batching_session_manager._message_buffer = [ + ("session-A", [("SessionA_Message_0", "user")], False, base_time), + ("session-A", [("SessionA_Message_1", "user")], False, base_time), + ("session-B", [("SessionB_Message_0", "user")], False, base_time), + ("session-B", [("SessionB_Message_1", "user")], False, base_time), + ] + + assert batching_session_manager.pending_message_count() == 4 + + # Flush should fail + with pytest.raises(SessionException, match="Failed to flush messages"): + batching_session_manager._flush_messages() + + # ALL messages should still be in buffer (even session A's which "succeeded") + # This is because buffer is only cleared after ALL succeed + assert batching_session_manager.pending_message_count() == 4 + + def test_blob_messages_sent_individually_not_batched(self, batching_session_manager, mock_memory_client): + """Test that multiple blob messages are sent as individual API calls, not batched.""" + blob_calls = [] + + def track_blob_event(**kwargs): + blob_calls.append(kwargs) + return {"event": {"eventId": f"blob_event_{len(blob_calls)}"}} + + mock_memory_client.gmdp_client.create_event.side_effect = track_blob_event + mock_memory_client.create_event.return_value = {"eventId": "conv_event"} + + # Add multiple blob messages (>9KB each) + for i in range(3): + large_text = f"blob_{i}_" + "x" * 10000 + message = SessionMessage( + message={"role": "user", "content": [{"text": large_text}]}, + message_id=i, + created_at=f"2024-01-01T12:0{i}:00Z", + ) + batching_session_manager.create_message("test-session-456", "test-agent", message) + + batching_session_manager._flush_messages() + + # Each blob should be sent individually (3 separate API calls) + assert mock_memory_client.gmdp_client.create_event.call_count == 3 + assert len(blob_calls) == 3 + + # Verify each blob was sent separately with correct content + for i, call in enumerate(blob_calls): + assert "payload" in call + assert "blob" in call["payload"][0] + assert f"blob_{i}_" in call["payload"][0]["blob"] + + def test_mixed_sessions_with_blobs_and_conversational(self, batching_session_manager, mock_memory_client): + """Test complex scenario: multiple sessions with both blob and conversational messages. + + Note: Tests internal grouping logic by directly manipulating buffer. + """ + import json + from datetime import datetime, timezone + + conv_calls = {} + blob_calls = [] + + def track_conv_event(**kwargs): + session_id = kwargs.get("session_id") + conv_calls[session_id] = kwargs.get("messages") + return {"eventId": f"conv_event_{session_id}"} + + def track_blob_event(**kwargs): + blob_calls.append(kwargs) + return {"event": {"eventId": f"blob_event_{len(blob_calls)}"}} + + mock_memory_client.create_event.side_effect = track_conv_event + mock_memory_client.gmdp_client.create_event.side_effect = track_blob_event + + # Directly populate buffer with mixed messages + base_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + blob_content = {"role": "user", "content": [{"text": "blob_A_" + "x" * 10000}]} + batching_session_manager._message_buffer = [ + # Session A: 2 conversational messages + ("session-A", [("SessionA_conv_0", "user")], False, base_time), + ("session-A", [("SessionA_conv_1", "user")], False, base_time), + # Session A: 1 blob message + ("session-A", [blob_content], True, base_time), + # Session B: 1 conversational message + ("session-B", [("SessionB_conv_0", "user")], False, base_time), + ] + + batching_session_manager._flush_messages() + + # Should have: + # - 2 conversational API calls (one per session) + # - 1 blob API call + assert mock_memory_client.create_event.call_count == 2 + assert mock_memory_client.gmdp_client.create_event.call_count == 1 + + # Session A conversational messages should be batched together + assert "session-A" in conv_calls + assert len(conv_calls["session-A"]) == 2 + + # Session B conversational message + assert "session-B" in conv_calls + assert len(conv_calls["session-B"]) == 1 + + # Blob sent separately + assert len(blob_calls) == 1 + assert "blob_A_" in blob_calls[0]["payload"][0]["blob"] + + +class TestBatchingBackwardsCompatibility: + """Test batch_size=1 behaves identically to previous implementation.""" + + def test_batch_size_one_sends_immediately(self, session_manager, mock_memory_client): + """Test batch_size=1 (default) sends message immediately.""" + mock_memory_client.create_event.return_value = {"eventId": "event_123"} + + message = SessionMessage( + message={"role": "user", "content": [{"text": "Hello"}]}, + message_id=1, + created_at="2024-01-01T12:00:00Z", + ) + + result = session_manager.create_message("test-session-456", "test-agent-123", message) + + # Should return event immediately + assert result.get("eventId") == "event_123" + # Should have sent immediately + mock_memory_client.create_event.assert_called_once() + # Buffer should be empty + assert session_manager.pending_message_count() == 0 + + def test_batch_size_one_returns_event_id(self, session_manager, mock_memory_client): + """Test batch_size=1 returns the event with eventId.""" + mock_memory_client.create_event.return_value = {"eventId": "unique_event_id"} + + message = SessionMessage( + message={"role": "user", "content": [{"text": "Hello"}]}, + message_id=1, + created_at="2024-01-01T12:00:00Z", + ) + + result = session_manager.create_message("test-session-456", "test-agent-123", message) + + assert "eventId" in result + assert result["eventId"] == "unique_event_id" + + +class TestBatchingContextManager: + """Test context manager (__enter__/__exit__) functionality.""" + + def test_context_manager_returns_self(self, batching_session_manager): + """Test __enter__ returns the session manager instance.""" + with batching_session_manager as ctx: + assert ctx is batching_session_manager + + def test_context_manager_flushes_on_exit(self, batching_session_manager, mock_memory_client): + """Test __exit__ flushes pending messages.""" + mock_memory_client.create_event.return_value = {"eventId": "event_123"} + + with batching_session_manager: + message = SessionMessage( + message={"role": "user", "content": [{"text": "Hello"}]}, + message_id=1, + created_at="2024-01-01T12:00:00Z", + ) + batching_session_manager.create_message("test-session-456", "test-agent", message) + + # Should still be buffered + assert batching_session_manager.pending_message_count() == 1 + + # After exiting context, should have flushed + assert batching_session_manager.pending_message_count() == 0 + mock_memory_client.create_event.assert_called_once() + + def test_context_manager_flushes_on_exception(self, batching_session_manager, mock_memory_client): + """Test __exit__ flushes even when exception occurs.""" + mock_memory_client.create_event.return_value = {"eventId": "event_123"} + + try: + with batching_session_manager: + message = SessionMessage( + message={"role": "user", "content": [{"text": "Hello"}]}, + message_id=1, + created_at="2024-01-01T12:00:00Z", + ) + batching_session_manager.create_message("test-session-456", "test-agent", message) + raise ValueError("Test exception") + except ValueError: + pass + + # Should have flushed despite exception + assert batching_session_manager.pending_message_count() == 0 + mock_memory_client.create_event.assert_called_once() + + def test_exit_preserves_original_exception_when_flush_fails(self, batching_session_manager, mock_memory_client, caplog): + """Test __exit__ logs flush failure and preserves the original exception.""" + mock_memory_client.create_event.side_effect = RuntimeError("flush failed") + + with caplog.at_level(logging.ERROR): + with pytest.raises(ValueError, match="original error"): + with batching_session_manager: + message = SessionMessage( + message={"role": "user", "content": [{"text": "Hello"}]}, + message_id=1, + created_at="2024-01-01T12:00:00Z", + ) + batching_session_manager.create_message("test-session-456", "test-agent", message) + raise ValueError("original error") + + assert any( + "Failed to flush messages during exception handling" in record.message and record.levelno == logging.ERROR + for record in caplog.records + ) + + def test_exit_raises_flush_exception_when_no_original_exception(self, batching_session_manager, mock_memory_client, caplog): + """Test __exit__ still raises flush exceptions when no original exception.""" + mock_memory_client.create_event.side_effect = RuntimeError("flush failed") + + with caplog.at_level(logging.ERROR): + with pytest.raises(SessionException, match="flush failed"): + with batching_session_manager: + message = SessionMessage( + message={"role": "user", "content": [{"text": "Hello"}]}, + message_id=1, + created_at="2024-01-01T12:00:00Z", + ) + batching_session_manager.create_message("test-session-456", "test-agent", message) + + assert not any( + "Failed to flush messages during exception handling" in record.message for record in caplog.records + ) + + +class TestBatchingClose: + """Test close() method functionality.""" + + def test_close_flushes_pending_messages(self, batching_session_manager, mock_memory_client): + """Test close() flushes all pending messages in a batched call.""" + mock_memory_client.create_event.return_value = {"eventId": "event_123"} + + # Add messages + for i in range(3): + message = SessionMessage( + message={"role": "user", "content": [{"text": f"Message {i}"}]}, + message_id=i, + created_at="2024-01-01T12:00:00Z", + ) + batching_session_manager.create_message("test-session-456", "test-agent", message) + + assert batching_session_manager.pending_message_count() == 3 + + # Close should flush + batching_session_manager.close() + + assert batching_session_manager.pending_message_count() == 0 + # One batched API call for all messages in the same session + assert mock_memory_client.create_event.call_count == 1 + + def test_close_with_empty_buffer(self, batching_session_manager, mock_memory_client): + """Test close() with empty buffer is a no-op.""" + batching_session_manager.close() + + mock_memory_client.create_event.assert_not_called() + assert batching_session_manager.pending_message_count() == 0 + + +class TestBatchingBlobMessages: + """Test batching handles blob messages (exceeding conversational limit) correctly.""" + + def test_blob_message_sent_via_gmdp_client(self, batching_session_manager, mock_memory_client): + """Test large messages (blobs) are sent via gmdp_client.""" + mock_memory_client.gmdp_client.create_event.return_value = {"event": {"eventId": "blob_event_123"}} + + # Create a message that exceeds CONVERSATIONAL_MAX_SIZE (9000) + large_text = "x" * 10000 + message = SessionMessage( + message={"role": "user", "content": [{"text": large_text}]}, + message_id=1, + created_at="2024-01-01T12:00:00Z", + ) + batching_session_manager.create_message("test-session-456", "test-agent", message) + + # Flush and verify blob path was used + batching_session_manager._flush_messages() + + mock_memory_client.gmdp_client.create_event.assert_called_once() + call_kwargs = mock_memory_client.gmdp_client.create_event.call_args.kwargs + assert "payload" in call_kwargs + assert "blob" in call_kwargs["payload"][0] + + def test_mixed_conversational_and_blob_messages(self, batching_session_manager, mock_memory_client): + """Test batching correctly handles mix of conversational and blob messages.""" + mock_memory_client.create_event.return_value = {"eventId": "conv_event"} + mock_memory_client.gmdp_client.create_event.return_value = {"event": {"eventId": "blob_event"}} + + # Add small (conversational) message + small_message = SessionMessage( + message={"role": "user", "content": [{"text": "Small message"}]}, + message_id=1, + created_at="2024-01-01T12:00:00Z", + ) + batching_session_manager.create_message("test-session-456", "test-agent", small_message) + + # Add large (blob) message + large_text = "x" * 10000 + large_message = SessionMessage( + message={"role": "user", "content": [{"text": large_text}]}, + message_id=2, + created_at="2024-01-01T12:01:00Z", + ) + batching_session_manager.create_message("test-session-456", "test-agent", large_message) + + # Flush + batching_session_manager._flush_messages() + + # Verify both paths were used + assert mock_memory_client.create_event.call_count == 1 # Conversational + assert mock_memory_client.gmdp_client.create_event.call_count == 1 # Blob diff --git a/tests/bedrock_agentcore/memory/integrations/strands/test_bedrock_converter.py b/tests/bedrock_agentcore/memory/integrations/strands/test_bedrock_converter.py index e15a457..2107c55 100644 --- a/tests/bedrock_agentcore/memory/integrations/strands/test_bedrock_converter.py +++ b/tests/bedrock_agentcore/memory/integrations/strands/test_bedrock_converter.py @@ -8,6 +8,19 @@ from bedrock_agentcore.memory.integrations.strands.bedrock_converter import AgentCoreMemoryConverter +def _make_conversational_event(session_messages): + """Build one event with multiple conversational payloads.""" + payloads = [] + for sm in session_messages: + payloads.append({ + "conversational": { + "content": {"text": json.dumps(sm.to_dict())}, + "role": sm.message["role"].upper(), + } + }) + return {"payload": payloads} + + class TestAgentCoreMemoryConverter: """Test cases for AgentCoreMemoryConverter.""" @@ -221,3 +234,140 @@ def test_message_to_payload_with_bytes_encodes_before_filtering(self): assert isinstance(encoded_bytes, dict) assert encoded_bytes.get("__bytes_encoded__") is True assert "data" in encoded_bytes + + # --- Ordering tests for events_to_messages --- + + def test_events_to_messages_empty_events(self): + """Test that empty input returns empty output.""" + result = AgentCoreMemoryConverter.events_to_messages([]) + assert result == [] + + def test_events_to_messages_multiple_events_chronological_order(self): + """Test two single-payload events in reverse chronological order produce chronological result.""" + msg_first = SessionMessage( + message_id=1, message={"role": "user", "content": [{"text": "First"}]}, created_at="2023-01-01T00:00:00Z" + ) + msg_second = SessionMessage( + message_id=2, + message={"role": "assistant", "content": [{"text": "Second"}]}, + created_at="2023-01-01T00:00:01Z", + ) + + # API returns newest first + event_newer = _make_conversational_event([msg_second]) + event_older = _make_conversational_event([msg_first]) + events = [event_newer, event_older] + + result = AgentCoreMemoryConverter.events_to_messages(events) + + assert len(result) == 2 + assert result[0].message["content"][0]["text"] == "First" + assert result[1].message["content"][0]["text"] == "Second" + + def test_events_to_messages_single_event_multiple_payloads_preserves_order(self): + """Test one event with 3 conversational payloads preserves payload order.""" + msgs = [ + SessionMessage( + message_id=i, message={"role": "user", "content": [{"text": f"msg{i}"}]}, created_at="2023-01-01T00:00:00Z" + ) + for i in range(1, 4) + ] + + event = _make_conversational_event(msgs) + result = AgentCoreMemoryConverter.events_to_messages([event]) + + assert len(result) == 3 + assert result[0].message["content"][0]["text"] == "msg1" + assert result[1].message["content"][0]["text"] == "msg2" + assert result[2].message["content"][0]["text"] == "msg3" + + def test_events_to_messages_multiple_batched_events_ordering(self): + """Test two multi-payload events: event order reversed, intra-event payload order preserved. + + This is the exact scenario that the original reverse-after-flatten bug broke. + """ + msg1 = SessionMessage( + message_id=1, message={"role": "user", "content": [{"text": "msg1"}]}, created_at="2023-01-01T00:00:00Z" + ) + msg2 = SessionMessage( + message_id=2, + message={"role": "assistant", "content": [{"text": "msg2"}]}, + created_at="2023-01-01T00:00:01Z", + ) + msg3 = SessionMessage( + message_id=3, message={"role": "user", "content": [{"text": "msg3"}]}, created_at="2023-01-01T00:00:02Z" + ) + msg4 = SessionMessage( + message_id=4, + message={"role": "assistant", "content": [{"text": "msg4"}]}, + created_at="2023-01-01T00:00:03Z", + ) + + # API returns newest event first + event_newer = _make_conversational_event([msg3, msg4]) + event_older = _make_conversational_event([msg1, msg2]) + events = [event_newer, event_older] + + result = AgentCoreMemoryConverter.events_to_messages(events) + + assert len(result) == 4 + assert result[0].message["content"][0]["text"] == "msg1" + assert result[1].message["content"][0]["text"] == "msg2" + assert result[2].message["content"][0]["text"] == "msg3" + assert result[3].message["content"][0]["text"] == "msg4" + + def test_events_to_messages_mixed_blob_and_conversational_ordering(self): + """Test blob and conversational events in reverse chronological order produce chronological result.""" + msg_first = SessionMessage( + message_id=1, message={"role": "user", "content": [{"text": "First"}]}, created_at="2023-01-01T00:00:00Z" + ) + msg_second = SessionMessage( + message_id=2, + message={"role": "assistant", "content": [{"text": "Second"}]}, + created_at="2023-01-01T00:00:01Z", + ) + + # Newer event uses blob format, older event uses conversational format + blob_data = [json.dumps(msg_second.to_dict()), "assistant"] + event_newer = {"payload": [{"blob": json.dumps(blob_data)}]} + event_older = _make_conversational_event([msg_first]) + events = [event_newer, event_older] + + result = AgentCoreMemoryConverter.events_to_messages(events) + + assert len(result) == 2 + assert result[0].message["content"][0]["text"] == "First" + assert result[1].message["content"][0]["text"] == "Second" + + @patch("bedrock_agentcore.memory.integrations.strands.bedrock_converter.logger") + def test_events_to_messages_malformed_payload_does_not_break_batch(self, mock_logger): + """Test a malformed blob payload between two valid conversational payloads in a single event.""" + msg1 = SessionMessage( + message_id=1, message={"role": "user", "content": [{"text": "msg1"}]}, created_at="2023-01-01T00:00:00Z" + ) + msg3 = SessionMessage( + message_id=3, message={"role": "user", "content": [{"text": "msg3"}]}, created_at="2023-01-01T00:00:02Z" + ) + + conv1 = { + "conversational": { + "content": {"text": json.dumps(msg1.to_dict())}, + "role": "USER", + } + } + bad_blob = {"blob": "invalid json"} + conv3 = { + "conversational": { + "content": {"text": json.dumps(msg3.to_dict())}, + "role": "USER", + } + } + + events = [{"payload": [conv1, bad_blob, conv3]}] + + result = AgentCoreMemoryConverter.events_to_messages(events) + + assert len(result) == 2 + assert result[0].message["content"][0]["text"] == "msg1" + assert result[1].message["content"][0]["text"] == "msg3" + mock_logger.error.assert_called()