Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 48 additions & 3 deletions pyrit/prompt_target/openai/openai_target.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import asyncio
import json
import logging
import re
Expand Down Expand Up @@ -37,6 +38,45 @@
logger = logging.getLogger(__name__)


def _ensure_async_token_provider(
api_key: Optional[str | Callable[[], str | Awaitable[str]]],
) -> Optional[str | Callable[[], Awaitable[str]]]:
"""
Ensure the api_key is either a string or an async callable.

If a synchronous callable token provider is provided, it's automatically wrapped
in an async function to make it compatible with AsyncOpenAI.

Args:
api_key: Either a string API key or a callable that returns a token (sync or async).

Returns:
Either a string API key or an async callable that returns a token.
"""
if api_key is None or isinstance(api_key, str) or not callable(api_key):
return api_key

# Check if the callable is already async
if asyncio.iscoroutinefunction(api_key):
return api_key

# Wrap synchronous token provider in async function
logger.info(
"Detected synchronous token provider. Automatically wrapping in async function for compatibility with AsyncOpenAI."
)

async def async_token_provider() -> str:
"""
Async wrapper for synchronous token provider.

Returns:
str: The token string from the synchronous provider.
"""
return api_key() # type: ignore

return async_token_provider


class OpenAITarget(PromptChatTarget):
"""
Abstract base class for OpenAI-based prompt targets.
Expand Down Expand Up @@ -75,9 +115,11 @@ def __init__(
model_name (str, Optional): The name of the model (or name of deployment in Azure).
If no value is provided, the environment variable will be used (set by subclass).
endpoint (str, Optional): The target URL for the OpenAI service.
api_key (str | Callable[[], str], Optional): The API key for accessing the OpenAI service,
or a callable that returns an access token. For Azure endpoints with Entra authentication,
pass a token provider from pyrit.auth (e.g., get_azure_openai_auth(endpoint)).
api_key (str | Callable[[], str | Awaitable[str]], Optional): The API key for accessing the
OpenAI service, or a callable that returns an access token (sync or async).
For Azure endpoints with Entra authentication, pass a token provider from pyrit.auth
(e.g., get_azure_openai_auth(endpoint) for async, or get_azure_token_provider(scope) for sync).
Synchronous token providers are automatically wrapped to work with async clients.
Defaults to the target-specific API key environment variable.
headers (str, Optional): Extra headers of the endpoint (JSON).
max_requests_per_minute (int, Optional): Number of requests the target can handle per
Expand Down Expand Up @@ -129,6 +171,9 @@ def __init__(
env_var_name=self.api_key_environment_variable, passed_value=api_key
)

# Ensure api_key is async-compatible (wrap sync token providers if needed)
self._api_key = _ensure_async_token_provider(self._api_key)

self._initialize_openai_client()

def _extract_deployment_from_azure_url(self, url: str) -> str:
Expand Down
27 changes: 27 additions & 0 deletions tests/integration/targets/test_entra_auth_targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,30 @@ async def test_prompt_shield_target_entra_auth(sqlite_instance):
result = await attack.execute_async(objective="test")
assert result is not None
assert result.last_response is not None


@pytest.mark.asyncio
async def test_openai_chat_target_with_sync_token_provider(sqlite_instance):
"""Test that OpenAIChatTarget works with synchronous token providers (auto-wrapped)."""
from azure.identity import DefaultAzureCredential, get_bearer_token_provider

endpoint = os.environ["AZURE_OPENAI_GPT4O_ENDPOINT"]
model_name = os.environ["AZURE_OPENAI_GPT4O_MODEL"]

# Use synchronous token provider - this should be auto-wrapped
sync_token_provider = get_bearer_token_provider(
DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
)

target = OpenAIChatTarget(
endpoint=endpoint,
model_name=model_name,
api_key=sync_token_provider,
temperature=0.0,
seed=42,
)

attack = PromptSendingAttack(objective_target=target)
result = await attack.execute_async(objective="Hello, how are you?")
assert result is not None
assert result.last_response is not None
8 changes: 6 additions & 2 deletions tests/unit/target/test_openai_chat_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,9 +715,13 @@ def mock_token_provider():
api_key=mock_token_provider,
)

# Verify token provider was stored as api_key
# Verify token provider was stored as api_key (wrapped in async)
assert callable(target._api_key)
assert target._api_key() == "mock-entra-token"
# Since sync provider is wrapped, _api_key is now async
import asyncio

assert asyncio.iscoroutinefunction(target._api_key)
assert asyncio.run(target._api_key()) == "mock-entra-token"


def test_set_auth_with_api_key(patch_central_database):
Expand Down
248 changes: 248 additions & 0 deletions tests/unit/target/test_token_provider_wrapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import asyncio
from unittest.mock import AsyncMock, patch

import pytest

from pyrit.prompt_target.openai.openai_target import _ensure_async_token_provider


class TestTokenProviderWrapping:
"""Test suite for synchronous token provider auto-wrapping functionality."""

def test_string_api_key_unchanged(self):
"""Test that string API keys are returned unchanged."""
api_key = "sk-test-key-12345"
result = _ensure_async_token_provider(api_key)
assert result == api_key
assert isinstance(result, str)

def test_none_api_key_unchanged(self):
"""Test that None is returned unchanged."""
result = _ensure_async_token_provider(None)
assert result is None

def test_async_token_provider_unchanged(self):
"""Test that async token providers are returned unchanged."""

async def async_token_provider():
return "async-token"

result = _ensure_async_token_provider(async_token_provider)
assert result is async_token_provider
assert asyncio.iscoroutinefunction(result)

def test_sync_token_provider_wrapped(self):
"""Test that synchronous token providers are automatically wrapped."""

def sync_token_provider():
return "sync-token"

result = _ensure_async_token_provider(sync_token_provider)

# Should return a different callable (the wrapper)
assert result is not sync_token_provider
assert callable(result)
assert asyncio.iscoroutinefunction(result)

@pytest.mark.asyncio
async def test_wrapped_sync_provider_returns_correct_token(self):
"""Test that wrapped synchronous token provider returns the correct token."""

def sync_token_provider():
return "my-sync-token"

wrapped = _ensure_async_token_provider(sync_token_provider)

# Call the wrapped provider
token = await wrapped()
assert token == "my-sync-token"

@pytest.mark.asyncio
async def test_async_provider_returns_correct_token(self):
"""Test that async token providers work correctly."""

async def async_token_provider():
return "my-async-token"

result = _ensure_async_token_provider(async_token_provider)

# Should be the same function
assert result is async_token_provider

# Call the provider
token = await result()
assert token == "my-async-token"

@pytest.mark.asyncio
async def test_wrapped_sync_provider_called_correctly(self):
"""Test that the wrapped sync provider calls the original function."""
call_count = 0

def sync_token_provider():
nonlocal call_count
call_count += 1
return f"token-{call_count}"

wrapped = _ensure_async_token_provider(sync_token_provider)

# Call multiple times
token1 = await wrapped()
token2 = await wrapped()

assert token1 == "token-1"
assert token2 == "token-2"
assert call_count == 2

def test_sync_provider_wrapping_logs_info(self):
"""Test that wrapping a sync provider logs an info message."""

def sync_token_provider():
return "token"

with patch("pyrit.prompt_target.openai.openai_target.logger") as mock_logger:
_ensure_async_token_provider(sync_token_provider)
mock_logger.info.assert_called_once()
call_args = mock_logger.info.call_args[0][0]
assert "synchronous token provider" in call_args.lower()
assert "wrapping" in call_args.lower()


@pytest.mark.usefixtures("patch_central_database")
class TestOpenAITargetWithTokenProviders:
"""Test OpenAITarget initialization with different token provider types."""

@pytest.mark.asyncio
async def test_openai_target_with_sync_token_provider(self):
"""Test that OpenAIChatTarget works with synchronous token providers."""
from pyrit.prompt_target import OpenAIChatTarget

def sync_token_provider():
return "sync-token-value"

with (
patch("pyrit.prompt_target.openai.openai_target.AsyncOpenAI") as mock_openai,
patch("pyrit.prompt_target.openai.openai_target.logger") as mock_logger,
):
mock_client = AsyncMock()
mock_openai.return_value = mock_client

target = OpenAIChatTarget(
endpoint="https://api.openai.com/v1",
model_name="gpt-4o",
api_key=sync_token_provider,
)

# Verify that info log was called about wrapping
mock_logger.info.assert_called()
info_call_found = False
for call in mock_logger.info.call_args_list:
if "synchronous token provider" in str(call).lower():
info_call_found = True
break
assert info_call_found, "Expected info log about wrapping sync token provider"

# Verify AsyncOpenAI was initialized
mock_openai.assert_called_once()
call_kwargs = mock_openai.call_args[1]

# The api_key should be a callable
api_key_arg = call_kwargs["api_key"]
assert callable(api_key_arg)
assert asyncio.iscoroutinefunction(api_key_arg)

# Verify the wrapped token provider returns correct value
token = await api_key_arg()
assert token == "sync-token-value"

@pytest.mark.asyncio
async def test_openai_target_with_async_token_provider(self):
"""Test that OpenAIChatTarget works with async token providers."""
from pyrit.prompt_target import OpenAIChatTarget

async def async_token_provider():
return "async-token-value"

with patch("pyrit.prompt_target.openai.openai_target.AsyncOpenAI") as mock_openai:
mock_client = AsyncMock()
mock_openai.return_value = mock_client

target = OpenAIChatTarget(
endpoint="https://api.openai.com/v1",
model_name="gpt-4o",
api_key=async_token_provider,
)

# Verify AsyncOpenAI was initialized
mock_openai.assert_called_once()
call_kwargs = mock_openai.call_args[1]

# The api_key should be the same async callable
api_key_arg = call_kwargs["api_key"]
assert api_key_arg is async_token_provider

# Verify the token provider returns correct value
token = await api_key_arg()
assert token == "async-token-value"

@pytest.mark.asyncio
async def test_openai_target_with_string_api_key(self):
"""Test that OpenAIChatTarget works with string API keys."""
from pyrit.prompt_target import OpenAIChatTarget

with patch("pyrit.prompt_target.openai.openai_target.AsyncOpenAI") as mock_openai:
mock_client = AsyncMock()
mock_openai.return_value = mock_client

target = OpenAIChatTarget(
endpoint="https://api.openai.com/v1",
model_name="gpt-4o",
api_key="sk-string-api-key",
)

# Verify AsyncOpenAI was initialized
mock_openai.assert_called_once()
call_kwargs = mock_openai.call_args[1]

# The api_key should be a string
api_key_arg = call_kwargs["api_key"]
assert api_key_arg == "sk-string-api-key"
assert isinstance(api_key_arg, str)

@pytest.mark.asyncio
async def test_azure_bearer_token_provider_integration(self):
"""Test integration with Azure bearer token provider pattern."""
from pyrit.prompt_target import OpenAIChatTarget

# Simulate get_bearer_token_provider from azure.identity (sync version)
def mock_sync_bearer_token_provider():
"""Mock synchronous bearer token provider."""
return "Bearer sync-azure-token"

with (
patch("pyrit.prompt_target.openai.openai_target.AsyncOpenAI") as mock_openai,
patch("pyrit.prompt_target.openai.openai_target.logger") as mock_logger,
):
mock_client = AsyncMock()
mock_openai.return_value = mock_client

target = OpenAIChatTarget(
endpoint="https://myresource.openai.azure.com/openai/v1",
model_name="gpt-4o",
api_key=mock_sync_bearer_token_provider,
)

# Verify that sync provider was wrapped
mock_logger.info.assert_called()

# Get the wrapped api_key
call_kwargs = mock_openai.call_args[1]
wrapped_provider = call_kwargs["api_key"]

assert asyncio.iscoroutinefunction(wrapped_provider)

# Verify it returns the correct token
token = await wrapped_provider()
assert token == "Bearer sync-azure-token"