From c7e950b71fa513b294914a64dd9a6950c47e7882 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Wed, 21 Jan 2026 16:50:58 -0500 Subject: [PATCH 1/2] interrupts - multiagent - do not emit AfterNodeCallEvent on interrupt --- src/strands/multiagent/graph.py | 3 ++- src/strands/multiagent/swarm.py | 7 ++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 32eca00ff..bad7eede9 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -1005,7 +1005,8 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) raise finally: - await self.hooks.invoke_callbacks_async(AfterNodeCallEvent(self, node.node_id, invocation_state)) + if node.execution_status != Status.INTERRUPTED: + await self.hooks.invoke_callbacks_async(AfterNodeCallEvent(self, node.node_id, invocation_state)) def _accumulate_metrics(self, node_result: NodeResult) -> None: """Accumulate metrics from a node result.""" diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index 9a4ce5494..10e0da515 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -782,9 +782,10 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato break finally: - await self.hooks.invoke_callbacks_async( - AfterNodeCallEvent(self, current_node.node_id, invocation_state) - ) + if self.state.completion_status != Status.INTERRUPTED: + await self.hooks.invoke_callbacks_async( + AfterNodeCallEvent(self, current_node.node_id, invocation_state) + ) logger.debug("node=<%s> | node execution completed", current_node.node_id) From 30b0b8dc6e2623911ee83cc8e5c9e091f38b6c9a Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Wed, 21 Jan 2026 18:11:24 -0500 Subject: [PATCH 2/2] add unit tests --- tests/strands/multiagent/conftest.py | 9 ++++++++- tests/strands/multiagent/test_graph.py | 8 ++++++++ tests/strands/multiagent/test_swarm.py | 8 ++++++++ 3 files changed, 24 insertions(+), 1 deletion(-) diff --git a/tests/strands/multiagent/conftest.py b/tests/strands/multiagent/conftest.py index e5dd1b4f9..190dc4a91 100644 --- a/tests/strands/multiagent/conftest.py +++ b/tests/strands/multiagent/conftest.py @@ -1,15 +1,22 @@ import pytest -from strands.hooks import BeforeNodeCallEvent, HookProvider +from strands.hooks import AfterNodeCallEvent, BeforeNodeCallEvent, HookProvider @pytest.fixture def interrupt_hook(): class Hook(HookProvider): + def __init__(self): + self.after_count = 0 + def register_hooks(self, registry): registry.add_callback(BeforeNodeCallEvent, self.interrupt) + registry.add_callback(AfterNodeCallEvent, self.cleanup) def interrupt(self, event): return event.interrupt("test_name", reason="test_reason") + def cleanup(self, event): + self.after_count += 1 + return Hook() diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index cd750865e..75482939d 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -2126,6 +2126,10 @@ def test_graph_interrupt_on_before_node_call_event(interrupt_hook): ] assert tru_interrupts == exp_interrupts + tru_after_count = interrupt_hook.after_count + exp_after_count = 0 + assert tru_after_count == exp_after_count + interrupt = multiagent_result.interrupts[0] responses = [ { @@ -2152,4 +2156,8 @@ def test_graph_interrupt_on_before_node_call_event(interrupt_hook): exp_message = "Task completed" assert tru_message == exp_message + tru_after_count = interrupt_hook.after_count + exp_after_count = 1 + assert tru_after_count == exp_after_count + assert multiagent_result.execution_time >= first_execution_time diff --git a/tests/strands/multiagent/test_swarm.py b/tests/strands/multiagent/test_swarm.py index 75ef97a25..491adc7c3 100644 --- a/tests/strands/multiagent/test_swarm.py +++ b/tests/strands/multiagent/test_swarm.py @@ -1259,6 +1259,10 @@ def test_swarm_interrupt_on_before_node_call_event(interrupt_hook): ] assert tru_interrupts == exp_interrupts + tru_after_count = interrupt_hook.after_count + exp_after_count = 0 + assert tru_after_count == exp_after_count + interrupt = multiagent_result.interrupts[0] responses = [ { @@ -1281,6 +1285,10 @@ def test_swarm_interrupt_on_before_node_call_event(interrupt_hook): exp_message = "Task completed" assert tru_message == exp_message + tru_after_count = interrupt_hook.after_count + exp_after_count = 1 + assert tru_after_count == exp_after_count + assert multiagent_result.execution_time >= first_execution_time