Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions src/strands/hooks/bedrock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""Bedrock-specific hooks for AWS Bedrock features.

This module provides hook implementations for AWS Bedrock-specific functionality,
such as automatic prompt caching management.
"""

import logging
from typing import Any

from ..types.content import ContentBlock
from . import HookProvider, HookRegistry
from .events import AfterModelCallEvent, BeforeModelCallEvent

logger = logging.getLogger(__name__)

# Cache point object for Bedrock prompt caching
# See: https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html
CACHE_POINT_ITEM: ContentBlock = {"cachePoint": {"type": "default"}}


class PromptCachingHook(HookProvider):
"""Hook provider for automatic Bedrock prompt caching management.

This hook automatically manages cache points for AWS Bedrock's prompt caching feature.
It adds a cache point to the last message before model invocation and removes it
after the invocation completes.

AWS Bedrock supports up to 4 cache points per request. This hook adds one cache point
to enable the "Simplified Cache Management" feature for Claude models, which automatically
checks for cache hits at content block boundaries (looking back approximately 20 content
blocks from the cache checkpoint).

Important Considerations:
- This hook adds a cache point to the last message's content array
- Bedrock has a maximum of 4 cache points per request
- Claude models require minimum token counts (e.g., 1,024 for Claude 3.7 Sonnet)
- Cache TTL is 5 minutes from the last access

Example:
```python
from strands import Agent
from strands.models import BedrockModel
from strands.hooks.bedrock import PromptCachingHook

model = BedrockModel(model_id="us.anthropic.claude-sonnet-4-20250514-v1:0")
agent = Agent(
model=model,
hooks=[PromptCachingHook()]
)
```

See Also:
- AWS Bedrock Prompt Caching: https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html
- Strands Agents Hooks: https://strandsagents.com/latest/documentation/docs/user-guide/concepts/agents/hooks/
"""

def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None:
"""Register hook callbacks with the registry.

Args:
registry: The hook registry to register callbacks with.
**kwargs: Additional keyword arguments for future extensibility.
"""
registry.add_callback(BeforeModelCallEvent, self.on_invocation_start)
registry.add_callback(AfterModelCallEvent, self.on_invocation_end)

def on_invocation_start(self, event: BeforeModelCallEvent) -> None:
"""Add cache point before model invocation.

This callback is triggered before the model is invoked. It adds a cache point
to the last message's content array to enable prompt caching.

Args:
event: The before model call event containing the agent and its messages.

Note:
If the messages list is empty, this method logs a warning and returns
without modifying the messages.
"""
messages = event.agent.messages

if len(messages) == 0:
logger.warning("Cannot add cache point: messages list is empty")
return

last_message = messages[-1]
content = last_message["content"]

content.append(CACHE_POINT_ITEM)
logger.debug(
"Added cache point to message | message_index=%d | role=%s | content_blocks=%d",
len(messages) - 1,
last_message.get("role", "unknown"),
len(content),
)

def on_invocation_end(self, event: AfterModelCallEvent) -> None:
"""Remove cache point after model invocation.

This callback is triggered after the model invocation completes. It removes
the cache point that was added in on_invocation_start to keep the message
history clean.

Args:
event: The after model call event containing the agent and its messages.

Note:
If the cache point is not found in the last message's content array,
this method logs a warning but does not raise an exception.
"""
messages = event.agent.messages

if len(messages) == 0:
logger.warning("Cannot remove cache point: messages list is empty")
return

last_message = messages[-1]
content = last_message["content"]

try:
content.remove(CACHE_POINT_ITEM)
logger.debug(
"Removed cache point from message | message_index=%d | role=%s | content_blocks=%d",
len(messages) - 1,
last_message.get("role", "unknown"),
len(content),
)
except ValueError:
logger.warning(
"Cache point not found in content | message_index=%d | role=%s | content_blocks=%d",
len(messages) - 1,
last_message.get("role", "unknown"),
len(content),
)
242 changes: 242 additions & 0 deletions tests/strands/hooks/test_bedrock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
"""Unit tests for Bedrock-specific hooks."""

import unittest.mock
from unittest.mock import Mock

import pytest

from strands.hooks import HookRegistry
from strands.hooks.bedrock import CACHE_POINT_ITEM, PromptCachingHook
from strands.hooks.events import AfterModelCallEvent, BeforeModelCallEvent


@pytest.fixture
def hook():
"""Create a PromptCachingHook instance."""
return PromptCachingHook()


@pytest.fixture
def mock_agent():
"""Create a mock agent with a messages list."""
agent = Mock()
agent.messages = []
return agent


@pytest.fixture
def before_event(mock_agent):
"""Create a BeforeModelCallEvent with a mock agent."""
return BeforeModelCallEvent(agent=mock_agent)


@pytest.fixture
def after_event(mock_agent):
"""Create an AfterModelCallEvent with a mock agent."""
return AfterModelCallEvent(agent=mock_agent)


class TestPromptCachingHookRegistration:
"""Test hook registration functionality."""

def test_register_hooks(self, hook):
"""Test that register_hooks adds callbacks to the registry."""
registry = HookRegistry()
hook.register_hooks(registry)

# Verify callbacks are registered
assert BeforeModelCallEvent in registry._registered_callbacks
assert AfterModelCallEvent in registry._registered_callbacks
assert len(registry._registered_callbacks[BeforeModelCallEvent]) == 1
assert len(registry._registered_callbacks[AfterModelCallEvent]) == 1


class TestPromptCachingHookAddCachePoint:
"""Test adding cache points before model invocation."""

def test_add_cache_point_success(self, hook, before_event):
"""Test successfully adding a cache point to the last message."""
before_event.agent.messages = [
{"role": "user", "content": [{"text": "Hello"}]},
]

hook.on_invocation_start(before_event)

# Verify cache point was added
assert len(before_event.agent.messages[-1]["content"]) == 2
assert before_event.agent.messages[-1]["content"][-1] == CACHE_POINT_ITEM

def test_add_cache_point_to_message_with_multiple_content_blocks(self, hook, before_event):
"""Test adding cache point to a message with multiple content blocks."""
before_event.agent.messages = [
{
"role": "user",
"content": [
{"text": "First block"},
{"text": "Second block"},
{"image": {"format": "png", "source": {"bytes": b"data"}}},
],
},
]

hook.on_invocation_start(before_event)

# Verify cache point was added at the end
assert len(before_event.agent.messages[-1]["content"]) == 4
assert before_event.agent.messages[-1]["content"][-1] == CACHE_POINT_ITEM
# Verify original content is intact
assert before_event.agent.messages[-1]["content"][0] == {"text": "First block"}

def test_add_cache_point_with_multiple_messages(self, hook, before_event):
"""Test that cache point is added only to the last message."""
before_event.agent.messages = [
{"role": "user", "content": [{"text": "First message"}]},
{"role": "assistant", "content": [{"text": "Response"}]},
{"role": "user", "content": [{"text": "Second message"}]},
]

hook.on_invocation_start(before_event)

# Verify cache point was added only to the last message
assert len(before_event.agent.messages[0]["content"]) == 1
assert len(before_event.agent.messages[1]["content"]) == 1
assert len(before_event.agent.messages[2]["content"]) == 2
assert before_event.agent.messages[2]["content"][-1] == CACHE_POINT_ITEM

def test_add_cache_point_empty_messages_list(self, hook, before_event):
"""Test handling of empty messages list."""
before_event.agent.messages = []

with unittest.mock.patch("strands.hooks.bedrock.logger") as mock_logger:
hook.on_invocation_start(before_event)
mock_logger.warning.assert_called_once_with("Cannot add cache point: messages list is empty")

# Verify no error was raised and messages remain empty
assert before_event.agent.messages == []


class TestPromptCachingHookRemoveCachePoint:
"""Test removing cache points after model invocation."""

def test_remove_cache_point_success(self, hook, after_event):
"""Test successfully removing a cache point from the last message."""
after_event.agent.messages = [
{"role": "user", "content": [{"text": "Hello"}, CACHE_POINT_ITEM]},
]

hook.on_invocation_end(after_event)

# Verify cache point was removed
assert len(after_event.agent.messages[-1]["content"]) == 1
assert after_event.agent.messages[-1]["content"][0] == {"text": "Hello"}

def test_remove_cache_point_from_message_with_multiple_blocks(self, hook, after_event):
"""Test removing cache point from a message with multiple content blocks."""
after_event.agent.messages = [
{
"role": "user",
"content": [
{"text": "First block"},
{"text": "Second block"},
CACHE_POINT_ITEM,
],
},
]

hook.on_invocation_end(after_event)

# Verify only cache point was removed
assert len(after_event.agent.messages[-1]["content"]) == 2
assert after_event.agent.messages[-1]["content"][0] == {"text": "First block"}
assert after_event.agent.messages[-1]["content"][1] == {"text": "Second block"}

def test_remove_cache_point_with_multiple_messages(self, hook, after_event):
"""Test that cache point is removed only from the last message."""
after_event.agent.messages = [
{"role": "user", "content": [{"text": "First message"}, CACHE_POINT_ITEM]},
{"role": "assistant", "content": [{"text": "Response"}]},
{"role": "user", "content": [{"text": "Second message"}, CACHE_POINT_ITEM]},
]

hook.on_invocation_end(after_event)

# Verify cache point was removed only from the last message
assert len(after_event.agent.messages[0]["content"]) == 2 # Unchanged
assert len(after_event.agent.messages[1]["content"]) == 1 # No cache point
assert len(after_event.agent.messages[2]["content"]) == 1 # Cache point removed

def test_remove_cache_point_empty_messages_list(self, hook, after_event):
"""Test handling of empty messages list."""
after_event.agent.messages = []

with unittest.mock.patch("strands.hooks.bedrock.logger") as mock_logger:
hook.on_invocation_end(after_event)
mock_logger.warning.assert_called_once_with("Cannot remove cache point: messages list is empty")

# Verify no error was raised
assert after_event.agent.messages == []

def test_remove_cache_point_not_found(self, hook, after_event):
"""Test handling when cache point is not found in content."""
after_event.agent.messages = [
{"role": "user", "content": [{"text": "Hello"}]}, # No cache point
]

with unittest.mock.patch("strands.hooks.bedrock.logger") as mock_logger:
hook.on_invocation_end(after_event)
mock_logger.warning.assert_called_once()
assert "Cache point not found" in mock_logger.warning.call_args[0][0]

# Verify content was not modified
assert after_event.agent.messages[0]["content"] == [{"text": "Hello"}]


class TestPromptCachingHookEndToEnd:
"""Test end-to-end scenarios with both add and remove operations."""

def test_add_and_remove_cache_point_lifecycle(self, hook, mock_agent):
"""Test the full lifecycle of adding and removing a cache point."""
mock_agent.messages = [
{"role": "user", "content": [{"text": "Hello"}]},
]

# Add cache point
before_event = BeforeModelCallEvent(agent=mock_agent)
hook.on_invocation_start(before_event)

# Verify cache point was added
assert len(mock_agent.messages[-1]["content"]) == 2
assert mock_agent.messages[-1]["content"][-1] == CACHE_POINT_ITEM

# Remove cache point
after_event = AfterModelCallEvent(agent=mock_agent)
hook.on_invocation_end(after_event)

# Verify cache point was removed and original content is intact
assert len(mock_agent.messages[-1]["content"]) == 1
assert mock_agent.messages[-1]["content"][0] == {"text": "Hello"}

def test_logging_on_successful_operations(self, hook, mock_agent):
"""Test that debug logs are generated on successful operations."""
mock_agent.messages = [
{"role": "user", "content": [{"text": "Test"}]},
]

with unittest.mock.patch("strands.hooks.bedrock.logger") as mock_logger:
# Add cache point
before_event = BeforeModelCallEvent(agent=mock_agent)
hook.on_invocation_start(before_event)

# Verify debug log for adding
mock_logger.debug.assert_called_once()
assert "Added cache point" in mock_logger.debug.call_args[0][0]

mock_logger.reset_mock()

# Remove cache point
after_event = AfterModelCallEvent(agent=mock_agent)
hook.on_invocation_end(after_event)

# Verify debug log for removing
mock_logger.debug.assert_called_once()
assert "Removed cache point" in mock_logger.debug.call_args[0][0]