diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 97435ad4a..990ffd8a8 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -166,6 +166,7 @@ class GraphNode: execution_status: Status = Status.PENDING result: NodeResult | None = None execution_time: int = 0 + graph: "Graph | None" = field(default=None, repr=False) _initial_messages: Messages = field(default_factory=list, init=False) _initial_state: AgentState = field(default_factory=AgentState, init=False) @@ -181,9 +182,22 @@ def __post_init__(self) -> None: def reset_executor_state(self) -> None: """Reset GraphNode executor state to initial state when graph was created. - This is useful when nodes are executed multiple times and need to start - fresh on each execution, providing stateless behavior. + If Graph is resuming from an agent interrupt, we reset the executor state from the interrupt context. + Otherwise, nodes reset to their initial state for stateless behavior. """ + # Check if resuming from agent interrupt + if ( + self.graph + and self.graph._interrupt_state.activated + and self.node_id in self.graph._interrupt_state.context + and self.graph._interrupt_state.context[self.node_id].get("activated") + ): + context = self.graph._interrupt_state.context[self.node_id] + self.executor.messages = context["messages"] + self.executor.state = AgentState(context["state"]) + self.executor._interrupt_state = _InterruptState.from_dict(context["interrupt_state"]) + return + if hasattr(self.executor, "messages"): self.executor.messages = copy.deepcopy(self._initial_messages) @@ -464,6 +478,10 @@ def __init__( self._resume_from_session = False self.id = id + # Set graph reference on all nodes for interrupt state restoration + for node in self.nodes.values(): + node.graph = self + run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self))) def __call__( @@ -603,12 +621,18 @@ def _validate_graph(self, nodes: dict[str, GraphNode]) -> None: # Validate Agent-specific constraints for each node _validate_node_executor(node.executor) - def _activate_interrupt(self, node: GraphNode, interrupts: list[Interrupt]) -> MultiAgentNodeInterruptEvent: + def _activate_interrupt( + self, node: GraphNode, interrupts: list[Interrupt], from_agent: bool = False + ) -> MultiAgentNodeInterruptEvent: """Activate the interrupt state. + Note, a Graph may be interrupted either from a BeforeNodeCallEvent hook or from within an agent node. In either + case, we must manage the interrupt state of both the Graph and the individual agent nodes. + Args: node: The interrupted node. interrupts: The interrupts raised by the user. + from_agent: Whether the interrupt was raised from within an agent node. Returns: MultiAgentNodeInterruptEvent @@ -620,6 +644,15 @@ def _activate_interrupt(self, node: GraphNode, interrupts: list[Interrupt]) -> M self.state.status = Status.INTERRUPTED self.state.interrupted_nodes.add(node) + # Store agent state if interrupt is from an agent node + if from_agent and isinstance(node.executor, Agent): + self._interrupt_state.context[node.node_id] = { + "activated": node.executor._interrupt_state.activated, + "interrupt_state": node.executor._interrupt_state.to_dict(), + "state": node.executor.state.get(), + "messages": node.executor.messages, + } + self._interrupt_state.interrupts.update({interrupt.id: interrupt for interrupt in interrupts}) self._interrupt_state.activate() @@ -908,6 +941,19 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) elif isinstance(node.executor, Agent): # For agents, stream their events and collect result agent_response = None + + # Determine node input - use interrupt responses if resuming from interrupt + if ( + self._interrupt_state.activated + and node.node_id in self._interrupt_state.context + and self._interrupt_state.context[node.node_id].get("activated") + ): + # Reset executor state from interrupt context before streaming + node.reset_executor_state() + node_input = self._interrupt_state.context["responses"] + else: + node_input = self._build_node_input(node) + async for event in node.executor.stream_async(node_input, invocation_state=invocation_state): # Forward agent events with node context wrapped_event = MultiAgentNodeStreamEvent(node.node_id, event) @@ -920,15 +966,41 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) if agent_response is None: raise ValueError(f"Node '{node.node_id}' did not produce a result event") + execution_time = round((time.time() - start_time) * 1000) + if agent_response.stop_reason == "interrupt": - node.executor.messages.pop() # remove interrupted tool use message - node.executor._interrupt_state.deactivate() + # Extract metrics with defaults + response_metrics = getattr(agent_response, "metrics", None) + usage = getattr( + response_metrics, "accumulated_usage", Usage(inputTokens=0, outputTokens=0, totalTokens=0) + ) + metrics = getattr(response_metrics, "accumulated_metrics", Metrics(latencyMs=execution_time)) + + node_result = NodeResult( + result=agent_response, + execution_time=execution_time, + status=Status.INTERRUPTED, + accumulated_usage=usage, + accumulated_metrics=metrics, + execution_count=1, + interrupts=agent_response.interrupts or [], + ) - raise NotImplementedError( - f"node_id=<{node.node_id}>, " - "issue= " - "| user raised interrupt from an agent node" + # Store result in state + node.result = node_result + node.execution_time = execution_time + self.state.results[node.node_id] = node_result + + # Emit node stop event + complete_event = MultiAgentNodeStopEvent( + node_id=node.node_id, + node_result=node_result, ) + yield complete_event + + # Activate interrupt state with agent context + yield self._activate_interrupt(node, agent_response.interrupts or [], from_agent=True) + return # Extract metrics with defaults response_metrics = getattr(agent_response, "metrics", None) @@ -939,7 +1011,7 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) node_result = NodeResult( result=agent_response, - execution_time=round((time.time() - start_time) * 1000), + execution_time=execution_time, status=Status.COMPLETED, accumulated_usage=usage, accumulated_metrics=metrics, diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index ab2d86e70..9552fe1dc 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -2154,3 +2154,245 @@ def test_graph_interrupt_on_before_node_call_event(interrupt_hook): assert tru_message == exp_message assert multiagent_result.execution_time >= first_execution_time + + +# ============================================================================ +# Agent Interrupt Tests (Issue #1526) +# ============================================================================ + + +@pytest.fixture +def agenerator(): + """Async generator fixture for mocking stream_async.""" + + async def _agenerator(items): + for item in items: + yield item + + return _agenerator + + +def test_graph_interrupt_on_agent(agenerator): + """Test that an agent node can raise an interrupt.""" + exp_interrupts = [ + Interrupt( + id="test_id", + name="test_name", + reason="test_reason", + ), + ] + + agent = create_mock_agent("test_agent", "Task completed") + # Add required state attributes for interrupt handling + agent.messages = [] + agent.state = AgentState() + agent._interrupt_state = _InterruptState() + + builder = GraphBuilder() + builder.add_node(agent, "test_agent") + graph = builder.build() + + # First invocation - agent returns interrupt + agent.stream_async = Mock() + agent.stream_async.return_value = agenerator( + [ + { + "result": AgentResult( + message={}, + stop_reason="interrupt", + state={}, + metrics=None, + interrupts=exp_interrupts, + ), + }, + ], + ) + + multiagent_result = graph("Test task") + + tru_status = multiagent_result.status + exp_status = Status.INTERRUPTED + assert tru_status == exp_status + + tru_interrupts = multiagent_result.interrupts + assert tru_interrupts == exp_interrupts + + # Verify interrupted node is tracked + tru_node_ids = [node.node_id for node in graph.state.interrupted_nodes] + exp_node_ids = ["test_agent"] + assert tru_node_ids == exp_node_ids + + # Resume with response + agent.stream_async = Mock() + agent.stream_async.return_value = agenerator( + [ + { + "result": AgentResult( + message={"role": "assistant", "content": [{"text": "Task completed"}]}, + stop_reason="end_turn", + state={}, + metrics=None, + ), + }, + ], + ) + graph._interrupt_state.context["test_agent"]["activated"] = True + + interrupt = multiagent_result.interrupts[0] + responses = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "test_response", + }, + }, + ] + multiagent_result = graph(responses) + + tru_status = multiagent_result.status + exp_status = Status.COMPLETED + assert tru_status == exp_status + + # Verify response was passed to agent + agent.stream_async.assert_called_once_with(responses, invocation_state={}) + + +def test_graph_interrupt_on_agent_parallel_execution(agenerator): + """Test that when multiple nodes run in parallel, non-interrupted nodes complete.""" + interrupt = Interrupt( + id="test_id", + name="test_name", + reason="test_reason", + ) + + # Create two agents - one will interrupt, one will complete + agent1 = create_mock_agent("agent1", "Agent 1 completed") + agent1.messages = [] + agent1.state = AgentState() + agent1._interrupt_state = _InterruptState() + + agent2 = create_mock_agent("agent2", "Agent 2 completed") + agent2.messages = [] + agent2.state = AgentState() + agent2._interrupt_state = _InterruptState() + + builder = GraphBuilder() + builder.add_node(agent1, "agent1") + builder.add_node(agent2, "agent2") + # Both are entry points, so they execute in parallel + graph = builder.build() + + # Agent1 will interrupt, Agent2 will complete + agent1.stream_async = Mock() + agent1.stream_async.return_value = agenerator( + [ + { + "result": AgentResult( + message={}, + stop_reason="interrupt", + state={}, + metrics=None, + interrupts=[interrupt], + ), + }, + ], + ) + + agent2.stream_async = Mock() + agent2.stream_async.return_value = agenerator( + [ + { + "result": AgentResult( + message={"role": "assistant", "content": [{"text": "Agent 2 completed"}]}, + stop_reason="end_turn", + state={}, + metrics=None, + ), + }, + ], + ) + + multiagent_result = graph("Test task") + + # Graph should be interrupted + assert multiagent_result.status == Status.INTERRUPTED + assert len(multiagent_result.interrupts) == 1 + + # Agent2 should have completed and be tracked in completed_nodes context + assert "completed_nodes" in graph._interrupt_state.context + assert "agent2" in graph._interrupt_state.context["completed_nodes"] + + +def test_graph_interrupt_on_agent_multiple_interrupts(agenerator): + """Test that multiple agent nodes can raise interrupts simultaneously.""" + interrupt1 = Interrupt(id="int1", name="interrupt1", reason="reason1") + interrupt2 = Interrupt(id="int2", name="interrupt2", reason="reason2") + + # Create two agents, both will interrupt + agent1 = create_mock_agent("agent1", "Agent 1 result") + agent1.messages = [] + agent1.state = AgentState() + agent1._interrupt_state = _InterruptState() + + agent2 = create_mock_agent("agent2", "Agent 2 result") + agent2.messages = [] + agent2.state = AgentState() + agent2._interrupt_state = _InterruptState() + + builder = GraphBuilder() + builder.add_node(agent1, "agent1") + builder.add_node(agent2, "agent2") + graph = builder.build() + + # Both agents interrupt + agent1.stream_async = Mock() + agent1.stream_async.return_value = agenerator( + [{"result": AgentResult(message={}, stop_reason="interrupt", state={}, metrics=None, interrupts=[interrupt1])}] + ) + + agent2.stream_async = Mock() + agent2.stream_async.return_value = agenerator( + [{"result": AgentResult(message={}, stop_reason="interrupt", state={}, metrics=None, interrupts=[interrupt2])}] + ) + + multiagent_result = graph("Test task") + + assert multiagent_result.status == Status.INTERRUPTED + # Both interrupts should be collected + assert len(multiagent_result.interrupts) == 2 + interrupt_ids = {i.id for i in multiagent_result.interrupts} + assert "int1" in interrupt_ids + assert "int2" in interrupt_ids + + +def test_graph_interrupt_on_agent_state_serialization(agenerator): + """Test that interrupt state is properly serialized for session management.""" + interrupt = Interrupt(id="test_id", name="test_name", reason="test_reason") + + agent = create_mock_agent("test_agent", "Task completed") + agent.messages = [{"role": "user", "content": [{"text": "Hello"}]}] + agent.state = AgentState({"key": "value"}) + agent._interrupt_state = _InterruptState() + + builder = GraphBuilder() + builder.add_node(agent, "test_agent") + graph = builder.build() + + agent.stream_async = Mock() + agent.stream_async.return_value = agenerator( + [{"result": AgentResult(message={}, stop_reason="interrupt", state={}, metrics=None, interrupts=[interrupt])}] + ) + + multiagent_result = graph("Test task") + assert multiagent_result.status == Status.INTERRUPTED + + # Serialize state + serialized = graph.serialize_state() + + # Verify interrupt state is included + assert "_internal_state" in serialized + assert "interrupt_state" in serialized["_internal_state"] + assert serialized["_internal_state"]["interrupt_state"]["activated"] is True + + # Verify agent state is stored in context + assert serialized["_internal_state"]["interrupt_state"]["context"].get("test_agent") is not None diff --git a/tests_integ/interrupts/multiagent/test_agent.py b/tests_integ/interrupts/multiagent/test_agent.py index 36fcfef27..9ae8e784a 100644 --- a/tests_integ/interrupts/multiagent/test_agent.py +++ b/tests_integ/interrupts/multiagent/test_agent.py @@ -5,7 +5,7 @@ from strands import Agent, tool from strands.interrupt import Interrupt -from strands.multiagent import Swarm +from strands.multiagent import GraphBuilder, Swarm from strands.multiagent.base import Status from strands.types.tools import ToolContext @@ -65,3 +65,111 @@ def test_swarm_interrupt_agent(swarm): weather_message = json.dumps(weather_result.result.message).lower() assert "sunny" in weather_message + + +# ============================================================================ +# Graph Agent Interrupt Integration Tests (Issue #1526) +# ============================================================================ + + +@pytest.fixture +def graph(weather_tool): + """Create a graph with an agent that can raise interrupts.""" + weather_agent = Agent(name="weather", tools=[weather_tool]) + + builder = GraphBuilder() + builder.add_node(weather_agent, "weather_agent") + return builder.build() + + +def test_graph_interrupt_agent(graph): + """Test that an agent node in a Graph can raise an interrupt and resume.""" + multiagent_result = graph("What is the weather?") + + tru_status = multiagent_result.status + exp_status = Status.INTERRUPTED + assert tru_status == exp_status + + tru_interrupts = multiagent_result.interrupts + exp_interrupts = [ + Interrupt( + id=ANY, + name="test_interrupt", + reason="need weather", + ), + ] + assert tru_interrupts == exp_interrupts + + interrupt = multiagent_result.interrupts[0] + + responses = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "sunny", + }, + }, + ] + multiagent_result = graph(responses) + + tru_status = multiagent_result.status + exp_status = Status.COMPLETED + assert tru_status == exp_status + + assert len(multiagent_result.results) == 1 + weather_result = multiagent_result.results["weather_agent"] + + weather_message = json.dumps(weather_result.result.message).lower() + assert "sunny" in weather_message + + +def test_graph_interrupt_agent_parallel(): + """Test Graph with parallel agent nodes where one raises an interrupt.""" + + @tool(name="interrupt_tool", context=True) + def interrupt_tool(tool_context: ToolContext) -> str: + response = tool_context.interrupt("approval", reason="need approval") + return f"Approved: {response}" + + @tool(name="non_interrupt_tool") + def non_interrupt_tool() -> str: + return "Non-interrupt task completed" + + # Create two agents: one that will interrupt, one that won't + interrupt_agent = Agent(name="interrupt_agent", tools=[interrupt_tool]) + non_interrupt_agent = Agent(name="non_interrupt_agent", tools=[non_interrupt_tool]) + + builder = GraphBuilder() + builder.add_node(interrupt_agent, "interrupt_agent") + builder.add_node(non_interrupt_agent, "non_interrupt_agent") + # Both are entry points, so they execute in parallel + graph = builder.build() + + # First invocation - both agents start, interrupt_agent raises interrupt + multiagent_result = graph("Execute tasks") + + assert multiagent_result.status == Status.INTERRUPTED + assert len(multiagent_result.interrupts) == 1 + assert multiagent_result.interrupts[0].name == "approval" + + # non_interrupt_agent should have completed + assert "completed_nodes" in graph._interrupt_state.context + assert "non_interrupt_agent" in graph._interrupt_state.context["completed_nodes"] + + # Resume with response + interrupt = multiagent_result.interrupts[0] + responses = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "yes", + }, + }, + ] + multiagent_result = graph(responses) + + assert multiagent_result.status == Status.COMPLETED + # Both agents should have results now + assert len(multiagent_result.results) == 2 + assert "interrupt_agent" in multiagent_result.results + assert "non_interrupt_agent" in multiagent_result.results