Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 82 additions & 10 deletions src/strands/multiagent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ class GraphNode:
execution_status: Status = Status.PENDING
result: NodeResult | None = None
execution_time: int = 0
graph: "Graph | None" = field(default=None, repr=False)
_initial_messages: Messages = field(default_factory=list, init=False)
_initial_state: AgentState = field(default_factory=AgentState, init=False)

Expand All @@ -181,9 +182,22 @@ def __post_init__(self) -> None:
def reset_executor_state(self) -> None:
"""Reset GraphNode executor state to initial state when graph was created.

This is useful when nodes are executed multiple times and need to start
fresh on each execution, providing stateless behavior.
If Graph is resuming from an agent interrupt, we reset the executor state from the interrupt context.
Otherwise, nodes reset to their initial state for stateless behavior.
"""
# Check if resuming from agent interrupt
if (
self.graph
and self.graph._interrupt_state.activated
and self.node_id in self.graph._interrupt_state.context
and self.graph._interrupt_state.context[self.node_id].get("activated")
):
context = self.graph._interrupt_state.context[self.node_id]
self.executor.messages = context["messages"]
self.executor.state = AgentState(context["state"])
self.executor._interrupt_state = _InterruptState.from_dict(context["interrupt_state"])
return

if hasattr(self.executor, "messages"):
self.executor.messages = copy.deepcopy(self._initial_messages)

Expand Down Expand Up @@ -464,6 +478,10 @@ def __init__(
self._resume_from_session = False
self.id = id

# Set graph reference on all nodes for interrupt state restoration
for node in self.nodes.values():
node.graph = self

run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self)))

def __call__(
Expand Down Expand Up @@ -603,12 +621,18 @@ def _validate_graph(self, nodes: dict[str, GraphNode]) -> None:
# Validate Agent-specific constraints for each node
_validate_node_executor(node.executor)

def _activate_interrupt(self, node: GraphNode, interrupts: list[Interrupt]) -> MultiAgentNodeInterruptEvent:
def _activate_interrupt(
self, node: GraphNode, interrupts: list[Interrupt], from_agent: bool = False
) -> MultiAgentNodeInterruptEvent:
"""Activate the interrupt state.

Note, a Graph may be interrupted either from a BeforeNodeCallEvent hook or from within an agent node. In either
case, we must manage the interrupt state of both the Graph and the individual agent nodes.

Args:
node: The interrupted node.
interrupts: The interrupts raised by the user.
from_agent: Whether the interrupt was raised from within an agent node.

Returns:
MultiAgentNodeInterruptEvent
Expand All @@ -620,6 +644,15 @@ def _activate_interrupt(self, node: GraphNode, interrupts: list[Interrupt]) -> M
self.state.status = Status.INTERRUPTED
self.state.interrupted_nodes.add(node)

# Store agent state if interrupt is from an agent node
if from_agent and isinstance(node.executor, Agent):
self._interrupt_state.context[node.node_id] = {
"activated": node.executor._interrupt_state.activated,
"interrupt_state": node.executor._interrupt_state.to_dict(),
"state": node.executor.state.get(),
"messages": node.executor.messages,
}

self._interrupt_state.interrupts.update({interrupt.id: interrupt for interrupt in interrupts})
self._interrupt_state.activate()

Expand Down Expand Up @@ -908,6 +941,19 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
elif isinstance(node.executor, Agent):
# For agents, stream their events and collect result
agent_response = None

# Determine node input - use interrupt responses if resuming from interrupt
if (
self._interrupt_state.activated
and node.node_id in self._interrupt_state.context
and self._interrupt_state.context[node.node_id].get("activated")
):
# Reset executor state from interrupt context before streaming
node.reset_executor_state()
node_input = self._interrupt_state.context["responses"]
else:
node_input = self._build_node_input(node)

async for event in node.executor.stream_async(node_input, invocation_state=invocation_state):
# Forward agent events with node context
wrapped_event = MultiAgentNodeStreamEvent(node.node_id, event)
Expand All @@ -920,15 +966,41 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
if agent_response is None:
raise ValueError(f"Node '{node.node_id}' did not produce a result event")

execution_time = round((time.time() - start_time) * 1000)

if agent_response.stop_reason == "interrupt":
node.executor.messages.pop() # remove interrupted tool use message
node.executor._interrupt_state.deactivate()
# Extract metrics with defaults
response_metrics = getattr(agent_response, "metrics", None)
usage = getattr(
response_metrics, "accumulated_usage", Usage(inputTokens=0, outputTokens=0, totalTokens=0)
)
metrics = getattr(response_metrics, "accumulated_metrics", Metrics(latencyMs=execution_time))

node_result = NodeResult(
result=agent_response,
execution_time=execution_time,
status=Status.INTERRUPTED,
accumulated_usage=usage,
accumulated_metrics=metrics,
execution_count=1,
interrupts=agent_response.interrupts or [],
)

raise NotImplementedError(
f"node_id=<{node.node_id}>, "
"issue=<https://github.com/strands-agents/sdk-python/issues/204> "
"| user raised interrupt from an agent node"
# Store result in state
node.result = node_result
node.execution_time = execution_time
self.state.results[node.node_id] = node_result

# Emit node stop event
complete_event = MultiAgentNodeStopEvent(
node_id=node.node_id,
node_result=node_result,
)
yield complete_event

# Activate interrupt state with agent context
yield self._activate_interrupt(node, agent_response.interrupts or [], from_agent=True)
return

# Extract metrics with defaults
response_metrics = getattr(agent_response, "metrics", None)
Expand All @@ -939,7 +1011,7 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])

node_result = NodeResult(
result=agent_response,
execution_time=round((time.time() - start_time) * 1000),
execution_time=execution_time,
status=Status.COMPLETED,
accumulated_usage=usage,
accumulated_metrics=metrics,
Expand Down
Loading
Loading