Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 54 additions & 34 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"""

import logging
import threading
import warnings
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -59,7 +60,7 @@
from ..types._events import AgentResultEvent, EventLoopStopEvent, InitEventLoopEvent, ModelStreamChunkEvent, TypedEvent
from ..types.agent import AgentInput
from ..types.content import ContentBlock, Message, Messages, SystemContentBlock
from ..types.exceptions import ContextWindowOverflowException
from ..types.exceptions import ConcurrencyException, ContextWindowOverflowException
from ..types.traces import AttributeValue
from .agent_result import AgentResult
from .conversation_manager import (
Expand Down Expand Up @@ -245,6 +246,11 @@ def __init__(

self._interrupt_state = _InterruptState()

# Initialize lock for guarding concurrent invocations
# Using threading.Lock instead of asyncio.Lock because run_async() creates
# separate event loops in different threads, so asyncio.Lock wouldn't work
self._invocation_lock = threading.Lock()

# Initialize session management functionality
self._session_manager = session_manager
if self._session_manager:
Expand Down Expand Up @@ -554,6 +560,7 @@ async def stream_async(
- And other event data provided by the callback handler

Raises:
ConcurrencyException: If another invocation is already in progress on this agent instance.
Exception: Any exceptions from the agent invocation will be propagated to the caller.

Example:
Expand All @@ -563,50 +570,63 @@ async def stream_async(
yield event["data"]
```
"""
self._interrupt_state.resume(prompt)
# Acquire lock to prevent concurrent invocations
# Using threading.Lock instead of asyncio.Lock because run_async() creates
# separate event loops in different threads
acquired = self._invocation_lock.acquire(blocking=False)
if not acquired:
raise ConcurrencyException(
"Agent is already processing a request. Concurrent invocations are not supported."
)

self.event_loop_metrics.reset_usage_metrics()
try:
self._interrupt_state.resume(prompt)

merged_state = {}
if kwargs:
warnings.warn("`**kwargs` parameter is deprecating, use `invocation_state` instead.", stacklevel=2)
merged_state.update(kwargs)
if invocation_state is not None:
merged_state["invocation_state"] = invocation_state
else:
if invocation_state is not None:
merged_state = invocation_state
self.event_loop_metrics.reset_usage_metrics()

callback_handler = self.callback_handler
if kwargs:
callback_handler = kwargs.get("callback_handler", self.callback_handler)
merged_state = {}
if kwargs:
warnings.warn("`**kwargs` parameter is deprecating, use `invocation_state` instead.", stacklevel=2)
merged_state.update(kwargs)
if invocation_state is not None:
merged_state["invocation_state"] = invocation_state
else:
if invocation_state is not None:
merged_state = invocation_state

# Process input and get message to add (if any)
messages = await self._convert_prompt_to_messages(prompt)
callback_handler = self.callback_handler
if kwargs:
callback_handler = kwargs.get("callback_handler", self.callback_handler)

self.trace_span = self._start_agent_trace_span(messages)
# Process input and get message to add (if any)
messages = await self._convert_prompt_to_messages(prompt)

with trace_api.use_span(self.trace_span):
try:
events = self._run_loop(messages, merged_state, structured_output_model)
self.trace_span = self._start_agent_trace_span(messages)

async for event in events:
event.prepare(invocation_state=merged_state)
with trace_api.use_span(self.trace_span):
try:
events = self._run_loop(messages, merged_state, structured_output_model)

async for event in events:
event.prepare(invocation_state=merged_state)

if event.is_callback_event:
as_dict = event.as_dict()
callback_handler(**as_dict)
yield as_dict
if event.is_callback_event:
as_dict = event.as_dict()
callback_handler(**as_dict)
yield as_dict

result = AgentResult(*event["stop"])
callback_handler(result=result)
yield AgentResultEvent(result=result).as_dict()
result = AgentResult(*event["stop"])
callback_handler(result=result)
yield AgentResultEvent(result=result).as_dict()

self._end_agent_trace_span(response=result)
self._end_agent_trace_span(response=result)

except Exception as e:
self._end_agent_trace_span(error=e)
raise
except Exception as e:
self._end_agent_trace_span(error=e)
raise

finally:
self._invocation_lock.release()

async def _run_loop(
self,
Expand Down
79 changes: 49 additions & 30 deletions src/strands/tools/_caller.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ..tools.executors._executor import ToolExecutor
from ..types._events import ToolInterruptEvent
from ..types.content import ContentBlock, Message
from ..types.exceptions import ConcurrencyException
from ..types.tools import ToolResult, ToolUse

if TYPE_CHECKING:
Expand Down Expand Up @@ -73,46 +74,64 @@ def caller(
if self._agent._interrupt_state.activated:
raise RuntimeError("cannot directly call tool during interrupt")

normalized_name = self._find_normalized_tool_name(name)
if record_direct_tool_call is not None:
should_record_direct_tool_call = record_direct_tool_call
else:
should_record_direct_tool_call = self._agent.record_direct_tool_call

# Create unique tool ID and set up the tool request
tool_id = f"tooluse_{name}_{random.randint(100000000, 999999999)}"
tool_use: ToolUse = {
"toolUseId": tool_id,
"name": normalized_name,
"input": kwargs.copy(),
}
tool_results: list[ToolResult] = []
invocation_state = kwargs
should_lock = should_record_direct_tool_call

async def acall() -> ToolResult:
async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state):
if isinstance(event, ToolInterruptEvent):
self._agent._interrupt_state.deactivate()
raise RuntimeError("cannot raise interrupt in direct tool call")
from ..agent import Agent # Locally imported to avoid circular reference

tool_result = tool_results[0]
acquired_lock = (
should_lock
and isinstance(self._agent, Agent)
and self._agent._invocation_lock.acquire_lock(blocking=False)
)
if should_lock and not acquired_lock:
raise ConcurrencyException(
"Direct tool call cannot be made while the agent is in the middle of an invocation. "
"Set record_direct_tool_call=False to allow direct tool calls during agent invocation."
)

if record_direct_tool_call is not None:
should_record_direct_tool_call = record_direct_tool_call
else:
should_record_direct_tool_call = self._agent.record_direct_tool_call
try:
normalized_name = self._find_normalized_tool_name(name)

if should_record_direct_tool_call:
# Create a record of this tool execution in the message history
await self._record_tool_execution(tool_use, tool_result, user_message_override)
# Create unique tool ID and set up the tool request
tool_id = f"tooluse_{name}_{random.randint(100000000, 999999999)}"
tool_use: ToolUse = {
"toolUseId": tool_id,
"name": normalized_name,
"input": kwargs.copy(),
}
tool_results: list[ToolResult] = []
invocation_state = kwargs

return tool_result
async def acall() -> ToolResult:
async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state):
if isinstance(event, ToolInterruptEvent):
self._agent._interrupt_state.deactivate()
raise RuntimeError("cannot raise interrupt in direct tool call")

tool_result = tool_results[0]

if should_record_direct_tool_call:
# Create a record of this tool execution in the message history
await self._record_tool_execution(tool_use, tool_result, user_message_override)

tool_result = run_async(acall)
return tool_result

# TODO: https://github.com/strands-agents/sdk-python/issues/1311
from ..agent import Agent
tool_result = run_async(acall)

if isinstance(self._agent, Agent):
self._agent.conversation_manager.apply_management(self._agent)
# TODO: https://github.com/strands-agents/sdk-python/issues/1311
if isinstance(self._agent, Agent):
self._agent.conversation_manager.apply_management(self._agent)

return tool_result

return tool_result
finally:
if acquired_lock and isinstance(self._agent, Agent):
self._agent._invocation_lock.release()

return caller

Expand Down
11 changes: 11 additions & 0 deletions src/strands/types/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,14 @@ def __init__(self, message: str):
"""
self.message = message
super().__init__(message)


class ConcurrencyException(Exception):
"""Exception raised when concurrent invocations are attempted on an agent instance.

Agent instances maintain internal state that cannot be safely accessed concurrently.
This exception is raised when an invocation is attempted while another invocation
is already in progress on the same agent instance.
"""

pass
Loading
Loading