diff --git a/pyrit/prompt_target/openai/openai_target.py b/pyrit/prompt_target/openai/openai_target.py index fb2c3936c..677ede76d 100644 --- a/pyrit/prompt_target/openai/openai_target.py +++ b/pyrit/prompt_target/openai/openai_target.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import asyncio import json import logging import re @@ -37,6 +38,45 @@ logger = logging.getLogger(__name__) +def _ensure_async_token_provider( + api_key: Optional[str | Callable[[], str | Awaitable[str]]], +) -> Optional[str | Callable[[], Awaitable[str]]]: + """ + Ensure the api_key is either a string or an async callable. + + If a synchronous callable token provider is provided, it's automatically wrapped + in an async function to make it compatible with AsyncOpenAI. + + Args: + api_key: Either a string API key or a callable that returns a token (sync or async). + + Returns: + Either a string API key or an async callable that returns a token. + """ + if api_key is None or isinstance(api_key, str) or not callable(api_key): + return api_key + + # Check if the callable is already async + if asyncio.iscoroutinefunction(api_key): + return api_key + + # Wrap synchronous token provider in async function + logger.info( + "Detected synchronous token provider. Automatically wrapping in async function for compatibility with AsyncOpenAI." + ) + + async def async_token_provider() -> str: + """ + Async wrapper for synchronous token provider. + + Returns: + str: The token string from the synchronous provider. + """ + return api_key() # type: ignore + + return async_token_provider + + class OpenAITarget(PromptChatTarget): """ Abstract base class for OpenAI-based prompt targets. @@ -75,9 +115,11 @@ def __init__( model_name (str, Optional): The name of the model (or name of deployment in Azure). If no value is provided, the environment variable will be used (set by subclass). endpoint (str, Optional): The target URL for the OpenAI service. - api_key (str | Callable[[], str], Optional): The API key for accessing the OpenAI service, - or a callable that returns an access token. For Azure endpoints with Entra authentication, - pass a token provider from pyrit.auth (e.g., get_azure_openai_auth(endpoint)). + api_key (str | Callable[[], str | Awaitable[str]], Optional): The API key for accessing the + OpenAI service, or a callable that returns an access token (sync or async). + For Azure endpoints with Entra authentication, pass a token provider from pyrit.auth + (e.g., get_azure_openai_auth(endpoint) for async, or get_azure_token_provider(scope) for sync). + Synchronous token providers are automatically wrapped to work with async clients. Defaults to the target-specific API key environment variable. headers (str, Optional): Extra headers of the endpoint (JSON). max_requests_per_minute (int, Optional): Number of requests the target can handle per @@ -129,6 +171,9 @@ def __init__( env_var_name=self.api_key_environment_variable, passed_value=api_key ) + # Ensure api_key is async-compatible (wrap sync token providers if needed) + self._api_key = _ensure_async_token_provider(self._api_key) + self._initialize_openai_client() def _extract_deployment_from_azure_url(self, url: str) -> str: diff --git a/tests/integration/targets/test_entra_auth_targets.py b/tests/integration/targets/test_entra_auth_targets.py index 838845ef2..4dc94be30 100644 --- a/tests/integration/targets/test_entra_auth_targets.py +++ b/tests/integration/targets/test_entra_auth_targets.py @@ -186,3 +186,30 @@ async def test_prompt_shield_target_entra_auth(sqlite_instance): result = await attack.execute_async(objective="test") assert result is not None assert result.last_response is not None + + +@pytest.mark.asyncio +async def test_openai_chat_target_with_sync_token_provider(sqlite_instance): + """Test that OpenAIChatTarget works with synchronous token providers (auto-wrapped).""" + from azure.identity import DefaultAzureCredential, get_bearer_token_provider + + endpoint = os.environ["AZURE_OPENAI_GPT4O_ENDPOINT"] + model_name = os.environ["AZURE_OPENAI_GPT4O_MODEL"] + + # Use synchronous token provider - this should be auto-wrapped + sync_token_provider = get_bearer_token_provider( + DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default" + ) + + target = OpenAIChatTarget( + endpoint=endpoint, + model_name=model_name, + api_key=sync_token_provider, + temperature=0.0, + seed=42, + ) + + attack = PromptSendingAttack(objective_target=target) + result = await attack.execute_async(objective="Hello, how are you?") + assert result is not None + assert result.last_response is not None diff --git a/tests/unit/target/test_openai_chat_target.py b/tests/unit/target/test_openai_chat_target.py index 7e798c06c..58a10cfe3 100644 --- a/tests/unit/target/test_openai_chat_target.py +++ b/tests/unit/target/test_openai_chat_target.py @@ -715,9 +715,13 @@ def mock_token_provider(): api_key=mock_token_provider, ) - # Verify token provider was stored as api_key + # Verify token provider was stored as api_key (wrapped in async) assert callable(target._api_key) - assert target._api_key() == "mock-entra-token" + # Since sync provider is wrapped, _api_key is now async + import asyncio + + assert asyncio.iscoroutinefunction(target._api_key) + assert asyncio.run(target._api_key()) == "mock-entra-token" def test_set_auth_with_api_key(patch_central_database): diff --git a/tests/unit/target/test_token_provider_wrapping.py b/tests/unit/target/test_token_provider_wrapping.py new file mode 100644 index 000000000..8ca874ce0 --- /dev/null +++ b/tests/unit/target/test_token_provider_wrapping.py @@ -0,0 +1,248 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import asyncio +from unittest.mock import AsyncMock, patch + +import pytest + +from pyrit.prompt_target.openai.openai_target import _ensure_async_token_provider + + +class TestTokenProviderWrapping: + """Test suite for synchronous token provider auto-wrapping functionality.""" + + def test_string_api_key_unchanged(self): + """Test that string API keys are returned unchanged.""" + api_key = "sk-test-key-12345" + result = _ensure_async_token_provider(api_key) + assert result == api_key + assert isinstance(result, str) + + def test_none_api_key_unchanged(self): + """Test that None is returned unchanged.""" + result = _ensure_async_token_provider(None) + assert result is None + + def test_async_token_provider_unchanged(self): + """Test that async token providers are returned unchanged.""" + + async def async_token_provider(): + return "async-token" + + result = _ensure_async_token_provider(async_token_provider) + assert result is async_token_provider + assert asyncio.iscoroutinefunction(result) + + def test_sync_token_provider_wrapped(self): + """Test that synchronous token providers are automatically wrapped.""" + + def sync_token_provider(): + return "sync-token" + + result = _ensure_async_token_provider(sync_token_provider) + + # Should return a different callable (the wrapper) + assert result is not sync_token_provider + assert callable(result) + assert asyncio.iscoroutinefunction(result) + + @pytest.mark.asyncio + async def test_wrapped_sync_provider_returns_correct_token(self): + """Test that wrapped synchronous token provider returns the correct token.""" + + def sync_token_provider(): + return "my-sync-token" + + wrapped = _ensure_async_token_provider(sync_token_provider) + + # Call the wrapped provider + token = await wrapped() + assert token == "my-sync-token" + + @pytest.mark.asyncio + async def test_async_provider_returns_correct_token(self): + """Test that async token providers work correctly.""" + + async def async_token_provider(): + return "my-async-token" + + result = _ensure_async_token_provider(async_token_provider) + + # Should be the same function + assert result is async_token_provider + + # Call the provider + token = await result() + assert token == "my-async-token" + + @pytest.mark.asyncio + async def test_wrapped_sync_provider_called_correctly(self): + """Test that the wrapped sync provider calls the original function.""" + call_count = 0 + + def sync_token_provider(): + nonlocal call_count + call_count += 1 + return f"token-{call_count}" + + wrapped = _ensure_async_token_provider(sync_token_provider) + + # Call multiple times + token1 = await wrapped() + token2 = await wrapped() + + assert token1 == "token-1" + assert token2 == "token-2" + assert call_count == 2 + + def test_sync_provider_wrapping_logs_info(self): + """Test that wrapping a sync provider logs an info message.""" + + def sync_token_provider(): + return "token" + + with patch("pyrit.prompt_target.openai.openai_target.logger") as mock_logger: + _ensure_async_token_provider(sync_token_provider) + mock_logger.info.assert_called_once() + call_args = mock_logger.info.call_args[0][0] + assert "synchronous token provider" in call_args.lower() + assert "wrapping" in call_args.lower() + + +@pytest.mark.usefixtures("patch_central_database") +class TestOpenAITargetWithTokenProviders: + """Test OpenAITarget initialization with different token provider types.""" + + @pytest.mark.asyncio + async def test_openai_target_with_sync_token_provider(self): + """Test that OpenAIChatTarget works with synchronous token providers.""" + from pyrit.prompt_target import OpenAIChatTarget + + def sync_token_provider(): + return "sync-token-value" + + with ( + patch("pyrit.prompt_target.openai.openai_target.AsyncOpenAI") as mock_openai, + patch("pyrit.prompt_target.openai.openai_target.logger") as mock_logger, + ): + mock_client = AsyncMock() + mock_openai.return_value = mock_client + + target = OpenAIChatTarget( + endpoint="https://api.openai.com/v1", + model_name="gpt-4o", + api_key=sync_token_provider, + ) + + # Verify that info log was called about wrapping + mock_logger.info.assert_called() + info_call_found = False + for call in mock_logger.info.call_args_list: + if "synchronous token provider" in str(call).lower(): + info_call_found = True + break + assert info_call_found, "Expected info log about wrapping sync token provider" + + # Verify AsyncOpenAI was initialized + mock_openai.assert_called_once() + call_kwargs = mock_openai.call_args[1] + + # The api_key should be a callable + api_key_arg = call_kwargs["api_key"] + assert callable(api_key_arg) + assert asyncio.iscoroutinefunction(api_key_arg) + + # Verify the wrapped token provider returns correct value + token = await api_key_arg() + assert token == "sync-token-value" + + @pytest.mark.asyncio + async def test_openai_target_with_async_token_provider(self): + """Test that OpenAIChatTarget works with async token providers.""" + from pyrit.prompt_target import OpenAIChatTarget + + async def async_token_provider(): + return "async-token-value" + + with patch("pyrit.prompt_target.openai.openai_target.AsyncOpenAI") as mock_openai: + mock_client = AsyncMock() + mock_openai.return_value = mock_client + + target = OpenAIChatTarget( + endpoint="https://api.openai.com/v1", + model_name="gpt-4o", + api_key=async_token_provider, + ) + + # Verify AsyncOpenAI was initialized + mock_openai.assert_called_once() + call_kwargs = mock_openai.call_args[1] + + # The api_key should be the same async callable + api_key_arg = call_kwargs["api_key"] + assert api_key_arg is async_token_provider + + # Verify the token provider returns correct value + token = await api_key_arg() + assert token == "async-token-value" + + @pytest.mark.asyncio + async def test_openai_target_with_string_api_key(self): + """Test that OpenAIChatTarget works with string API keys.""" + from pyrit.prompt_target import OpenAIChatTarget + + with patch("pyrit.prompt_target.openai.openai_target.AsyncOpenAI") as mock_openai: + mock_client = AsyncMock() + mock_openai.return_value = mock_client + + target = OpenAIChatTarget( + endpoint="https://api.openai.com/v1", + model_name="gpt-4o", + api_key="sk-string-api-key", + ) + + # Verify AsyncOpenAI was initialized + mock_openai.assert_called_once() + call_kwargs = mock_openai.call_args[1] + + # The api_key should be a string + api_key_arg = call_kwargs["api_key"] + assert api_key_arg == "sk-string-api-key" + assert isinstance(api_key_arg, str) + + @pytest.mark.asyncio + async def test_azure_bearer_token_provider_integration(self): + """Test integration with Azure bearer token provider pattern.""" + from pyrit.prompt_target import OpenAIChatTarget + + # Simulate get_bearer_token_provider from azure.identity (sync version) + def mock_sync_bearer_token_provider(): + """Mock synchronous bearer token provider.""" + return "Bearer sync-azure-token" + + with ( + patch("pyrit.prompt_target.openai.openai_target.AsyncOpenAI") as mock_openai, + patch("pyrit.prompt_target.openai.openai_target.logger") as mock_logger, + ): + mock_client = AsyncMock() + mock_openai.return_value = mock_client + + target = OpenAIChatTarget( + endpoint="https://myresource.openai.azure.com/openai/v1", + model_name="gpt-4o", + api_key=mock_sync_bearer_token_provider, + ) + + # Verify that sync provider was wrapped + mock_logger.info.assert_called() + + # Get the wrapped api_key + call_kwargs = mock_openai.call_args[1] + wrapped_provider = call_kwargs["api_key"] + + assert asyncio.iscoroutinefunction(wrapped_provider) + + # Verify it returns the correct token + token = await wrapped_provider() + assert token == "Bearer sync-azure-token"