diff --git a/.env_example b/.env_example index 2d63d6691..d77940bcd 100644 --- a/.env_example +++ b/.env_example @@ -35,6 +35,16 @@ AZURE_OPENAI_GPT4_CHAT_ENDPOINT="https://xxxxx.openai.azure.com/openai/v1" AZURE_OPENAI_GPT4_CHAT_KEY="xxxxx" AZURE_OPENAI_GPT4_CHAT_MODEL="deployment-name" +# Endpoints that host models with fewer safety mechanisms (e.g. via adversarial fine tuning +# or content filters turned off) can be defined below and used in adversarial attack testing scenarios. +AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT="https://xxxxx.openai.azure.com/openai/v1" +AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY="xxxxx" +AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL="deployment-name" + +AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT2="https://xxxxx.openai.azure.com/openai/v1" +AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY2="xxxxx" +AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL2="deployment-name" + AZURE_FOUNDRY_DEEPSEEK_ENDPOINT="https://xxxxx.eastus2.models.ai.azure.com" AZURE_FOUNDRY_DEEPSEEK_KEY="xxxxx" diff --git a/doc/api.rst b/doc/api.rst index c774deca0..99273143f 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -703,6 +703,7 @@ API Reference PyRITInitializer AIRTInitializer + AIRTTargetInitializer SimpleInitializer LoadDefaultDatasets ScenarioObjectiveListInitializer diff --git a/doc/code/registry/2_instance_registry.ipynb b/doc/code/registry/2_instance_registry.ipynb index 24a8b1bb6..52ce37405 100644 --- a/doc/code/registry/2_instance_registry.ipynb +++ b/doc/code/registry/2_instance_registry.ipynb @@ -35,10 +35,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "Found default environment files: ['C:\\\\Users\\\\rlundeen\\\\.pyrit\\\\.env', 'C:\\\\Users\\\\rlundeen\\\\.pyrit\\\\.env.local']\n", - "Loaded environment file: C:\\Users\\rlundeen\\.pyrit\\.env\n", - "Loaded environment file: C:\\Users\\rlundeen\\.pyrit\\.env.local\n", - "Registered scorers: ['self_ask_refusal_d9007ba2']\n" + "Found default environment files: ['C:\\\\Users\\\\songjustin\\\\.pyrit\\\\.env']\n", + "Loaded environment file: C:\\Users\\songjustin\\.pyrit\\.env\n", + "Registered scorers: ['self_ask_refusal_scorer::94a582f5']\n" ] } ], @@ -83,7 +82,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Retrieved scorer: \n", + "Retrieved scorer: \n", "Scorer type: SelfAskRefusalScorer\n" ] } @@ -118,7 +117,7 @@ "output_type": "stream", "text": [ "\n", - "self_ask_refusal_d9007ba2:\n", + "self_ask_refusal_scorer::94a582f5:\n", " Class: SelfAskRefusalScorer\n", " Type: true_false\n", " Description: A self-ask scorer that detects refusal in AI responses. This...\n", @@ -126,7 +125,7 @@ "\u001b[1m 📊 Scorer Information\u001b[0m\n", "\u001b[37m ▸ Scorer Identifier\u001b[0m\n", "\u001b[36m • Scorer Type: SelfAskRefusalScorer\u001b[0m\n", - "\u001b[36m • Target Model: gpt-40\u001b[0m\n", + "\u001b[36m • Target Model: gpt-4o\u001b[0m\n", "\u001b[36m • Temperature: None\u001b[0m\n", "\u001b[36m • Score Aggregator: OR_\u001b[0m\n", "\n", @@ -141,12 +140,12 @@ "# Get metadata for all registered scorers\n", "metadata = registry.list_metadata()\n", "for item in metadata:\n", - " print(f\"\\n{item.name}:\")\n", + " print(f\"\\n{item.unique_name}:\")\n", " print(f\" Class: {item.class_name}\")\n", " print(f\" Type: {item.scorer_type}\")\n", - " print(f\" Description: {item.description[:60]}...\")\n", + " print(f\" Description: {item.class_description[:60]}...\")\n", "\n", - " ConsoleScorerPrinter().print_objective_scorer(scorer_identifier=item.scorer_identifier)" + " ConsoleScorerPrinter().print_objective_scorer(scorer_identifier=item)" ] }, { @@ -169,26 +168,69 @@ "name": "stdout", "output_type": "stream", "text": [ - "True/False scorers: ['self_ask_refusal_d9007ba2']\n", - "Refusal scorers: ['self_ask_refusal_d9007ba2']\n", - "True/False refusal scorers: ['self_ask_refusal_d9007ba2']\n" + "True/False scorers: ['self_ask_refusal_scorer::94a582f5']\n", + "Refusal scorers: ['self_ask_refusal_scorer::94a582f5']\n", + "True/False refusal scorers: ['self_ask_refusal_scorer::94a582f5']\n" ] } ], "source": [ "# Filter by scorer_type (based on isinstance check against TrueFalseScorer/FloatScaleScorer)\n", "true_false_scorers = registry.list_metadata(include_filters={\"scorer_type\": \"true_false\"})\n", - "print(f\"True/False scorers: {[m.name for m in true_false_scorers]}\")\n", + "print(f\"True/False scorers: {[m.unique_name for m in true_false_scorers]}\")\n", "\n", "# Filter by class_name\n", "refusal_scorers = registry.list_metadata(include_filters={\"class_name\": \"SelfAskRefusalScorer\"})\n", - "print(f\"Refusal scorers: {[m.name for m in refusal_scorers]}\")\n", + "print(f\"Refusal scorers: {[m.unique_name for m in refusal_scorers]}\")\n", "\n", "# Combine multiple filters (AND logic)\n", "specific_scorers = registry.list_metadata(\n", " include_filters={\"scorer_type\": \"true_false\", \"class_name\": \"SelfAskRefusalScorer\"}\n", ")\n", - "print(f\"True/False refusal scorers: {[m.name for m in specific_scorers]}\")" + "print(f\"True/False refusal scorers: {[m.unique_name for m in specific_scorers]}\")" + ] + }, + { + "cell_type": "markdown", + "id": "9", + "metadata": {}, + "source": [ + "## Using Target Initializer\n", + "\n", + "You can optionally use the `AIRTTargetInitializer` to automatically configure and register targets that use commonly used environment variables (from `.env_example`). This initializer does not strictly require any environment variables - it simply registers whatever endpoints are available." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found default environment files: ['C:\\\\Users\\\\songjustin\\\\.pyrit\\\\.env']\n", + "Loaded environment file: C:\\Users\\songjustin\\.pyrit\\.env\n", + "Registered targets after AIRT initialization: ['azure_content_safety', 'azure_gpt4o_unsafe_chat', 'azure_gpt4o_unsafe_chat2', 'default_openai_frontend', 'openai_chat', 'openai_image', 'openai_realtime', 'openai_responses', 'openai_tts', 'openai_video']\n" + ] + } + ], + "source": [ + "from pyrit.registry import TargetRegistry\n", + "from pyrit.setup import initialize_pyrit_async\n", + "from pyrit.setup.initializers import AIRTTargetInitializer\n", + "\n", + "# Using built-in initializer\n", + "await initialize_pyrit_async( # type: ignore\n", + " memory_db_type=\"InMemory\", initializers=[AIRTTargetInitializer()]\n", + ")\n", + "\n", + "# Get the registry singleton\n", + "registry = TargetRegistry.get_registry_singleton()\n", + "# List registered targets\n", + "target_names = registry.get_names()\n", + "print(f\"Registered targets after AIRT initialization: {target_names}\")" ] } ], @@ -203,7 +245,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.13.5" + "version": "3.11.9" } }, "nbformat": 4, diff --git a/doc/code/registry/2_instance_registry.py b/doc/code/registry/2_instance_registry.py index c20755730..d645529f2 100644 --- a/doc/code/registry/2_instance_registry.py +++ b/doc/code/registry/2_instance_registry.py @@ -5,11 +5,15 @@ # extension: .py # format_name: percent # format_version: '1.3' -# jupytext_version: 1.18.1 +# jupytext_version: 1.17.2 +# kernelspec: +# display_name: pyrit-dev +# language: python +# name: python3 # --- # %% [markdown] -# ## Why Instance Registries? +# # Why Instance Registries? # # Some components need configuration that can't easily be passed at instantiation time. For example, scorers often need: # - A configured `chat_target` for LLM-based scoring @@ -19,7 +23,7 @@ # Instance registries let initializers register fully-configured instances that are ready to use. # %% [markdown] -# # Listing Available Instances +# ## Listing Available Instances # # Use `get_names()` to see registered instances, or `list_metadata()` for details. @@ -67,12 +71,12 @@ # Get metadata for all registered scorers metadata = registry.list_metadata() for item in metadata: - print(f"\n{item.name}:") + print(f"\n{item.unique_name}:") print(f" Class: {item.class_name}") print(f" Type: {item.scorer_type}") - print(f" Description: {item.description[:60]}...") + print(f" Description: {item.class_description[:60]}...") - ConsoleScorerPrinter().print_objective_scorer(scorer_identifier=item.scorer_identifier) + ConsoleScorerPrinter().print_objective_scorer(scorer_identifier=item) # %% [markdown] # ## Filtering @@ -82,14 +86,35 @@ # %% # Filter by scorer_type (based on isinstance check against TrueFalseScorer/FloatScaleScorer) true_false_scorers = registry.list_metadata(include_filters={"scorer_type": "true_false"}) -print(f"True/False scorers: {[m.name for m in true_false_scorers]}") +print(f"True/False scorers: {[m.unique_name for m in true_false_scorers]}") # Filter by class_name refusal_scorers = registry.list_metadata(include_filters={"class_name": "SelfAskRefusalScorer"}) -print(f"Refusal scorers: {[m.name for m in refusal_scorers]}") +print(f"Refusal scorers: {[m.unique_name for m in refusal_scorers]}") # Combine multiple filters (AND logic) specific_scorers = registry.list_metadata( include_filters={"scorer_type": "true_false", "class_name": "SelfAskRefusalScorer"} ) -print(f"True/False refusal scorers: {[m.name for m in specific_scorers]}") +print(f"True/False refusal scorers: {[m.unique_name for m in specific_scorers]}") + +# %% [markdown] +# ## Using Target Initializer +# +# You can optionally use the `AIRTTargetInitializer` to automatically configure and register targets that use commonly used environment variables (from `.env_example`). This initializer does not strictly require any environment variables - it simply registers whatever endpoints are available. + +# %% +from pyrit.registry import TargetRegistry +from pyrit.setup import initialize_pyrit_async +from pyrit.setup.initializers import AIRTTargetInitializer + +# Using built-in initializer +await initialize_pyrit_async( # type: ignore + memory_db_type="InMemory", initializers=[AIRTTargetInitializer()] +) + +# Get the registry singleton +registry = TargetRegistry.get_registry_singleton() +# List registered targets +target_names = registry.get_names() +print(f"Registered targets after AIRT initialization: {target_names}") diff --git a/pyrit/identifiers/target_identifier.py b/pyrit/identifiers/target_identifier.py index f08ad8709..b8924fb0c 100644 --- a/pyrit/identifiers/target_identifier.py +++ b/pyrit/identifiers/target_identifier.py @@ -34,6 +34,9 @@ class TargetIdentifier(Identifier): max_requests_per_minute: Optional[int] = None """Maximum number of requests per minute.""" + supports_conversation_history: bool = False + """Whether the target supports explicit setting of conversation history (is a PromptChatTarget).""" + target_specific_params: Optional[Dict[str, Any]] = None """Additional target-specific parameters.""" diff --git a/pyrit/prompt_target/common/prompt_target.py b/pyrit/prompt_target/common/prompt_target.py index 8cd80f47d..653d008e6 100644 --- a/pyrit/prompt_target/common/prompt_target.py +++ b/pyrit/prompt_target/common/prompt_target.py @@ -122,6 +122,9 @@ def _create_identifier( elif self._model_name: model_name = self._model_name + # Late import to avoid circular dependency + from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget + return TargetIdentifier( class_name=self.__class__.__name__, class_module=self.__class__.__module__, @@ -132,6 +135,7 @@ def _create_identifier( temperature=temperature, top_p=top_p, max_requests_per_minute=self._max_requests_per_minute, + supports_conversation_history=isinstance(self, PromptChatTarget), target_specific_params=target_specific_params, ) diff --git a/pyrit/registry/__init__.py b/pyrit/registry/__init__.py index 209ec6c14..5f2fe7536 100644 --- a/pyrit/registry/__init__.py +++ b/pyrit/registry/__init__.py @@ -21,6 +21,7 @@ from pyrit.registry.instance_registries import ( BaseInstanceRegistry, ScorerRegistry, + TargetRegistry, ) __all__ = [ @@ -39,4 +40,5 @@ "ScenarioMetadata", "ScenarioRegistry", "ScorerRegistry", + "TargetRegistry", ] diff --git a/pyrit/registry/instance_registries/__init__.py b/pyrit/registry/instance_registries/__init__.py index eab870f0e..2cf50693c 100644 --- a/pyrit/registry/instance_registries/__init__.py +++ b/pyrit/registry/instance_registries/__init__.py @@ -17,10 +17,14 @@ from pyrit.registry.instance_registries.scorer_registry import ( ScorerRegistry, ) +from pyrit.registry.instance_registries.target_registry import ( + TargetRegistry, +) __all__ = [ # Base class "BaseInstanceRegistry", # Concrete registries "ScorerRegistry", + "TargetRegistry", ] diff --git a/pyrit/registry/instance_registries/target_registry.py b/pyrit/registry/instance_registries/target_registry.py new file mode 100644 index 000000000..3fcdbb316 --- /dev/null +++ b/pyrit/registry/instance_registries/target_registry.py @@ -0,0 +1,97 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Target registry for discovering and managing PyRIT prompt targets. + +Targets are registered explicitly via initializers as pre-configured instances. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Optional + +from pyrit.identifiers import TargetIdentifier +from pyrit.registry.instance_registries.base_instance_registry import ( + BaseInstanceRegistry, +) + +if TYPE_CHECKING: + from pyrit.prompt_target import PromptTarget + +logger = logging.getLogger(__name__) + + +class TargetRegistry(BaseInstanceRegistry["PromptTarget", TargetIdentifier]): + """ + Registry for managing available prompt target instances. + + This registry stores pre-configured PromptTarget instances (not classes). + Targets are registered explicitly via initializers after being instantiated + with their required parameters (e.g., endpoint, API keys). + + Targets are identified by their snake_case name derived from the class name, + or a custom name provided during registration. + """ + + @classmethod + def get_registry_singleton(cls) -> "TargetRegistry": + """ + Get the singleton instance of the TargetRegistry. + + Returns: + The singleton TargetRegistry instance. + """ + return super().get_registry_singleton() # type: ignore[return-value] + + def register_instance( + self, + target: "PromptTarget", + *, + name: Optional[str] = None, + ) -> None: + """ + Register a target instance. + + Note: Unlike ScenarioRegistry and InitializerRegistry which register classes, + TargetRegistry registers pre-configured instances. + + Args: + target: The pre-configured target instance (not a class). + name: Optional custom registry name. If not provided, + derived from class name with identifier hash appended + (e.g., OpenAIChatTarget -> openai_chat_abc123). + """ + if name is None: + name = target.get_identifier().unique_name + + self.register(target, name=name) + logger.debug(f"Registered target instance: {name} ({target.__class__.__name__})") + + def get_instance_by_name(self, name: str) -> Optional["PromptTarget"]: + """ + Get a registered target instance by name. + + Note: This returns an already-instantiated target, not a class. + + Args: + name: The registry name of the target. + + Returns: + The target instance, or None if not found. + """ + return self.get(name) + + def _build_metadata(self, name: str, instance: "PromptTarget") -> TargetIdentifier: + """ + Build metadata for a target instance. + + Args: + name: The registry name of the target. + instance: The target instance. + + Returns: + TargetIdentifier describing the target. + """ + return instance.get_identifier() diff --git a/pyrit/setup/initializers/__init__.py b/pyrit/setup/initializers/__init__.py index 1c0cbd468..6b1c63c48 100644 --- a/pyrit/setup/initializers/__init__.py +++ b/pyrit/setup/initializers/__init__.py @@ -4,6 +4,7 @@ """PyRIT initializers package.""" from pyrit.setup.initializers.airt import AIRTInitializer +from pyrit.setup.initializers.airt_targets import AIRTTargetInitializer from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer from pyrit.setup.initializers.scenarios.load_default_datasets import LoadDefaultDatasets from pyrit.setup.initializers.scenarios.objective_list import ScenarioObjectiveListInitializer @@ -13,6 +14,7 @@ __all__ = [ "PyRITInitializer", "AIRTInitializer", + "AIRTTargetInitializer", "SimpleInitializer", "LoadDefaultDatasets", "ScenarioObjectiveListInitializer", diff --git a/pyrit/setup/initializers/airt_targets.py b/pyrit/setup/initializers/airt_targets.py new file mode 100644 index 000000000..be98380ae --- /dev/null +++ b/pyrit/setup/initializers/airt_targets.py @@ -0,0 +1,223 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +AIRT Target Initializer for registering pre-configured targets from environment variables. + +This module provides the AIRTTargetInitializer class that registers available +targets into the TargetRegistry based on environment variable configuration. +""" + +import logging +import os +from dataclasses import dataclass +from typing import Any, List, Optional, Type + +from pyrit.prompt_target import ( + OpenAIChatTarget, + OpenAIImageTarget, + OpenAIResponseTarget, + OpenAITTSTarget, + OpenAIVideoTarget, + PromptShieldTarget, + PromptTarget, + RealtimeTarget, +) +from pyrit.registry import TargetRegistry +from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer + +logger = logging.getLogger(__name__) + + +@dataclass +class TargetConfig: + """Configuration for a target to be registered.""" + + registry_name: str + target_class: Type[PromptTarget] + endpoint_var: str + key_var: str + model_var: Optional[str] = None + underlying_model_var: Optional[str] = None + + +# Define all supported target configurations +TARGET_CONFIGS: List[TargetConfig] = [ + TargetConfig( + registry_name="default_openai_frontend", + target_class=OpenAIChatTarget, + endpoint_var="DEFAULT_OPENAI_FRONTEND_ENDPOINT", + key_var="DEFAULT_OPENAI_FRONTEND_KEY", + model_var="DEFAULT_OPENAI_FRONTEND_MODEL", + underlying_model_var="DEFAULT_OPENAI_FRONTEND_UNDERLYING_MODEL", + ), + TargetConfig( + registry_name="openai_chat", + target_class=OpenAIChatTarget, + endpoint_var="OPENAI_CHAT_ENDPOINT", + key_var="OPENAI_CHAT_KEY", + model_var="OPENAI_CHAT_MODEL", + underlying_model_var="OPENAI_CHAT_UNDERLYING_MODEL", + ), + TargetConfig( + registry_name="openai_responses", + target_class=OpenAIResponseTarget, + endpoint_var="OPENAI_RESPONSES_ENDPOINT", + key_var="OPENAI_RESPONSES_KEY", + model_var="OPENAI_RESPONSES_MODEL", + underlying_model_var="OPENAI_RESPONSES_UNDERLYING_MODEL", + ), + TargetConfig( + registry_name="azure_gpt4o_unsafe_chat", + target_class=OpenAIChatTarget, + endpoint_var="AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT", + key_var="AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY", + model_var="AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL", + underlying_model_var="AZURE_OPENAI_GPT4O_UNSAFE_CHAT_UNDERLYING_MODEL", + ), + TargetConfig( + registry_name="azure_gpt4o_unsafe_chat2", + target_class=OpenAIChatTarget, + endpoint_var="AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT2", + key_var="AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY2", + model_var="AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL2", + underlying_model_var="AZURE_OPENAI_GPT4O_UNSAFE_CHAT_UNDERLYING_MODEL2", + ), + TargetConfig( + registry_name="openai_realtime", + target_class=RealtimeTarget, + endpoint_var="OPENAI_REALTIME_ENDPOINT", + key_var="OPENAI_REALTIME_API_KEY", + model_var="OPENAI_REALTIME_MODEL", + underlying_model_var="OPENAI_REALTIME_UNDERLYING_MODEL", + ), + TargetConfig( + registry_name="openai_image", + target_class=OpenAIImageTarget, + endpoint_var="OPENAI_IMAGE_ENDPOINT", + key_var="OPENAI_IMAGE_API_KEY", + model_var="OPENAI_IMAGE_MODEL", + underlying_model_var="OPENAI_IMAGE_UNDERLYING_MODEL", + ), + TargetConfig( + registry_name="openai_tts", + target_class=OpenAITTSTarget, + endpoint_var="OPENAI_TTS_ENDPOINT", + key_var="OPENAI_TTS_KEY", + model_var="OPENAI_TTS_MODEL", + underlying_model_var="OPENAI_TTS_UNDERLYING_MODEL", + ), + TargetConfig( + registry_name="openai_video", + target_class=OpenAIVideoTarget, + endpoint_var="OPENAI_VIDEO_ENDPOINT", + key_var="OPENAI_VIDEO_KEY", + model_var="OPENAI_VIDEO_MODEL", + underlying_model_var="OPENAI_VIDEO_UNDERLYING_MODEL", + ), + TargetConfig( + registry_name="azure_content_safety", + target_class=PromptShieldTarget, + endpoint_var="AZURE_CONTENT_SAFETY_API_ENDPOINT", + key_var="AZURE_CONTENT_SAFETY_API_KEY", + ), +] + + +class AIRTTargetInitializer(PyRITInitializer): + """ + AIRT Target Initializer for registering pre-configured targets. + + This initializer scans for known endpoint environment variables and registers + the corresponding targets into the TargetRegistry. Unlike AIRTInitializer, + this initializer does not require any environment variables - it simply + registers whatever endpoints are available. + + Supported Endpoints: + - DEFAULT_OPENAI_FRONTEND_ENDPOINT: Default OpenAI frontend (OpenAIChatTarget) + - OPENAI_CHAT_ENDPOINT: OpenAI Chat API (OpenAIChatTarget) + - OPENAI_RESPONSES_ENDPOINT: OpenAI Responses API (OpenAIResponseTarget) + - AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT: Azure OpenAI GPT-4o unsafe (OpenAIChatTarget) + - AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT2: Azure OpenAI GPT-4o unsafe secondary (OpenAIChatTarget) + - OPENAI_REALTIME_ENDPOINT: OpenAI Realtime API (RealtimeTarget) + - OPENAI_IMAGE_ENDPOINT: OpenAI Image Generation (OpenAIImageTarget) + - OPENAI_TTS_ENDPOINT: OpenAI Text-to-Speech (OpenAITTSTarget) + - OPENAI_VIDEO_ENDPOINT: OpenAI Video Generation (OpenAIVideoTarget) + - AZURE_CONTENT_SAFETY_API_ENDPOINT: Azure Content Safety (PromptShieldTarget) + + Example: + initializer = AIRTTargetInitializer() + await initializer.initialize_async() + """ + + def __init__(self) -> None: + """Initialize the AIRT Target Initializer.""" + super().__init__() + + @property + def name(self) -> str: + """Get the name of this initializer.""" + return "AIRT Target Initializer" + + @property + def description(self) -> str: + """Get the description of this initializer.""" + return ( + "Instantiates a collection of (AI Red Team suggested) targets from " + "available environment variables and adds them to the TargetRegistry" + ) + + @property + def required_env_vars(self) -> List[str]: + """ + Get list of required environment variables. + + Returns empty list since this initializer is optional - it registers + whatever endpoints are available without requiring any. + """ + return [] + + async def initialize_async(self) -> None: + """ + Register available targets based on environment variables. + + Scans for known endpoint environment variables and registers the + corresponding targets into the TargetRegistry. + """ + for config in TARGET_CONFIGS: + self._register_target(config) + + def _register_target(self, config: TargetConfig) -> None: + """ + Register a target if its required environment variables are set. + + Args: + config: The target configuration specifying env vars and target class. + """ + endpoint = os.getenv(config.endpoint_var) + api_key = os.getenv(config.key_var) + + if not endpoint or not api_key: + return + + model_name = os.getenv(config.model_var) if config.model_var else None + underlying_model = os.getenv(config.underlying_model_var) if config.underlying_model_var else None + + # Build kwargs for the target constructor + kwargs: dict[str, Any] = { + "endpoint": endpoint, + "api_key": api_key, + } + + # Only add model_name if the target supports it (PromptShieldTarget doesn't) + if model_name is not None: + kwargs["model_name"] = model_name + + # Add underlying_model if specified (for Azure deployments where name differs from model) + if underlying_model is not None: + kwargs["underlying_model"] = underlying_model + + target = config.target_class(**kwargs) + registry = TargetRegistry.get_registry_singleton() + registry.register_instance(target, name=config.registry_name) + logger.info(f"Registered target: {config.registry_name}") diff --git a/tests/unit/identifiers/test_target_identifier.py b/tests/unit/identifiers/test_target_identifier.py index 0541b36be..148c60983 100644 --- a/tests/unit/identifiers/test_target_identifier.py +++ b/tests/unit/identifiers/test_target_identifier.py @@ -500,6 +500,63 @@ def test_can_use_as_dict_key(self): assert d[identifier] == "value" +class TestTargetIdentifierSupportsConversationHistory: + """Test the supports_conversation_history field in TargetIdentifier.""" + + def test_supports_conversation_history_defaults_to_false(self): + """Test that supports_conversation_history defaults to False.""" + identifier = TargetIdentifier( + class_name="TestTarget", + class_module="pyrit.prompt_target.test_target", + class_description="A test target", + identifier_type="instance", + ) + + assert identifier.supports_conversation_history is False + + def test_supports_conversation_history_included_in_hash(self): + """Test that supports_conversation_history affects the hash.""" + base_args = { + "class_name": "TestTarget", + "class_module": "pyrit.prompt_target.test_target", + "class_description": "A test target", + "identifier_type": "instance", + } + + identifier1 = TargetIdentifier(supports_conversation_history=False, **base_args) + identifier2 = TargetIdentifier(supports_conversation_history=True, **base_args) + + assert identifier1.hash != identifier2.hash + + def test_supports_conversation_history_in_to_dict(self): + """Test that supports_conversation_history is included in to_dict.""" + identifier = TargetIdentifier( + class_name="TestChatTarget", + class_module="pyrit.prompt_target.test_chat_target", + class_description="A test chat target", + identifier_type="instance", + supports_conversation_history=True, + ) + + result = identifier.to_dict() + + assert result["supports_conversation_history"] is True + + def test_supports_conversation_history_from_dict(self): + """Test that supports_conversation_history is restored from dict.""" + data = { + "class_name": "TestChatTarget", + "class_module": "pyrit.prompt_target.test_chat_target", + "class_description": "A test chat target", + "identifier_type": "instance", + "supports_conversation_history": True, + } + + identifier = TargetIdentifier.from_dict(data) + + assert identifier.supports_conversation_history is True + + class TestTargetIdentifierNormalize: """Test the normalize class method for TargetIdentifier.""" diff --git a/tests/unit/registry/test_target_registry.py b/tests/unit/registry/test_target_registry.py new file mode 100644 index 000000000..8e32411b8 --- /dev/null +++ b/tests/unit/registry/test_target_registry.py @@ -0,0 +1,277 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + + +import pytest + +from pyrit.identifiers import TargetIdentifier +from pyrit.models import Message, MessagePiece +from pyrit.prompt_target import PromptTarget +from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget +from pyrit.registry.instance_registries.target_registry import TargetRegistry + + +class MockPromptTarget(PromptTarget): + """Mock PromptTarget for testing.""" + + def __init__(self, *, model_name: str = "mock_model") -> None: + super().__init__(model_name=model_name) + + async def send_prompt_async( + self, + *, + message: Message, + ) -> list[Message]: + return [ + MessagePiece( + role="assistant", + original_value="mock response", + ).to_message() + ] + + def _validate_request(self, *, message: Message) -> None: + pass + + +class MockPromptChatTarget(PromptChatTarget): + """Mock PromptChatTarget for testing conversation history support.""" + + def __init__(self, *, model_name: str = "mock_chat_model", endpoint: str = "http://chat-test") -> None: + super().__init__(model_name=model_name, endpoint=endpoint) + + async def send_prompt_async( + self, + *, + message: Message, + ) -> list[Message]: + return [ + MessagePiece( + role="assistant", + original_value="chat response", + ).to_message() + ] + + def _validate_request(self, *, message: Message) -> None: + pass + + def is_json_response_supported(self) -> bool: + return False + + +class TestTargetRegistrySingleton: + """Tests for the singleton pattern in TargetRegistry.""" + + def setup_method(self): + """Reset the singleton before each test.""" + TargetRegistry.reset_instance() + + def teardown_method(self): + """Reset the singleton after each test.""" + TargetRegistry.reset_instance() + + def test_get_registry_singleton_returns_same_instance(self): + """Test that get_registry_singleton returns the same singleton each time.""" + instance1 = TargetRegistry.get_registry_singleton() + instance2 = TargetRegistry.get_registry_singleton() + + assert instance1 is instance2 + + def test_get_registry_singleton_returns_target_registry_type(self): + """Test that get_registry_singleton returns a TargetRegistry instance.""" + instance = TargetRegistry.get_registry_singleton() + assert isinstance(instance, TargetRegistry) + + def test_reset_instance_clears_singleton(self): + """Test that reset_instance clears the singleton.""" + instance1 = TargetRegistry.get_registry_singleton() + TargetRegistry.reset_instance() + instance2 = TargetRegistry.get_registry_singleton() + + assert instance1 is not instance2 + + +@pytest.mark.usefixtures("patch_central_database") +class TestTargetRegistryRegisterInstance: + """Tests for register_instance functionality in TargetRegistry.""" + + def setup_method(self): + """Reset and get a fresh registry for each test.""" + TargetRegistry.reset_instance() + self.registry = TargetRegistry.get_registry_singleton() + + def teardown_method(self): + """Reset the singleton after each test.""" + TargetRegistry.reset_instance() + + def test_register_instance_with_custom_name(self): + """Test registering a target with a custom name.""" + target = MockPromptTarget() + self.registry.register_instance(target, name="custom_target") + + assert "custom_target" in self.registry + assert self.registry.get("custom_target") is target + + def test_register_instance_generates_name_from_class(self): + """Test that register_instance generates a name from class name when not provided.""" + target = MockPromptTarget() + self.registry.register_instance(target) + + # Name should be derived from class name with hash suffix + names = self.registry.get_names() + assert len(names) == 1 + assert names[0].startswith("mock_prompt_") + + def test_register_instance_multiple_targets_unique_names(self): + """Test registering multiple targets generates unique names.""" + target1 = MockPromptTarget() + target2 = MockPromptChatTarget() + + self.registry.register_instance(target1) + self.registry.register_instance(target2) + + assert len(self.registry) == 2 + + def test_register_instance_same_target_type_different_config(self): + """Test that same target class with different configs can be registered.""" + target1 = MockPromptTarget(model_name="model_a") + target2 = MockPromptTarget(model_name="model_b") + + # Register with explicit names + self.registry.register_instance(target1, name="target_1") + self.registry.register_instance(target2, name="target_2") + + assert len(self.registry) == 2 + + +@pytest.mark.usefixtures("patch_central_database") +class TestTargetRegistryGetInstanceByName: + """Tests for get_instance_by_name functionality in TargetRegistry.""" + + def setup_method(self): + """Reset and get a fresh registry for each test.""" + TargetRegistry.reset_instance() + self.registry = TargetRegistry.get_registry_singleton() + self.target = MockPromptTarget() + self.registry.register_instance(self.target, name="test_target") + + def teardown_method(self): + """Reset the singleton after each test.""" + TargetRegistry.reset_instance() + + def test_get_instance_by_name_returns_target(self): + """Test getting a registered target by name.""" + result = self.registry.get_instance_by_name("test_target") + assert result is self.target + + def test_get_instance_by_name_nonexistent_returns_none(self): + """Test that getting a non-existent target returns None.""" + result = self.registry.get_instance_by_name("nonexistent") + assert result is None + + +@pytest.mark.usefixtures("patch_central_database") +class TestTargetRegistryBuildMetadata: + """Tests for _build_metadata functionality in TargetRegistry.""" + + def setup_method(self): + """Reset and get a fresh registry for each test.""" + TargetRegistry.reset_instance() + self.registry = TargetRegistry.get_registry_singleton() + + def teardown_method(self): + """Reset the singleton after each test.""" + TargetRegistry.reset_instance() + + def test_build_metadata_includes_class_name(self): + """Test that metadata (TargetIdentifier) includes the class name.""" + target = MockPromptTarget() + self.registry.register_instance(target, name="mock_target") + + metadata = self.registry.list_metadata() + assert len(metadata) == 1 + assert isinstance(metadata[0], TargetIdentifier) + assert metadata[0].class_name == "MockPromptTarget" + + def test_build_metadata_includes_model_name(self): + """Test that metadata includes the model_name.""" + target = MockPromptTarget(model_name="test_model") + self.registry.register_instance(target, name="mock_target") + + metadata = self.registry.list_metadata() + assert metadata[0].model_name == "test_model" + + def test_build_metadata_description_from_docstring(self): + """Test that class_description is derived from the target's docstring.""" + target = MockPromptTarget() + self.registry.register_instance(target, name="mock_target") + + metadata = self.registry.list_metadata() + # MockPromptTarget has a docstring + assert "Mock PromptTarget for testing" in metadata[0].class_description + + +@pytest.mark.usefixtures("patch_central_database") +class TestTargetRegistryListMetadata: + """Tests for list_metadata in TargetRegistry.""" + + def setup_method(self): + """Reset and get a fresh registry with multiple targets.""" + TargetRegistry.reset_instance() + self.registry = TargetRegistry.get_registry_singleton() + + self.target1 = MockPromptTarget(model_name="model_a") + self.target2 = MockPromptTarget(model_name="model_b") + self.chat_target = MockPromptChatTarget() + + self.registry.register_instance(self.target1, name="target_1") + self.registry.register_instance(self.target2, name="target_2") + self.registry.register_instance(self.chat_target, name="chat_target") + + def teardown_method(self): + """Reset the singleton after each test.""" + TargetRegistry.reset_instance() + + def test_list_metadata_returns_all_registered(self): + """Test that list_metadata returns metadata for all registered targets.""" + metadata = self.registry.list_metadata() + assert len(metadata) == 3 + + def test_list_metadata_filter_by_class_name(self): + """Test filtering metadata by class_name.""" + mock_metadata = self.registry.list_metadata(include_filters={"class_name": "MockPromptTarget"}) + + assert len(mock_metadata) == 2 + for m in mock_metadata: + assert m.class_name == "MockPromptTarget" + + +@pytest.mark.usefixtures("patch_central_database") +class TestTargetRegistrySupportsConversationHistory: + """Tests for supports_conversation_history field in TargetIdentifier.""" + + def setup_method(self): + """Reset and get a fresh registry for each test.""" + TargetRegistry.reset_instance() + self.registry = TargetRegistry.get_registry_singleton() + + def teardown_method(self): + """Reset the singleton after each test.""" + TargetRegistry.reset_instance() + + def test_registered_chat_target_has_supports_conversation_history_true(self): + """Test that registered chat targets have supports_conversation_history=True in metadata.""" + chat_target = MockPromptChatTarget() + self.registry.register_instance(chat_target, name="chat_target") + + metadata = self.registry.list_metadata() + assert len(metadata) == 1 + assert metadata[0].supports_conversation_history is True + + def test_registered_non_chat_target_has_supports_conversation_history_false(self): + """Test that registered non-chat targets have supports_conversation_history=False in metadata.""" + target = MockPromptTarget() + self.registry.register_instance(target, name="prompt_target") + + metadata = self.registry.list_metadata() + assert len(metadata) == 1 + assert metadata[0].supports_conversation_history is False diff --git a/tests/unit/setup/test_airt_targets_initializer.py b/tests/unit/setup/test_airt_targets_initializer.py new file mode 100644 index 000000000..48a537313 --- /dev/null +++ b/tests/unit/setup/test_airt_targets_initializer.py @@ -0,0 +1,202 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os + +import pytest + +from pyrit.registry import TargetRegistry +from pyrit.setup.initializers import AIRTTargetInitializer +from pyrit.setup.initializers.airt_targets import TARGET_CONFIGS + + +class TestAIRTTargetInitializerBasic: + """Tests for AIRTTargetInitializer class - basic functionality.""" + + def test_can_be_created(self): + """Test that AIRTTargetInitializer can be instantiated.""" + init = AIRTTargetInitializer() + assert init is not None + assert init.name == "AIRT Target Initializer" + assert init.execution_order == 1 + + def test_required_env_vars_is_empty(self): + """Test that no env vars are required (initializer is optional).""" + init = AIRTTargetInitializer() + assert init.required_env_vars == [] + + +@pytest.mark.usefixtures("patch_central_database") +class TestAIRTTargetInitializerInitialize: + """Tests for AIRTTargetInitializer.initialize_async method.""" + + def setup_method(self) -> None: + """Reset registry before each test.""" + TargetRegistry.reset_instance() + # Clear all target-related env vars + self._clear_env_vars() + + def teardown_method(self) -> None: + """Clean up after each test.""" + TargetRegistry.reset_instance() + self._clear_env_vars() + + def _clear_env_vars(self) -> None: + """Clear all environment variables used by TARGET_CONFIGS.""" + for config in TARGET_CONFIGS: + for var in [config.endpoint_var, config.key_var, config.model_var, config.underlying_model_var]: + if var and var in os.environ: + del os.environ[var] + + @pytest.mark.asyncio + async def test_initialize_runs_without_error_no_env_vars(self): + """Test that initialize runs without errors when no env vars are set.""" + init = AIRTTargetInitializer() + await init.initialize_async() + + # No targets should be registered + registry = TargetRegistry.get_registry_singleton() + assert len(registry) == 0 + + @pytest.mark.asyncio + async def test_registers_target_when_env_vars_set(self): + """Test that a target is registered when its env vars are set.""" + os.environ["OPENAI_CHAT_ENDPOINT"] = "https://api.openai.com/v1" + os.environ["OPENAI_CHAT_KEY"] = "test_key" + os.environ["OPENAI_CHAT_MODEL"] = "gpt-4o" + + init = AIRTTargetInitializer() + await init.initialize_async() + + registry = TargetRegistry.get_registry_singleton() + assert "openai_chat" in registry + target = registry.get_instance_by_name("openai_chat") + assert target is not None + assert target._model_name == "gpt-4o" + + @pytest.mark.asyncio + async def test_does_not_register_target_without_endpoint(self): + """Test that target is not registered if endpoint is missing.""" + # Only set key, not endpoint + os.environ["OPENAI_CHAT_KEY"] = "test_key" + os.environ["OPENAI_CHAT_MODEL"] = "gpt-4o" + + init = AIRTTargetInitializer() + await init.initialize_async() + + registry = TargetRegistry.get_registry_singleton() + assert "openai_chat" not in registry + + @pytest.mark.asyncio + async def test_does_not_register_target_without_api_key(self): + """Test that target is not registered if api_key is missing.""" + # Only set endpoint, not key + os.environ["OPENAI_CHAT_ENDPOINT"] = "https://api.openai.com/v1" + os.environ["OPENAI_CHAT_MODEL"] = "gpt-4o" + + init = AIRTTargetInitializer() + await init.initialize_async() + + registry = TargetRegistry.get_registry_singleton() + assert "openai_chat" not in registry + + @pytest.mark.asyncio + async def test_registers_multiple_targets(self): + """Test that multiple targets are registered when their env vars are set.""" + # Set up openai_chat + os.environ["OPENAI_CHAT_ENDPOINT"] = "https://api.openai.com/v1" + os.environ["OPENAI_CHAT_KEY"] = "test_key" + os.environ["OPENAI_CHAT_MODEL"] = "gpt-4o" + + # Set up openai_image + os.environ["OPENAI_IMAGE_ENDPOINT"] = "https://api.openai.com/v1" + os.environ["OPENAI_IMAGE_API_KEY"] = "test_image_key" + os.environ["OPENAI_IMAGE_MODEL"] = "dall-e-3" + + init = AIRTTargetInitializer() + await init.initialize_async() + + registry = TargetRegistry.get_registry_singleton() + assert len(registry) == 2 + assert "openai_chat" in registry + assert "openai_image" in registry + + @pytest.mark.asyncio + async def test_registers_azure_content_safety_without_model(self): + """Test that PromptShieldTarget is registered without model_name (it doesn't use one).""" + os.environ["AZURE_CONTENT_SAFETY_API_ENDPOINT"] = "https://test.cognitiveservices.azure.com" + os.environ["AZURE_CONTENT_SAFETY_API_KEY"] = "test_safety_key" + + init = AIRTTargetInitializer() + await init.initialize_async() + + registry = TargetRegistry.get_registry_singleton() + assert "azure_content_safety" in registry + + @pytest.mark.asyncio + async def test_underlying_model_passed_when_set(self): + """Test that underlying_model is passed to target when env var is set.""" + os.environ["OPENAI_CHAT_ENDPOINT"] = "https://my-deployment.openai.azure.com" + os.environ["OPENAI_CHAT_KEY"] = "test_key" + os.environ["OPENAI_CHAT_MODEL"] = "my-deployment-name" + os.environ["OPENAI_CHAT_UNDERLYING_MODEL"] = "gpt-4o" + + init = AIRTTargetInitializer() + await init.initialize_async() + + registry = TargetRegistry.get_registry_singleton() + target = registry.get_instance_by_name("openai_chat") + assert target is not None + assert target._model_name == "my-deployment-name" + assert target._underlying_model == "gpt-4o" + + +@pytest.mark.usefixtures("patch_central_database") +class TestAIRTTargetInitializerTargetConfigs: + """Tests verifying TARGET_CONFIGS covers expected targets.""" + + def test_target_configs_not_empty(self): + """Test that TARGET_CONFIGS has configurations defined.""" + assert len(TARGET_CONFIGS) > 0 + + def test_all_configs_have_required_fields(self): + """Test that all TARGET_CONFIGS have required fields.""" + for config in TARGET_CONFIGS: + assert config.registry_name, f"Config missing registry_name" + assert config.target_class, f"Config {config.registry_name} missing target_class" + assert config.endpoint_var, f"Config {config.registry_name} missing endpoint_var" + assert config.key_var, f"Config {config.registry_name} missing key_var" + + def test_expected_targets_in_configs(self): + """Test that expected target names are in TARGET_CONFIGS.""" + registry_names = [config.registry_name for config in TARGET_CONFIGS] + + # Verify key targets are configured + assert "openai_chat" in registry_names + assert "openai_image" in registry_names + assert "openai_tts" in registry_names + assert "azure_content_safety" in registry_names + + +class TestAIRTTargetInitializerGetInfo: + """Tests for AIRTTargetInitializer.get_info_async method.""" + + @pytest.mark.asyncio + async def test_get_info_returns_expected_structure(self): + """Test that get_info_async returns expected structure.""" + info = await AIRTTargetInitializer.get_info_async() + + assert isinstance(info, dict) + assert info["name"] == "AIRT Target Initializer" + assert info["class"] == "AIRTTargetInitializer" + assert "description" in info + assert isinstance(info["description"], str) + + @pytest.mark.asyncio + async def test_get_info_required_env_vars_empty_or_not_present(self): + """Test that get_info has empty or no required_env_vars (since none are required).""" + info = await AIRTTargetInitializer.get_info_async() + + # required_env_vars may be omitted or empty since this initializer has no requirements + if "required_env_vars" in info: + assert info["required_env_vars"] == []