From d1c3eebbe08eb39bc7eea47a4ee32f6fd7d77308 Mon Sep 17 00:00:00 2001 From: moritalous Date: Sun, 18 Jan 2026 05:26:20 +0000 Subject: [PATCH 1/2] feat: add PromptCachingHook for automatic Bedrock prompt caching management - Implement automatic cache point management for Amazon Bedrock - Add cache points before model invocation (BeforeModelCallEvent) - Remove cache points after invocation (AfterModelCallEvent) - Include comprehensive error handling and logging - Add 16 unit tests with full coverage Closes #1508 --- src/strands/hooks/bedrock.py | 164 +++++++++++++++ tests/strands/hooks/test_bedrock.py | 298 ++++++++++++++++++++++++++++ 2 files changed, 462 insertions(+) create mode 100644 src/strands/hooks/bedrock.py create mode 100644 tests/strands/hooks/test_bedrock.py diff --git a/src/strands/hooks/bedrock.py b/src/strands/hooks/bedrock.py new file mode 100644 index 000000000..5a0a3671a --- /dev/null +++ b/src/strands/hooks/bedrock.py @@ -0,0 +1,164 @@ +"""Bedrock-specific hooks for AWS Bedrock features. + +This module provides hook implementations for AWS Bedrock-specific functionality, +such as automatic prompt caching management. +""" + +import logging +from typing import Any + +from . import HookProvider, HookRegistry +from .events import AfterModelCallEvent, BeforeModelCallEvent + +logger = logging.getLogger(__name__) + +# Cache point object for Bedrock prompt caching +# See: https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html +CACHE_POINT_ITEM: dict[str, Any] = {"cachePoint": {"type": "default"}} + + +class PromptCachingHook(HookProvider): + """Hook provider for automatic Bedrock prompt caching management. + + This hook automatically manages cache points for AWS Bedrock's prompt caching feature. + It adds a cache point to the last message before model invocation and removes it + after the invocation completes. + + AWS Bedrock supports up to 4 cache points per request. This hook adds one cache point + to enable the "Simplified Cache Management" feature for Claude models, which automatically + checks for cache hits at content block boundaries (looking back approximately 20 content + blocks from the cache checkpoint). + + Important Considerations: + - This hook adds a cache point to the last message's content array + - Bedrock has a maximum of 4 cache points per request + - Claude models require minimum token counts (e.g., 1,024 for Claude 3.7 Sonnet) + - Cache TTL is 5 minutes from the last access + + Example: + ```python + from strands import Agent + from strands.models import BedrockModel + from strands.hooks.bedrock import PromptCachingHook + + model = BedrockModel(model_id="us.anthropic.claude-sonnet-4-20250514-v1:0") + agent = Agent( + model=model, + hooks=[PromptCachingHook()] + ) + ``` + + See Also: + - AWS Bedrock Prompt Caching: https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html + - Strands Agents Hooks: https://strandsagents.com/latest/documentation/docs/user-guide/concepts/agents/hooks/ + """ + + def register_hooks(self, registry: HookRegistry) -> None: + """Register hook callbacks with the registry. + + Args: + registry: The hook registry to register callbacks with. + """ + registry.add_callback(BeforeModelCallEvent, self.on_invocation_start) + registry.add_callback(AfterModelCallEvent, self.on_invocation_end) + + def on_invocation_start(self, event: BeforeModelCallEvent) -> None: + """Add cache point before model invocation. + + This callback is triggered before the model is invoked. It adds a cache point + to the last message's content array to enable prompt caching. + + Args: + event: The before model call event containing the agent and its messages. + + Note: + If the messages list is empty or the last message has no content array, + this method logs a warning and returns without modifying the messages. + """ + messages = event.agent.messages + + # Validate messages structure + if not messages: + logger.warning("Cannot add cache point: messages list is empty") + return + + last_message = messages[-1] + if "content" not in last_message: + logger.warning( + "Cannot add cache point: last message has no content field | role=%s", + last_message.get("role", "unknown"), + ) + return + + content = last_message["content"] + if not isinstance(content, list): + logger.warning( + "Cannot add cache point: content is not a list | type=%s | role=%s", + type(content).__name__, + last_message.get("role", "unknown"), + ) + return + + # Add cache point to the end of the last message's content + content.append(CACHE_POINT_ITEM) + logger.debug( + "Added cache point to message | message_index=%d | role=%s | content_blocks=%d", + len(messages) - 1, + last_message.get("role", "unknown"), + len(content), + ) + + def on_invocation_end(self, event: AfterModelCallEvent) -> None: + """Remove cache point after model invocation. + + This callback is triggered after the model invocation completes. It removes + the cache point that was added in on_invocation_start to keep the message + history clean. + + Args: + event: The after model call event containing the agent and its messages. + + Note: + If the cache point is not found in the last message's content array, + this method logs a warning but does not raise an exception. + """ + messages = event.agent.messages + + # Validate messages structure + if not messages: + logger.warning("Cannot remove cache point: messages list is empty") + return + + last_message = messages[-1] + if "content" not in last_message: + logger.warning( + "Cannot remove cache point: last message has no content field | role=%s", + last_message.get("role", "unknown"), + ) + return + + content = last_message["content"] + if not isinstance(content, list): + logger.warning( + "Cannot remove cache point: content is not a list | type=%s | role=%s", + type(content).__name__, + last_message.get("role", "unknown"), + ) + return + + # Remove cache point from the last message's content + try: + content.remove(CACHE_POINT_ITEM) + logger.debug( + "Removed cache point from message | message_index=%d | role=%s | content_blocks=%d", + len(messages) - 1, + last_message.get("role", "unknown"), + len(content), + ) + except ValueError: + logger.warning( + "Cache point not found in content | message_index=%d | role=%s | content_blocks=%d", + len(messages) - 1, + last_message.get("role", "unknown"), + len(content), + ) diff --git a/tests/strands/hooks/test_bedrock.py b/tests/strands/hooks/test_bedrock.py new file mode 100644 index 000000000..312188e58 --- /dev/null +++ b/tests/strands/hooks/test_bedrock.py @@ -0,0 +1,298 @@ +"""Unit tests for Bedrock-specific hooks.""" + +import unittest.mock +from unittest.mock import Mock + +import pytest + +from strands.hooks import HookRegistry +from strands.hooks.bedrock import CACHE_POINT_ITEM, PromptCachingHook +from strands.hooks.events import AfterModelCallEvent, BeforeModelCallEvent + + +@pytest.fixture +def hook(): + """Create a PromptCachingHook instance.""" + return PromptCachingHook() + + +@pytest.fixture +def mock_agent(): + """Create a mock agent with a messages list.""" + agent = Mock() + agent.messages = [] + return agent + + +@pytest.fixture +def before_event(mock_agent): + """Create a BeforeModelCallEvent with a mock agent.""" + return BeforeModelCallEvent(agent=mock_agent) + + +@pytest.fixture +def after_event(mock_agent): + """Create an AfterModelCallEvent with a mock agent.""" + return AfterModelCallEvent(agent=mock_agent) + + +class TestPromptCachingHookRegistration: + """Test hook registration functionality.""" + + def test_register_hooks(self, hook): + """Test that register_hooks adds callbacks to the registry.""" + registry = HookRegistry() + hook.register_hooks(registry) + + # Verify callbacks are registered + assert BeforeModelCallEvent in registry._registered_callbacks + assert AfterModelCallEvent in registry._registered_callbacks + assert len(registry._registered_callbacks[BeforeModelCallEvent]) == 1 + assert len(registry._registered_callbacks[AfterModelCallEvent]) == 1 + + +class TestPromptCachingHookAddCachePoint: + """Test adding cache points before model invocation.""" + + def test_add_cache_point_success(self, hook, before_event): + """Test successfully adding a cache point to the last message.""" + before_event.agent.messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + ] + + hook.on_invocation_start(before_event) + + # Verify cache point was added + assert len(before_event.agent.messages[-1]["content"]) == 2 + assert before_event.agent.messages[-1]["content"][-1] == CACHE_POINT_ITEM + + def test_add_cache_point_to_message_with_multiple_content_blocks(self, hook, before_event): + """Test adding cache point to a message with multiple content blocks.""" + before_event.agent.messages = [ + { + "role": "user", + "content": [ + {"text": "First block"}, + {"text": "Second block"}, + {"image": {"format": "png", "source": {"bytes": b"data"}}}, + ], + }, + ] + + hook.on_invocation_start(before_event) + + # Verify cache point was added at the end + assert len(before_event.agent.messages[-1]["content"]) == 4 + assert before_event.agent.messages[-1]["content"][-1] == CACHE_POINT_ITEM + # Verify original content is intact + assert before_event.agent.messages[-1]["content"][0] == {"text": "First block"} + + def test_add_cache_point_with_multiple_messages(self, hook, before_event): + """Test that cache point is added only to the last message.""" + before_event.agent.messages = [ + {"role": "user", "content": [{"text": "First message"}]}, + {"role": "assistant", "content": [{"text": "Response"}]}, + {"role": "user", "content": [{"text": "Second message"}]}, + ] + + hook.on_invocation_start(before_event) + + # Verify cache point was added only to the last message + assert len(before_event.agent.messages[0]["content"]) == 1 + assert len(before_event.agent.messages[1]["content"]) == 1 + assert len(before_event.agent.messages[2]["content"]) == 2 + assert before_event.agent.messages[2]["content"][-1] == CACHE_POINT_ITEM + + def test_add_cache_point_empty_messages_list(self, hook, before_event): + """Test handling of empty messages list.""" + before_event.agent.messages = [] + + with unittest.mock.patch("strands.hooks.bedrock.logger") as mock_logger: + hook.on_invocation_start(before_event) + mock_logger.warning.assert_called_once_with("Cannot add cache point: messages list is empty") + + # Verify no error was raised and messages remain empty + assert before_event.agent.messages == [] + + def test_add_cache_point_message_without_content_field(self, hook, before_event): + """Test handling of message without content field.""" + before_event.agent.messages = [ + {"role": "user"}, # No content field + ] + + with unittest.mock.patch("strands.hooks.bedrock.logger") as mock_logger: + hook.on_invocation_start(before_event) + mock_logger.warning.assert_called_once() + assert "no content field" in mock_logger.warning.call_args[0][0] + + # Verify message was not modified + assert "content" not in before_event.agent.messages[0] + + def test_add_cache_point_content_not_a_list(self, hook, before_event): + """Test handling of content that is not a list.""" + before_event.agent.messages = [ + {"role": "user", "content": "This should be a list"}, + ] + + with unittest.mock.patch("strands.hooks.bedrock.logger") as mock_logger: + hook.on_invocation_start(before_event) + mock_logger.warning.assert_called_once() + assert "content is not a list" in mock_logger.warning.call_args[0][0] + + # Verify content was not modified + assert before_event.agent.messages[0]["content"] == "This should be a list" + + +class TestPromptCachingHookRemoveCachePoint: + """Test removing cache points after model invocation.""" + + def test_remove_cache_point_success(self, hook, after_event): + """Test successfully removing a cache point from the last message.""" + after_event.agent.messages = [ + {"role": "user", "content": [{"text": "Hello"}, CACHE_POINT_ITEM]}, + ] + + hook.on_invocation_end(after_event) + + # Verify cache point was removed + assert len(after_event.agent.messages[-1]["content"]) == 1 + assert after_event.agent.messages[-1]["content"][0] == {"text": "Hello"} + + def test_remove_cache_point_from_message_with_multiple_blocks(self, hook, after_event): + """Test removing cache point from a message with multiple content blocks.""" + after_event.agent.messages = [ + { + "role": "user", + "content": [ + {"text": "First block"}, + {"text": "Second block"}, + CACHE_POINT_ITEM, + ], + }, + ] + + hook.on_invocation_end(after_event) + + # Verify only cache point was removed + assert len(after_event.agent.messages[-1]["content"]) == 2 + assert after_event.agent.messages[-1]["content"][0] == {"text": "First block"} + assert after_event.agent.messages[-1]["content"][1] == {"text": "Second block"} + + def test_remove_cache_point_with_multiple_messages(self, hook, after_event): + """Test that cache point is removed only from the last message.""" + after_event.agent.messages = [ + {"role": "user", "content": [{"text": "First message"}, CACHE_POINT_ITEM]}, + {"role": "assistant", "content": [{"text": "Response"}]}, + {"role": "user", "content": [{"text": "Second message"}, CACHE_POINT_ITEM]}, + ] + + hook.on_invocation_end(after_event) + + # Verify cache point was removed only from the last message + assert len(after_event.agent.messages[0]["content"]) == 2 # Unchanged + assert len(after_event.agent.messages[1]["content"]) == 1 # No cache point + assert len(after_event.agent.messages[2]["content"]) == 1 # Cache point removed + + def test_remove_cache_point_empty_messages_list(self, hook, after_event): + """Test handling of empty messages list.""" + after_event.agent.messages = [] + + with unittest.mock.patch("strands.hooks.bedrock.logger") as mock_logger: + hook.on_invocation_end(after_event) + mock_logger.warning.assert_called_once_with("Cannot remove cache point: messages list is empty") + + # Verify no error was raised + assert after_event.agent.messages == [] + + def test_remove_cache_point_message_without_content_field(self, hook, after_event): + """Test handling of message without content field.""" + after_event.agent.messages = [ + {"role": "user"}, # No content field + ] + + with unittest.mock.patch("strands.hooks.bedrock.logger") as mock_logger: + hook.on_invocation_end(after_event) + mock_logger.warning.assert_called_once() + assert "no content field" in mock_logger.warning.call_args[0][0] + + # Verify message was not modified + assert "content" not in after_event.agent.messages[0] + + def test_remove_cache_point_content_not_a_list(self, hook, after_event): + """Test handling of content that is not a list.""" + after_event.agent.messages = [ + {"role": "user", "content": "This should be a list"}, + ] + + with unittest.mock.patch("strands.hooks.bedrock.logger") as mock_logger: + hook.on_invocation_end(after_event) + mock_logger.warning.assert_called_once() + assert "content is not a list" in mock_logger.warning.call_args[0][0] + + # Verify content was not modified + assert after_event.agent.messages[0]["content"] == "This should be a list" + + def test_remove_cache_point_not_found(self, hook, after_event): + """Test handling when cache point is not found in content.""" + after_event.agent.messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, # No cache point + ] + + with unittest.mock.patch("strands.hooks.bedrock.logger") as mock_logger: + hook.on_invocation_end(after_event) + mock_logger.warning.assert_called_once() + assert "Cache point not found" in mock_logger.warning.call_args[0][0] + + # Verify content was not modified + assert after_event.agent.messages[0]["content"] == [{"text": "Hello"}] + + +class TestPromptCachingHookEndToEnd: + """Test end-to-end scenarios with both add and remove operations.""" + + def test_add_and_remove_cache_point_lifecycle(self, hook, mock_agent): + """Test the full lifecycle of adding and removing a cache point.""" + mock_agent.messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + ] + + # Add cache point + before_event = BeforeModelCallEvent(agent=mock_agent) + hook.on_invocation_start(before_event) + + # Verify cache point was added + assert len(mock_agent.messages[-1]["content"]) == 2 + assert mock_agent.messages[-1]["content"][-1] == CACHE_POINT_ITEM + + # Remove cache point + after_event = AfterModelCallEvent(agent=mock_agent) + hook.on_invocation_end(after_event) + + # Verify cache point was removed and original content is intact + assert len(mock_agent.messages[-1]["content"]) == 1 + assert mock_agent.messages[-1]["content"][0] == {"text": "Hello"} + + def test_logging_on_successful_operations(self, hook, mock_agent): + """Test that debug logs are generated on successful operations.""" + mock_agent.messages = [ + {"role": "user", "content": [{"text": "Test"}]}, + ] + + with unittest.mock.patch("strands.hooks.bedrock.logger") as mock_logger: + # Add cache point + before_event = BeforeModelCallEvent(agent=mock_agent) + hook.on_invocation_start(before_event) + + # Verify debug log for adding + mock_logger.debug.assert_called_once() + assert "Added cache point" in mock_logger.debug.call_args[0][0] + + mock_logger.reset_mock() + + # Remove cache point + after_event = AfterModelCallEvent(agent=mock_agent) + hook.on_invocation_end(after_event) + + # Verify debug log for removing + mock_logger.debug.assert_called_once() + assert "Removed cache point" in mock_logger.debug.call_args[0][0] From 057f1a10614c6d4132ffb89de04b466c8fea4624 Mon Sep 17 00:00:00 2001 From: moritalous Date: Sun, 18 Jan 2026 05:56:22 +0000 Subject: [PATCH 2/2] fix: simplify bedrock hook by trusting type definitions - Remove redundant error handling for Message.content (guaranteed by TypedDict) - Remove corresponding unit tests for impossible edge cases - Change CACHE_POINT_ITEM type from Any to ContentBlock for type safety --- src/strands/hooks/bedrock.py | 46 +++++------------------- tests/strands/hooks/test_bedrock.py | 56 ----------------------------- 2 files changed, 8 insertions(+), 94 deletions(-) diff --git a/src/strands/hooks/bedrock.py b/src/strands/hooks/bedrock.py index 5a0a3671a..6fab80247 100644 --- a/src/strands/hooks/bedrock.py +++ b/src/strands/hooks/bedrock.py @@ -7,6 +7,7 @@ import logging from typing import Any +from ..types.content import ContentBlock from . import HookProvider, HookRegistry from .events import AfterModelCallEvent, BeforeModelCallEvent @@ -14,7 +15,7 @@ # Cache point object for Bedrock prompt caching # See: https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html -CACHE_POINT_ITEM: dict[str, Any] = {"cachePoint": {"type": "default"}} +CACHE_POINT_ITEM: ContentBlock = {"cachePoint": {"type": "default"}} class PromptCachingHook(HookProvider): @@ -53,11 +54,12 @@ class PromptCachingHook(HookProvider): - Strands Agents Hooks: https://strandsagents.com/latest/documentation/docs/user-guide/concepts/agents/hooks/ """ - def register_hooks(self, registry: HookRegistry) -> None: + def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: """Register hook callbacks with the registry. Args: registry: The hook registry to register callbacks with. + **kwargs: Additional keyword arguments for future extensibility. """ registry.add_callback(BeforeModelCallEvent, self.on_invocation_start) registry.add_callback(AfterModelCallEvent, self.on_invocation_end) @@ -72,34 +74,18 @@ def on_invocation_start(self, event: BeforeModelCallEvent) -> None: event: The before model call event containing the agent and its messages. Note: - If the messages list is empty or the last message has no content array, - this method logs a warning and returns without modifying the messages. + If the messages list is empty, this method logs a warning and returns + without modifying the messages. """ messages = event.agent.messages - # Validate messages structure - if not messages: + if len(messages) == 0: logger.warning("Cannot add cache point: messages list is empty") return last_message = messages[-1] - if "content" not in last_message: - logger.warning( - "Cannot add cache point: last message has no content field | role=%s", - last_message.get("role", "unknown"), - ) - return - content = last_message["content"] - if not isinstance(content, list): - logger.warning( - "Cannot add cache point: content is not a list | type=%s | role=%s", - type(content).__name__, - last_message.get("role", "unknown"), - ) - return - # Add cache point to the end of the last message's content content.append(CACHE_POINT_ITEM) logger.debug( "Added cache point to message | message_index=%d | role=%s | content_blocks=%d", @@ -124,29 +110,13 @@ def on_invocation_end(self, event: AfterModelCallEvent) -> None: """ messages = event.agent.messages - # Validate messages structure - if not messages: + if len(messages) == 0: logger.warning("Cannot remove cache point: messages list is empty") return last_message = messages[-1] - if "content" not in last_message: - logger.warning( - "Cannot remove cache point: last message has no content field | role=%s", - last_message.get("role", "unknown"), - ) - return - content = last_message["content"] - if not isinstance(content, list): - logger.warning( - "Cannot remove cache point: content is not a list | type=%s | role=%s", - type(content).__name__, - last_message.get("role", "unknown"), - ) - return - # Remove cache point from the last message's content try: content.remove(CACHE_POINT_ITEM) logger.debug( diff --git a/tests/strands/hooks/test_bedrock.py b/tests/strands/hooks/test_bedrock.py index 312188e58..1cf16418e 100644 --- a/tests/strands/hooks/test_bedrock.py +++ b/tests/strands/hooks/test_bedrock.py @@ -114,34 +114,6 @@ def test_add_cache_point_empty_messages_list(self, hook, before_event): # Verify no error was raised and messages remain empty assert before_event.agent.messages == [] - def test_add_cache_point_message_without_content_field(self, hook, before_event): - """Test handling of message without content field.""" - before_event.agent.messages = [ - {"role": "user"}, # No content field - ] - - with unittest.mock.patch("strands.hooks.bedrock.logger") as mock_logger: - hook.on_invocation_start(before_event) - mock_logger.warning.assert_called_once() - assert "no content field" in mock_logger.warning.call_args[0][0] - - # Verify message was not modified - assert "content" not in before_event.agent.messages[0] - - def test_add_cache_point_content_not_a_list(self, hook, before_event): - """Test handling of content that is not a list.""" - before_event.agent.messages = [ - {"role": "user", "content": "This should be a list"}, - ] - - with unittest.mock.patch("strands.hooks.bedrock.logger") as mock_logger: - hook.on_invocation_start(before_event) - mock_logger.warning.assert_called_once() - assert "content is not a list" in mock_logger.warning.call_args[0][0] - - # Verify content was not modified - assert before_event.agent.messages[0]["content"] == "This should be a list" - class TestPromptCachingHookRemoveCachePoint: """Test removing cache points after model invocation.""" @@ -204,34 +176,6 @@ def test_remove_cache_point_empty_messages_list(self, hook, after_event): # Verify no error was raised assert after_event.agent.messages == [] - def test_remove_cache_point_message_without_content_field(self, hook, after_event): - """Test handling of message without content field.""" - after_event.agent.messages = [ - {"role": "user"}, # No content field - ] - - with unittest.mock.patch("strands.hooks.bedrock.logger") as mock_logger: - hook.on_invocation_end(after_event) - mock_logger.warning.assert_called_once() - assert "no content field" in mock_logger.warning.call_args[0][0] - - # Verify message was not modified - assert "content" not in after_event.agent.messages[0] - - def test_remove_cache_point_content_not_a_list(self, hook, after_event): - """Test handling of content that is not a list.""" - after_event.agent.messages = [ - {"role": "user", "content": "This should be a list"}, - ] - - with unittest.mock.patch("strands.hooks.bedrock.logger") as mock_logger: - hook.on_invocation_end(after_event) - mock_logger.warning.assert_called_once() - assert "content is not a list" in mock_logger.warning.call_args[0][0] - - # Verify content was not modified - assert after_event.agent.messages[0]["content"] == "This should be a list" - def test_remove_cache_point_not_found(self, hook, after_event): """Test handling when cache point is not found in content.""" after_event.agent.messages = [