From 3f86a4e16a75713613f07d84d5ccab95e8c1f412 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow Date: Thu, 15 Jan 2026 14:28:58 -0500 Subject: [PATCH] fix: Swap sleeps with explicit signaling So that unit tests are determistic --- tests/strands/agent/test_agent.py | 59 +++++++++++++++++++++++-------- 1 file changed, 44 insertions(+), 15 deletions(-) diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 81ce65989..eb039185c 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -1,14 +1,13 @@ -import asyncio import copy import importlib import json import os import textwrap import threading -import time import unittest.mock import warnings -from typing import Any, AsyncGenerator +from collections.abc import AsyncGenerator +from typing import Any from uuid import uuid4 import pytest @@ -193,11 +192,25 @@ class User(BaseModel): return User(name="Jane Doe", age=30, email="jane@doe.com") -class SlowMockedModel(MockedModelProvider): +class SyncEventMockedModel(MockedModelProvider): + """A mock model that uses events to synchronize concurrent threads. + + This model signals when it starts streaming and waits for a proceed signal, + allowing deterministic testing of concurrent behavior without relying on sleeps. + """ + + def __init__(self, agent_responses): + super().__init__(agent_responses) + self.started_event = threading.Event() + self.proceed_event = threading.Event() + async def stream( self, messages, tool_specs=None, system_prompt=None, tool_choice=None, **kwargs ) -> AsyncGenerator[Any, None]: - await asyncio.sleep(0.15) # Add async delay to ensure concurrency + # Signal that streaming has started + self.started_event.set() + # Wait for signal to proceed + self.proceed_event.wait() async for event in super().stream(messages, tool_specs, system_prompt, tool_choice, **kwargs): yield event @@ -2212,7 +2225,7 @@ def test_agent_skips_fix_for_valid_conversation(mock_model, agenerator): def test_agent_concurrent_call_raises_exception(): """Test that concurrent __call__() calls raise ConcurrencyException.""" - model = SlowMockedModel( + model = SyncEventMockedModel( [ {"role": "assistant", "content": [{"text": "hello"}]}, {"role": "assistant", "content": [{"text": "world"}]}, @@ -2233,12 +2246,20 @@ def invoke(): with lock: errors.append(e) - # Create two threads that will try to invoke concurrently + # Start first thread and wait for it to begin streaming t1 = threading.Thread(target=invoke) - t2 = threading.Thread(target=invoke) - t1.start() + model.started_event.wait() # Wait until first thread is in the model.stream() + + # Start second thread while first is still running + t2 = threading.Thread(target=invoke) t2.start() + + # Give second thread time to attempt invocation and fail + t2.join(timeout=1.0) + + # Now let first thread complete + model.proceed_event.set() t1.join() t2.join() @@ -2254,11 +2275,12 @@ def test_agent_concurrent_structured_output_raises_exception(): Note: This test validates that the sync invocation path is protected. The concurrent __call__() test already validates the core functionality. """ - model = SlowMockedModel( + # Events for synchronization + model = SyncEventMockedModel( [ {"role": "assistant", "content": [{"text": "response1"}]}, {"role": "assistant", "content": [{"text": "response2"}]}, - ] + ], ) agent = Agent(model=model) @@ -2275,13 +2297,20 @@ def invoke(): with lock: errors.append(e) - # Create two threads that will try to invoke concurrently + # Start first thread and wait for it to begin streaming t1 = threading.Thread(target=invoke) - t2 = threading.Thread(target=invoke) - t1.start() - time.sleep(0.05) # Small delay to ensure first thread acquires lock + model.started_event.wait() # Wait until first thread is in the model.stream() + + # Start second thread while first is still running + t2 = threading.Thread(target=invoke) t2.start() + + # Give second thread time to attempt invocation and fail + t2.join(timeout=1.0) + + # Now let first thread complete + model.proceed_event.set() t1.join() t2.join()