Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/strands/multiagent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,7 +1005,8 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
raise

finally:
await self.hooks.invoke_callbacks_async(AfterNodeCallEvent(self, node.node_id, invocation_state))
if node.execution_status != Status.INTERRUPTED:
await self.hooks.invoke_callbacks_async(AfterNodeCallEvent(self, node.node_id, invocation_state))

def _accumulate_metrics(self, node_result: NodeResult) -> None:
"""Accumulate metrics from a node result."""
Expand Down
7 changes: 4 additions & 3 deletions src/strands/multiagent/swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,9 +782,10 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato
break

finally:
await self.hooks.invoke_callbacks_async(
AfterNodeCallEvent(self, current_node.node_id, invocation_state)
)
if self.state.completion_status != Status.INTERRUPTED:
await self.hooks.invoke_callbacks_async(
AfterNodeCallEvent(self, current_node.node_id, invocation_state)
)

logger.debug("node=<%s> | node execution completed", current_node.node_id)

Expand Down
9 changes: 8 additions & 1 deletion tests/strands/multiagent/conftest.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
import pytest

from strands.hooks import BeforeNodeCallEvent, HookProvider
from strands.hooks import AfterNodeCallEvent, BeforeNodeCallEvent, HookProvider


@pytest.fixture
def interrupt_hook():
class Hook(HookProvider):
def __init__(self):
self.after_count = 0

def register_hooks(self, registry):
registry.add_callback(BeforeNodeCallEvent, self.interrupt)
registry.add_callback(AfterNodeCallEvent, self.cleanup)

def interrupt(self, event):
return event.interrupt("test_name", reason="test_reason")

def cleanup(self, event):
self.after_count += 1

return Hook()
8 changes: 8 additions & 0 deletions tests/strands/multiagent/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2126,6 +2126,10 @@ def test_graph_interrupt_on_before_node_call_event(interrupt_hook):
]
assert tru_interrupts == exp_interrupts

tru_after_count = interrupt_hook.after_count
exp_after_count = 0
assert tru_after_count == exp_after_count

interrupt = multiagent_result.interrupts[0]
responses = [
{
Expand All @@ -2152,4 +2156,8 @@ def test_graph_interrupt_on_before_node_call_event(interrupt_hook):
exp_message = "Task completed"
assert tru_message == exp_message

tru_after_count = interrupt_hook.after_count
exp_after_count = 1
assert tru_after_count == exp_after_count

assert multiagent_result.execution_time >= first_execution_time
8 changes: 8 additions & 0 deletions tests/strands/multiagent/test_swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1259,6 +1259,10 @@ def test_swarm_interrupt_on_before_node_call_event(interrupt_hook):
]
assert tru_interrupts == exp_interrupts

tru_after_count = interrupt_hook.after_count
exp_after_count = 0
assert tru_after_count == exp_after_count

interrupt = multiagent_result.interrupts[0]
responses = [
{
Expand All @@ -1281,6 +1285,10 @@ def test_swarm_interrupt_on_before_node_call_event(interrupt_hook):
exp_message = "Task completed"
assert tru_message == exp_message

tru_after_count = interrupt_hook.after_count
exp_after_count = 1
assert tru_after_count == exp_after_count

assert multiagent_result.execution_time >= first_execution_time


Expand Down
Loading