Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 44 additions & 15 deletions tests/strands/agent/test_agent.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import asyncio
import copy
import importlib
import json
import os
import textwrap
import threading
import time
import unittest.mock
import warnings
from typing import Any, AsyncGenerator
from collections.abc import AsyncGenerator
from typing import Any
from uuid import uuid4

import pytest
Expand Down Expand Up @@ -193,11 +192,25 @@ class User(BaseModel):
return User(name="Jane Doe", age=30, email="jane@doe.com")


class SlowMockedModel(MockedModelProvider):
class SyncEventMockedModel(MockedModelProvider):
"""A mock model that uses events to synchronize concurrent threads.

This model signals when it starts streaming and waits for a proceed signal,
allowing deterministic testing of concurrent behavior without relying on sleeps.
"""

def __init__(self, agent_responses):
super().__init__(agent_responses)
self.started_event = threading.Event()
self.proceed_event = threading.Event()

async def stream(
self, messages, tool_specs=None, system_prompt=None, tool_choice=None, **kwargs
) -> AsyncGenerator[Any, None]:
await asyncio.sleep(0.15) # Add async delay to ensure concurrency
# Signal that streaming has started
self.started_event.set()
# Wait for signal to proceed
self.proceed_event.wait()
async for event in super().stream(messages, tool_specs, system_prompt, tool_choice, **kwargs):
yield event

Expand Down Expand Up @@ -2212,7 +2225,7 @@ def test_agent_skips_fix_for_valid_conversation(mock_model, agenerator):

def test_agent_concurrent_call_raises_exception():
"""Test that concurrent __call__() calls raise ConcurrencyException."""
model = SlowMockedModel(
model = SyncEventMockedModel(
[
{"role": "assistant", "content": [{"text": "hello"}]},
{"role": "assistant", "content": [{"text": "world"}]},
Expand All @@ -2233,12 +2246,20 @@ def invoke():
with lock:
errors.append(e)

# Create two threads that will try to invoke concurrently
# Start first thread and wait for it to begin streaming
t1 = threading.Thread(target=invoke)
t2 = threading.Thread(target=invoke)

t1.start()
model.started_event.wait() # Wait until first thread is in the model.stream()

# Start second thread while first is still running
t2 = threading.Thread(target=invoke)
t2.start()

# Give second thread time to attempt invocation and fail
t2.join(timeout=1.0)

# Now let first thread complete
model.proceed_event.set()
t1.join()
t2.join()

Expand All @@ -2254,11 +2275,12 @@ def test_agent_concurrent_structured_output_raises_exception():
Note: This test validates that the sync invocation path is protected.
The concurrent __call__() test already validates the core functionality.
"""
model = SlowMockedModel(
# Events for synchronization
model = SyncEventMockedModel(
[
{"role": "assistant", "content": [{"text": "response1"}]},
{"role": "assistant", "content": [{"text": "response2"}]},
]
],
)
agent = Agent(model=model)

Expand All @@ -2275,13 +2297,20 @@ def invoke():
with lock:
errors.append(e)

# Create two threads that will try to invoke concurrently
# Start first thread and wait for it to begin streaming
t1 = threading.Thread(target=invoke)
t2 = threading.Thread(target=invoke)

t1.start()
time.sleep(0.05) # Small delay to ensure first thread acquires lock
model.started_event.wait() # Wait until first thread is in the model.stream()

# Start second thread while first is still running
t2 = threading.Thread(target=invoke)
t2.start()

# Give second thread time to attempt invocation and fail
t2.join(timeout=1.0)

# Now let first thread complete
model.proceed_event.set()
t1.join()
t2.join()

Expand Down
Loading