Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,9 +414,13 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
if "format" in document:
result["format"] = document["format"]

# Handle source
# Handle source (supports both bytes and s3Location)
if "source" in document:
result["source"] = {"bytes": document["source"]["bytes"]}
source = document["source"]
if "bytes" in source:
result["source"] = {"bytes": source["bytes"]}
elif "s3Location" in source:
result["source"] = {"s3Location": source["s3Location"]}

# Handle optional fields
if "citations" in document and document["citations"] is not None:
Expand All @@ -437,11 +441,12 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
if "image" in content:
image = content["image"]
source = image["source"]
formatted_source = {}
image_source: dict[str, Any] = {}
if "bytes" in source:
formatted_source = {"bytes": source["bytes"]}
result = {"format": image["format"], "source": formatted_source}
return {"image": result}
image_source = {"bytes": source["bytes"]}
elif "s3Location" in source:
image_source = {"s3Location": source["s3Location"]}
return {"image": {"format": image["format"], "source": image_source}}

# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ReasoningContentBlock.html
if "reasoningContent" in content:
Expand Down Expand Up @@ -502,11 +507,12 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
if "video" in content:
video = content["video"]
source = video["source"]
formatted_source = {}
video_source: dict[str, Any] = {}
if "bytes" in source:
formatted_source = {"bytes": source["bytes"]}
result = {"format": video["format"], "source": formatted_source}
return {"video": result}
video_source = {"bytes": source["bytes"]}
elif "s3Location" in source:
video_source = {"s3Location": source["s3Location"]}
return {"video": {"format": video["format"], "source": video_source}}

# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CitationsContentBlock.html
if "citationsContent" in content:
Expand Down
25 changes: 22 additions & 3 deletions src/strands/types/media.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,33 @@

from .citations import CitationsConfig


class S3Location(TypedDict, total=False):
"""S3 location for media content.

Attributes:
uri: The S3 URI of the content.
bucketOwner: The account ID of the bucket owner.
"""

uri: str
bucketOwner: str


DocumentFormat = Literal["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"]
"""Supported document formats."""


class DocumentSource(TypedDict):
class DocumentSource(TypedDict, total=False):
"""Contains the content of a document.

Attributes:
bytes: The binary content of the document.
s3Location: The S3 location of the document.
"""

bytes: bytes
s3Location: S3Location


class DocumentContent(TypedDict, total=False):
Expand All @@ -45,14 +60,16 @@ class DocumentContent(TypedDict, total=False):
"""Supported image formats."""


class ImageSource(TypedDict):
class ImageSource(TypedDict, total=False):
"""Contains the content of an image.

Attributes:
bytes: The binary content of the image.
s3Location: The S3 location of the image.
"""

bytes: bytes
s3Location: S3Location


class ImageContent(TypedDict):
Expand All @@ -71,14 +88,16 @@ class ImageContent(TypedDict):
"""Supported video formats."""


class VideoSource(TypedDict):
class VideoSource(TypedDict, total=False):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are all breaking changes moving from required to optional fields. We'll need to discuss this as a team if this is an acceptable breakage. As of now, I am not inclined to merge this.

"""Contains the content of a video.

Attributes:
bytes: The binary content of the video.
s3Location: The S3 location of the video.
"""

bytes: bytes
s3Location: S3Location


class VideoContent(TypedDict):
Expand Down
73 changes: 73 additions & 0 deletions tests/strands/models/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -2240,3 +2240,76 @@ async def test_format_request_with_guardrail_latest_message(model):
# Latest user message image should also be wrapped
assert "guardContent" in formatted_messages[2]["content"][1]
assert formatted_messages[2]["content"][1]["guardContent"]["image"]["format"] == "png"


def test_format_request_s3_location_document_source(model, model_id):
"""Test that s3Location source is supported for documents when bytes is not present."""
messages = [
{
"role": "user",
"content": [
{
"document": {
"name": "test.pdf",
"format": "pdf",
"source": {"s3Location": {"uri": "s3://bucket/key.pdf"}},
}
},
],
}
]

formatted_request = model._format_request(messages)

document_block = formatted_request["messages"][0]["content"][0]["document"]
expected = {"name": "test.pdf", "format": "pdf", "source": {"s3Location": {"uri": "s3://bucket/key.pdf"}}}
assert document_block == expected


def test_format_request_s3_location_image_source(model, model_id):
"""Test that s3Location source is supported for images when bytes is not present."""
messages = [
{
"role": "user",
"content": [
{
"image": {
"format": "png",
"source": {"s3Location": {"uri": "s3://bucket/image.png"}},
}
},
],
}
]

formatted_request = model._format_request(messages)

image_block = formatted_request["messages"][0]["content"][0]["image"]
expected = {"format": "png", "source": {"s3Location": {"uri": "s3://bucket/image.png"}}}
assert image_block == expected


def test_format_request_s3_location_video_source(model, model_id):
"""Test that s3Location source is supported for videos when bytes is not present."""
messages = [
{
"role": "user",
"content": [
{
"video": {
"format": "mp4",
"source": {"s3Location": {"uri": "s3://bucket/video.mp4", "bucketOwner": "123456789012"}},
}
},
],
}
]

formatted_request = model._format_request(messages)

video_block = formatted_request["messages"][0]["content"][0]["video"]
expected = {
"format": "mp4",
"source": {"s3Location": {"uri": "s3://bucket/video.mp4", "bucketOwner": "123456789012"}},
}
assert video_block == expected
53 changes: 53 additions & 0 deletions tests/strands/tools/mcp/test_mcp_client_tool_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,3 +824,56 @@ def short_names_only(tool) -> bool:
# Should only include short tool (name length <= 10)
assert len(result) == 1
assert result[0] is mock_agent_tool1


def test_is_session_active_with_close_future_done():
"""Test that _is_session_active returns False when close_future is done."""
from unittest.mock import Mock

client = MCPClient(transport_callable=lambda: Mock())

# Mock background thread as alive
client._background_thread = Mock()
client._background_thread.is_alive.return_value = True

# Mock close_future as done
client._close_future = Mock()
client._close_future.done.return_value = True

# Should return False because close_future is done
assert client._is_session_active() is False


def test_is_session_active_with_close_future_not_done():
"""Test that _is_session_active returns True when close_future is not done."""
from unittest.mock import Mock

client = MCPClient(transport_callable=lambda: Mock())

# Mock background thread as alive
client._background_thread = Mock()
client._background_thread.is_alive.return_value = True

# Mock close_future as not done
client._close_future = Mock()
client._close_future.done.return_value = False

# Should return True
assert client._is_session_active() is True


def test_is_session_active_with_none_close_future():
"""Test that _is_session_active returns True when close_future is None."""
from unittest.mock import Mock

client = MCPClient(transport_callable=lambda: Mock())

# Mock background thread as alive
client._background_thread = Mock()
client._background_thread.is_alive.return_value = True

# close_future is None
client._close_future = None

# Should return True
assert client._is_session_active() is True