diff --git a/src/strands/hooks/bedrock.py b/src/strands/hooks/bedrock.py new file mode 100644 index 000000000..6fab80247 --- /dev/null +++ b/src/strands/hooks/bedrock.py @@ -0,0 +1,134 @@ +"""Bedrock-specific hooks for AWS Bedrock features. + +This module provides hook implementations for AWS Bedrock-specific functionality, +such as automatic prompt caching management. +""" + +import logging +from typing import Any + +from ..types.content import ContentBlock +from . import HookProvider, HookRegistry +from .events import AfterModelCallEvent, BeforeModelCallEvent + +logger = logging.getLogger(__name__) + +# Cache point object for Bedrock prompt caching +# See: https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html +CACHE_POINT_ITEM: ContentBlock = {"cachePoint": {"type": "default"}} + + +class PromptCachingHook(HookProvider): + """Hook provider for automatic Bedrock prompt caching management. + + This hook automatically manages cache points for AWS Bedrock's prompt caching feature. + It adds a cache point to the last message before model invocation and removes it + after the invocation completes. + + AWS Bedrock supports up to 4 cache points per request. This hook adds one cache point + to enable the "Simplified Cache Management" feature for Claude models, which automatically + checks for cache hits at content block boundaries (looking back approximately 20 content + blocks from the cache checkpoint). + + Important Considerations: + - This hook adds a cache point to the last message's content array + - Bedrock has a maximum of 4 cache points per request + - Claude models require minimum token counts (e.g., 1,024 for Claude 3.7 Sonnet) + - Cache TTL is 5 minutes from the last access + + Example: + ```python + from strands import Agent + from strands.models import BedrockModel + from strands.hooks.bedrock import PromptCachingHook + + model = BedrockModel(model_id="us.anthropic.claude-sonnet-4-20250514-v1:0") + agent = Agent( + model=model, + hooks=[PromptCachingHook()] + ) + ``` + + See Also: + - AWS Bedrock Prompt Caching: https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html + - Strands Agents Hooks: https://strandsagents.com/latest/documentation/docs/user-guide/concepts/agents/hooks/ + """ + + def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: + """Register hook callbacks with the registry. + + Args: + registry: The hook registry to register callbacks with. + **kwargs: Additional keyword arguments for future extensibility. + """ + registry.add_callback(BeforeModelCallEvent, self.on_invocation_start) + registry.add_callback(AfterModelCallEvent, self.on_invocation_end) + + def on_invocation_start(self, event: BeforeModelCallEvent) -> None: + """Add cache point before model invocation. + + This callback is triggered before the model is invoked. It adds a cache point + to the last message's content array to enable prompt caching. + + Args: + event: The before model call event containing the agent and its messages. + + Note: + If the messages list is empty, this method logs a warning and returns + without modifying the messages. + """ + messages = event.agent.messages + + if len(messages) == 0: + logger.warning("Cannot add cache point: messages list is empty") + return + + last_message = messages[-1] + content = last_message["content"] + + content.append(CACHE_POINT_ITEM) + logger.debug( + "Added cache point to message | message_index=%d | role=%s | content_blocks=%d", + len(messages) - 1, + last_message.get("role", "unknown"), + len(content), + ) + + def on_invocation_end(self, event: AfterModelCallEvent) -> None: + """Remove cache point after model invocation. + + This callback is triggered after the model invocation completes. It removes + the cache point that was added in on_invocation_start to keep the message + history clean. + + Args: + event: The after model call event containing the agent and its messages. + + Note: + If the cache point is not found in the last message's content array, + this method logs a warning but does not raise an exception. + """ + messages = event.agent.messages + + if len(messages) == 0: + logger.warning("Cannot remove cache point: messages list is empty") + return + + last_message = messages[-1] + content = last_message["content"] + + try: + content.remove(CACHE_POINT_ITEM) + logger.debug( + "Removed cache point from message | message_index=%d | role=%s | content_blocks=%d", + len(messages) - 1, + last_message.get("role", "unknown"), + len(content), + ) + except ValueError: + logger.warning( + "Cache point not found in content | message_index=%d | role=%s | content_blocks=%d", + len(messages) - 1, + last_message.get("role", "unknown"), + len(content), + ) diff --git a/tests/strands/hooks/test_bedrock.py b/tests/strands/hooks/test_bedrock.py new file mode 100644 index 000000000..1cf16418e --- /dev/null +++ b/tests/strands/hooks/test_bedrock.py @@ -0,0 +1,242 @@ +"""Unit tests for Bedrock-specific hooks.""" + +import unittest.mock +from unittest.mock import Mock + +import pytest + +from strands.hooks import HookRegistry +from strands.hooks.bedrock import CACHE_POINT_ITEM, PromptCachingHook +from strands.hooks.events import AfterModelCallEvent, BeforeModelCallEvent + + +@pytest.fixture +def hook(): + """Create a PromptCachingHook instance.""" + return PromptCachingHook() + + +@pytest.fixture +def mock_agent(): + """Create a mock agent with a messages list.""" + agent = Mock() + agent.messages = [] + return agent + + +@pytest.fixture +def before_event(mock_agent): + """Create a BeforeModelCallEvent with a mock agent.""" + return BeforeModelCallEvent(agent=mock_agent) + + +@pytest.fixture +def after_event(mock_agent): + """Create an AfterModelCallEvent with a mock agent.""" + return AfterModelCallEvent(agent=mock_agent) + + +class TestPromptCachingHookRegistration: + """Test hook registration functionality.""" + + def test_register_hooks(self, hook): + """Test that register_hooks adds callbacks to the registry.""" + registry = HookRegistry() + hook.register_hooks(registry) + + # Verify callbacks are registered + assert BeforeModelCallEvent in registry._registered_callbacks + assert AfterModelCallEvent in registry._registered_callbacks + assert len(registry._registered_callbacks[BeforeModelCallEvent]) == 1 + assert len(registry._registered_callbacks[AfterModelCallEvent]) == 1 + + +class TestPromptCachingHookAddCachePoint: + """Test adding cache points before model invocation.""" + + def test_add_cache_point_success(self, hook, before_event): + """Test successfully adding a cache point to the last message.""" + before_event.agent.messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + ] + + hook.on_invocation_start(before_event) + + # Verify cache point was added + assert len(before_event.agent.messages[-1]["content"]) == 2 + assert before_event.agent.messages[-1]["content"][-1] == CACHE_POINT_ITEM + + def test_add_cache_point_to_message_with_multiple_content_blocks(self, hook, before_event): + """Test adding cache point to a message with multiple content blocks.""" + before_event.agent.messages = [ + { + "role": "user", + "content": [ + {"text": "First block"}, + {"text": "Second block"}, + {"image": {"format": "png", "source": {"bytes": b"data"}}}, + ], + }, + ] + + hook.on_invocation_start(before_event) + + # Verify cache point was added at the end + assert len(before_event.agent.messages[-1]["content"]) == 4 + assert before_event.agent.messages[-1]["content"][-1] == CACHE_POINT_ITEM + # Verify original content is intact + assert before_event.agent.messages[-1]["content"][0] == {"text": "First block"} + + def test_add_cache_point_with_multiple_messages(self, hook, before_event): + """Test that cache point is added only to the last message.""" + before_event.agent.messages = [ + {"role": "user", "content": [{"text": "First message"}]}, + {"role": "assistant", "content": [{"text": "Response"}]}, + {"role": "user", "content": [{"text": "Second message"}]}, + ] + + hook.on_invocation_start(before_event) + + # Verify cache point was added only to the last message + assert len(before_event.agent.messages[0]["content"]) == 1 + assert len(before_event.agent.messages[1]["content"]) == 1 + assert len(before_event.agent.messages[2]["content"]) == 2 + assert before_event.agent.messages[2]["content"][-1] == CACHE_POINT_ITEM + + def test_add_cache_point_empty_messages_list(self, hook, before_event): + """Test handling of empty messages list.""" + before_event.agent.messages = [] + + with unittest.mock.patch("strands.hooks.bedrock.logger") as mock_logger: + hook.on_invocation_start(before_event) + mock_logger.warning.assert_called_once_with("Cannot add cache point: messages list is empty") + + # Verify no error was raised and messages remain empty + assert before_event.agent.messages == [] + + +class TestPromptCachingHookRemoveCachePoint: + """Test removing cache points after model invocation.""" + + def test_remove_cache_point_success(self, hook, after_event): + """Test successfully removing a cache point from the last message.""" + after_event.agent.messages = [ + {"role": "user", "content": [{"text": "Hello"}, CACHE_POINT_ITEM]}, + ] + + hook.on_invocation_end(after_event) + + # Verify cache point was removed + assert len(after_event.agent.messages[-1]["content"]) == 1 + assert after_event.agent.messages[-1]["content"][0] == {"text": "Hello"} + + def test_remove_cache_point_from_message_with_multiple_blocks(self, hook, after_event): + """Test removing cache point from a message with multiple content blocks.""" + after_event.agent.messages = [ + { + "role": "user", + "content": [ + {"text": "First block"}, + {"text": "Second block"}, + CACHE_POINT_ITEM, + ], + }, + ] + + hook.on_invocation_end(after_event) + + # Verify only cache point was removed + assert len(after_event.agent.messages[-1]["content"]) == 2 + assert after_event.agent.messages[-1]["content"][0] == {"text": "First block"} + assert after_event.agent.messages[-1]["content"][1] == {"text": "Second block"} + + def test_remove_cache_point_with_multiple_messages(self, hook, after_event): + """Test that cache point is removed only from the last message.""" + after_event.agent.messages = [ + {"role": "user", "content": [{"text": "First message"}, CACHE_POINT_ITEM]}, + {"role": "assistant", "content": [{"text": "Response"}]}, + {"role": "user", "content": [{"text": "Second message"}, CACHE_POINT_ITEM]}, + ] + + hook.on_invocation_end(after_event) + + # Verify cache point was removed only from the last message + assert len(after_event.agent.messages[0]["content"]) == 2 # Unchanged + assert len(after_event.agent.messages[1]["content"]) == 1 # No cache point + assert len(after_event.agent.messages[2]["content"]) == 1 # Cache point removed + + def test_remove_cache_point_empty_messages_list(self, hook, after_event): + """Test handling of empty messages list.""" + after_event.agent.messages = [] + + with unittest.mock.patch("strands.hooks.bedrock.logger") as mock_logger: + hook.on_invocation_end(after_event) + mock_logger.warning.assert_called_once_with("Cannot remove cache point: messages list is empty") + + # Verify no error was raised + assert after_event.agent.messages == [] + + def test_remove_cache_point_not_found(self, hook, after_event): + """Test handling when cache point is not found in content.""" + after_event.agent.messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, # No cache point + ] + + with unittest.mock.patch("strands.hooks.bedrock.logger") as mock_logger: + hook.on_invocation_end(after_event) + mock_logger.warning.assert_called_once() + assert "Cache point not found" in mock_logger.warning.call_args[0][0] + + # Verify content was not modified + assert after_event.agent.messages[0]["content"] == [{"text": "Hello"}] + + +class TestPromptCachingHookEndToEnd: + """Test end-to-end scenarios with both add and remove operations.""" + + def test_add_and_remove_cache_point_lifecycle(self, hook, mock_agent): + """Test the full lifecycle of adding and removing a cache point.""" + mock_agent.messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + ] + + # Add cache point + before_event = BeforeModelCallEvent(agent=mock_agent) + hook.on_invocation_start(before_event) + + # Verify cache point was added + assert len(mock_agent.messages[-1]["content"]) == 2 + assert mock_agent.messages[-1]["content"][-1] == CACHE_POINT_ITEM + + # Remove cache point + after_event = AfterModelCallEvent(agent=mock_agent) + hook.on_invocation_end(after_event) + + # Verify cache point was removed and original content is intact + assert len(mock_agent.messages[-1]["content"]) == 1 + assert mock_agent.messages[-1]["content"][0] == {"text": "Hello"} + + def test_logging_on_successful_operations(self, hook, mock_agent): + """Test that debug logs are generated on successful operations.""" + mock_agent.messages = [ + {"role": "user", "content": [{"text": "Test"}]}, + ] + + with unittest.mock.patch("strands.hooks.bedrock.logger") as mock_logger: + # Add cache point + before_event = BeforeModelCallEvent(agent=mock_agent) + hook.on_invocation_start(before_event) + + # Verify debug log for adding + mock_logger.debug.assert_called_once() + assert "Added cache point" in mock_logger.debug.call_args[0][0] + + mock_logger.reset_mock() + + # Remove cache point + after_event = AfterModelCallEvent(agent=mock_agent) + hook.on_invocation_end(after_event) + + # Verify debug log for removing + mock_logger.debug.assert_called_once() + assert "Removed cache point" in mock_logger.debug.call_args[0][0]