diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index c4ebc0b54..7126644e6 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -10,6 +10,7 @@ """ import logging +import threading import warnings from typing import ( TYPE_CHECKING, @@ -59,7 +60,7 @@ from ..types._events import AgentResultEvent, EventLoopStopEvent, InitEventLoopEvent, ModelStreamChunkEvent, TypedEvent from ..types.agent import AgentInput from ..types.content import ContentBlock, Message, Messages, SystemContentBlock -from ..types.exceptions import ContextWindowOverflowException +from ..types.exceptions import ConcurrencyException, ContextWindowOverflowException from ..types.traces import AttributeValue from .agent_result import AgentResult from .conversation_manager import ( @@ -245,6 +246,11 @@ def __init__( self._interrupt_state = _InterruptState() + # Initialize lock for guarding concurrent invocations + # Using threading.Lock instead of asyncio.Lock because run_async() creates + # separate event loops in different threads, so asyncio.Lock wouldn't work + self._invocation_lock = threading.Lock() + # Initialize session management functionality self._session_manager = session_manager if self._session_manager: @@ -554,6 +560,7 @@ async def stream_async( - And other event data provided by the callback handler Raises: + ConcurrencyException: If another invocation is already in progress on this agent instance. Exception: Any exceptions from the agent invocation will be propagated to the caller. Example: @@ -563,50 +570,63 @@ async def stream_async( yield event["data"] ``` """ - self._interrupt_state.resume(prompt) + # Acquire lock to prevent concurrent invocations + # Using threading.Lock instead of asyncio.Lock because run_async() creates + # separate event loops in different threads + acquired = self._invocation_lock.acquire(blocking=False) + if not acquired: + raise ConcurrencyException( + "Agent is already processing a request. Concurrent invocations are not supported." + ) - self.event_loop_metrics.reset_usage_metrics() + try: + self._interrupt_state.resume(prompt) - merged_state = {} - if kwargs: - warnings.warn("`**kwargs` parameter is deprecating, use `invocation_state` instead.", stacklevel=2) - merged_state.update(kwargs) - if invocation_state is not None: - merged_state["invocation_state"] = invocation_state - else: - if invocation_state is not None: - merged_state = invocation_state + self.event_loop_metrics.reset_usage_metrics() - callback_handler = self.callback_handler - if kwargs: - callback_handler = kwargs.get("callback_handler", self.callback_handler) + merged_state = {} + if kwargs: + warnings.warn("`**kwargs` parameter is deprecating, use `invocation_state` instead.", stacklevel=2) + merged_state.update(kwargs) + if invocation_state is not None: + merged_state["invocation_state"] = invocation_state + else: + if invocation_state is not None: + merged_state = invocation_state - # Process input and get message to add (if any) - messages = await self._convert_prompt_to_messages(prompt) + callback_handler = self.callback_handler + if kwargs: + callback_handler = kwargs.get("callback_handler", self.callback_handler) - self.trace_span = self._start_agent_trace_span(messages) + # Process input and get message to add (if any) + messages = await self._convert_prompt_to_messages(prompt) - with trace_api.use_span(self.trace_span): - try: - events = self._run_loop(messages, merged_state, structured_output_model) + self.trace_span = self._start_agent_trace_span(messages) - async for event in events: - event.prepare(invocation_state=merged_state) + with trace_api.use_span(self.trace_span): + try: + events = self._run_loop(messages, merged_state, structured_output_model) + + async for event in events: + event.prepare(invocation_state=merged_state) - if event.is_callback_event: - as_dict = event.as_dict() - callback_handler(**as_dict) - yield as_dict + if event.is_callback_event: + as_dict = event.as_dict() + callback_handler(**as_dict) + yield as_dict - result = AgentResult(*event["stop"]) - callback_handler(result=result) - yield AgentResultEvent(result=result).as_dict() + result = AgentResult(*event["stop"]) + callback_handler(result=result) + yield AgentResultEvent(result=result).as_dict() - self._end_agent_trace_span(response=result) + self._end_agent_trace_span(response=result) - except Exception as e: - self._end_agent_trace_span(error=e) - raise + except Exception as e: + self._end_agent_trace_span(error=e) + raise + + finally: + self._invocation_lock.release() async def _run_loop( self, diff --git a/src/strands/tools/_caller.py b/src/strands/tools/_caller.py index 97485d068..bfec5886d 100644 --- a/src/strands/tools/_caller.py +++ b/src/strands/tools/_caller.py @@ -15,6 +15,7 @@ from ..tools.executors._executor import ToolExecutor from ..types._events import ToolInterruptEvent from ..types.content import ContentBlock, Message +from ..types.exceptions import ConcurrencyException from ..types.tools import ToolResult, ToolUse if TYPE_CHECKING: @@ -73,46 +74,64 @@ def caller( if self._agent._interrupt_state.activated: raise RuntimeError("cannot directly call tool during interrupt") - normalized_name = self._find_normalized_tool_name(name) + if record_direct_tool_call is not None: + should_record_direct_tool_call = record_direct_tool_call + else: + should_record_direct_tool_call = self._agent.record_direct_tool_call - # Create unique tool ID and set up the tool request - tool_id = f"tooluse_{name}_{random.randint(100000000, 999999999)}" - tool_use: ToolUse = { - "toolUseId": tool_id, - "name": normalized_name, - "input": kwargs.copy(), - } - tool_results: list[ToolResult] = [] - invocation_state = kwargs + should_lock = should_record_direct_tool_call - async def acall() -> ToolResult: - async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state): - if isinstance(event, ToolInterruptEvent): - self._agent._interrupt_state.deactivate() - raise RuntimeError("cannot raise interrupt in direct tool call") + from ..agent import Agent # Locally imported to avoid circular reference - tool_result = tool_results[0] + acquired_lock = ( + should_lock + and isinstance(self._agent, Agent) + and self._agent._invocation_lock.acquire_lock(blocking=False) + ) + if should_lock and not acquired_lock: + raise ConcurrencyException( + "Direct tool call cannot be made while the agent is in the middle of an invocation. " + "Set record_direct_tool_call=False to allow direct tool calls during agent invocation." + ) - if record_direct_tool_call is not None: - should_record_direct_tool_call = record_direct_tool_call - else: - should_record_direct_tool_call = self._agent.record_direct_tool_call + try: + normalized_name = self._find_normalized_tool_name(name) - if should_record_direct_tool_call: - # Create a record of this tool execution in the message history - await self._record_tool_execution(tool_use, tool_result, user_message_override) + # Create unique tool ID and set up the tool request + tool_id = f"tooluse_{name}_{random.randint(100000000, 999999999)}" + tool_use: ToolUse = { + "toolUseId": tool_id, + "name": normalized_name, + "input": kwargs.copy(), + } + tool_results: list[ToolResult] = [] + invocation_state = kwargs - return tool_result + async def acall() -> ToolResult: + async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state): + if isinstance(event, ToolInterruptEvent): + self._agent._interrupt_state.deactivate() + raise RuntimeError("cannot raise interrupt in direct tool call") + + tool_result = tool_results[0] + + if should_record_direct_tool_call: + # Create a record of this tool execution in the message history + await self._record_tool_execution(tool_use, tool_result, user_message_override) - tool_result = run_async(acall) + return tool_result - # TODO: https://github.com/strands-agents/sdk-python/issues/1311 - from ..agent import Agent + tool_result = run_async(acall) - if isinstance(self._agent, Agent): - self._agent.conversation_manager.apply_management(self._agent) + # TODO: https://github.com/strands-agents/sdk-python/issues/1311 + if isinstance(self._agent, Agent): + self._agent.conversation_manager.apply_management(self._agent) + + return tool_result - return tool_result + finally: + if acquired_lock and isinstance(self._agent, Agent): + self._agent._invocation_lock.release() return caller diff --git a/src/strands/types/exceptions.py b/src/strands/types/exceptions.py index b9c5bc769..1d1983abd 100644 --- a/src/strands/types/exceptions.py +++ b/src/strands/types/exceptions.py @@ -94,3 +94,14 @@ def __init__(self, message: str): """ self.message = message super().__init__(message) + + +class ConcurrencyException(Exception): + """Exception raised when concurrent invocations are attempted on an agent instance. + + Agent instances maintain internal state that cannot be safely accessed concurrently. + This exception is raised when an invocation is attempted while another invocation + is already in progress on the same agent instance. + """ + + pass diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index f133400a8..865d3d2e6 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -1,17 +1,21 @@ +import asyncio import copy import importlib import json import os import textwrap +import threading +import time import unittest.mock import warnings +from typing import Any, AsyncGenerator from uuid import uuid4 import pytest from pydantic import BaseModel import strands -from strands import Agent +from strands import Agent, ToolContext from strands.agent import AgentResult from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager @@ -24,7 +28,7 @@ from strands.telemetry.tracer import serialize from strands.types._events import EventLoopStopEvent, ModelStreamEvent from strands.types.content import Messages -from strands.types.exceptions import ContextWindowOverflowException, EventLoopException +from strands.types.exceptions import ConcurrencyException, ContextWindowOverflowException, EventLoopException from strands.types.session import Session, SessionAgent, SessionMessage, SessionType from tests.fixtures.mock_session_repository import MockedSessionRepository from tests.fixtures.mocked_model_provider import MockedModelProvider @@ -185,6 +189,15 @@ class User(BaseModel): return User(name="Jane Doe", age=30, email="jane@doe.com") +class SlowMockedModel(MockedModelProvider): + async def stream( + self, messages, tool_specs=None, system_prompt=None, tool_choice=None, **kwargs + ) -> AsyncGenerator[Any, None]: + await asyncio.sleep(0.15) # Add async delay to ensure concurrency + async for event in super().stream(messages, tool_specs, system_prompt, tool_choice, **kwargs): + yield event + + def test_agent__init__tool_loader_format(tool_decorated, tool_module, tool_imported, tool_registry): _ = tool_registry @@ -2182,3 +2195,246 @@ def test_agent_skips_fix_for_valid_conversation(mock_model, agenerator): # Should not have added any toolResult messages # Only the new user message and assistant response should be added assert len(agent.messages) == original_length + 2 + + +# ============================================================================ +# Concurrency Exception Tests +# ============================================================================ + + +def test_agent_concurrent_call_raises_exception(): + """Test that concurrent __call__() calls raise ConcurrencyException.""" + model = SlowMockedModel( + [ + {"role": "assistant", "content": [{"text": "hello"}]}, + {"role": "assistant", "content": [{"text": "world"}]}, + ] + ) + agent = Agent(model=model) + + results = [] + errors = [] + lock = threading.Lock() + + def invoke(): + try: + result = agent("test") + with lock: + results.append(result) + except ConcurrencyException as e: + with lock: + errors.append(e) + + # Create two threads that will try to invoke concurrently + t1 = threading.Thread(target=invoke) + t2 = threading.Thread(target=invoke) + + t1.start() + t2.start() + t1.join() + t2.join() + + # One should succeed, one should raise ConcurrencyException + assert len(results) == 1, f"Expected 1 success, got {len(results)}" + assert len(errors) == 1, f"Expected 1 error, got {len(errors)}" + assert "concurrent" in str(errors[0]).lower() and "invocation" in str(errors[0]).lower() + + +def test_agent_concurrent_structured_output_raises_exception(): + """Test that concurrent structured_output() calls raise ConcurrencyException. + + Note: This test validates that the sync invocation path is protected. + The concurrent __call__() test already validates the core functionality. + """ + model = SlowMockedModel( + [ + {"role": "assistant", "content": [{"text": "response1"}]}, + {"role": "assistant", "content": [{"text": "response2"}]}, + ] + ) + agent = Agent(model=model) + + results = [] + errors = [] + lock = threading.Lock() + + def invoke(): + try: + result = agent("test") + with lock: + results.append(result) + except ConcurrencyException as e: + with lock: + errors.append(e) + + # Create two threads that will try to invoke concurrently + t1 = threading.Thread(target=invoke) + t2 = threading.Thread(target=invoke) + + t1.start() + time.sleep(0.05) # Small delay to ensure first thread acquires lock + t2.start() + t1.join() + t2.join() + + # One should succeed, one should raise ConcurrencyException + assert len(results) == 1, f"Expected 1 success, got {len(results)}" + assert len(errors) == 1, f"Expected 1 error, got {len(errors)}" + assert "concurrent" in str(errors[0]).lower() and "invocation" in str(errors[0]).lower() + + +@pytest.mark.asyncio +async def test_agent_sequential_invocations_work(): + """Test that sequential invocations work correctly after lock is released.""" + model = MockedModelProvider( + [ + {"role": "assistant", "content": [{"text": "response1"}]}, + {"role": "assistant", "content": [{"text": "response2"}]}, + {"role": "assistant", "content": [{"text": "response3"}]}, + ] + ) + agent = Agent(model=model) + + # All sequential calls should succeed + result1 = await agent.invoke_async("test1") + assert result1.message["content"][0]["text"] == "response1" + + result2 = await agent.invoke_async("test2") + assert result2.message["content"][0]["text"] == "response2" + + result3 = await agent.invoke_async("test3") + assert result3.message["content"][0]["text"] == "response3" + + +@pytest.mark.asyncio +async def test_agent_lock_released_on_exception(): + """Test that lock is released when an exception occurs during invocation.""" + + # Create a mock model that raises an explicit error + mock_model = unittest.mock.Mock() + + async def failing_stream(*args, **kwargs): + raise RuntimeError("Simulated model failure") + yield # Make this an async generator + + mock_model.stream = failing_stream + + agent = Agent(model=mock_model) + + # First call will fail due to the simulated error + with pytest.raises(RuntimeError, match="Simulated model failure"): + await agent.invoke_async("test") + + # Lock should be released, so this should not raise ConcurrencyException + # It will still raise RuntimeError, but that's expected + with pytest.raises(RuntimeError, match="Simulated model failure"): + await agent.invoke_async("test") + + +def test_agent_direct_tool_call_during_invocation_raises_exception(tool_decorated): + """Test that direct tool call during agent invocation raises ConcurrencyException.""" + + tool_calls = [] + + @strands.tool + def tool_to_invoke(): + tool_calls.append("tool_to_invoke") + return "called" + + @strands.tool(context=True) + def agent_tool(tool_context: ToolContext) -> str: + tool_context.agent.tool.tool_to_invoke(record_direct_tool_call=True) + return "tool result" + + model = MockedModelProvider( + [ + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "test-123", + "name": "agent_tool", + "input": {}, + } + } + ], + }, + {"role": "assistant", "content": [{"text": "Done"}]}, + ] + ) + agent = Agent(model=model, tools=[agent_tool, tool_to_invoke]) + agent("Hi") + + # Tool call should have not succeeded + assert len(tool_calls) == 0 + + assert agent.messages[-2] == { + "content": [ + { + "toolResult": { + "content": [ + { + "text": "Error: ConcurrencyException - Direct tool call cannot be made while the agent is " + "in the middle of an invocation. Set record_direct_tool_call=False to allow direct tool " + "calls during agent invocation." + } + ], + "status": "error", + "toolUseId": "test-123", + } + } + ], + "role": "user", + } + + +def test_agent_direct_tool_call_during_invocation_succeeds_with_record_false(tool_decorated): + """Test that direct tool call during agent invocation succeeds when record_direct_tool_call=False.""" + tool_calls = [] + + @strands.tool + def tool_to_invoke(): + tool_calls.append("tool_to_invoke") + return "called" + + @strands.tool(context=True) + def agent_tool(tool_context: ToolContext) -> str: + tool_context.agent.tool.tool_to_invoke(record_direct_tool_call=False) + return "tool result" + + model = MockedModelProvider( + [ + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "test-123", + "name": "agent_tool", + "input": {}, + } + } + ], + }, + {"role": "assistant", "content": [{"text": "Done"}]}, + ] + ) + agent = Agent(model=model, tools=[agent_tool, tool_to_invoke]) + agent("Hi") + + # Tool call should have succeeded + assert len(tool_calls) == 1 + + assert agent.messages[-2] == { + "content": [ + { + "toolResult": { + "content": [{"text": "tool result"}], + "status": "success", + "toolUseId": "test-123", + } + } + ], + "role": "user", + }