Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,9 @@ API Reference
group_conversation_message_pieces_by_sequence
group_message_pieces_into_conversations
HarmDefinition
Identifiable
Identifier
IdentifierType
ImagePathDataTypeSerializer
AllowedCategories
AttackOutcome
Expand Down
8 changes: 4 additions & 4 deletions pyrit/cli/frontend_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def format_scenario_metadata(*, scenario_metadata: ScenarioMetadata) -> None:
_print_header(text=scenario_metadata.name)
print(f" Class: {scenario_metadata.class_name}")

description = scenario_metadata.description
description = scenario_metadata.class_description
if description:
print(" Description:")
print(_format_wrapped_text(text=description, indent=" "))
Expand Down Expand Up @@ -428,7 +428,7 @@ def format_initializer_metadata(*, initializer_metadata: "InitializerMetadata")
"""
_print_header(text=initializer_metadata.name)
print(f" Class: {initializer_metadata.class_name}")
print(f" Name: {initializer_metadata.initializer_name}")
print(f" Name: {initializer_metadata.display_name}")
print(f" Execution Order: {initializer_metadata.execution_order}")

if initializer_metadata.required_env_vars:
Expand All @@ -438,9 +438,9 @@ def format_initializer_metadata(*, initializer_metadata: "InitializerMetadata")
else:
print(" Required Environment Variables: None")

if initializer_metadata.description:
if initializer_metadata.class_description:
print(" Description:")
print(_format_wrapped_text(text=initializer_metadata.description, indent=" "))
print(_format_wrapped_text(text=initializer_metadata.class_description, indent=" "))


def validate_database(*, database: str) -> str:
Expand Down
4 changes: 3 additions & 1 deletion pyrit/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)
from pyrit.models.embeddings import EmbeddingData, EmbeddingResponse, EmbeddingSupport, EmbeddingUsageInformation
from pyrit.models.harm_definition import HarmDefinition, ScaleDescription, get_all_harm_definitions
from pyrit.models.identifiers import Identifier
from pyrit.models.identifiers import Identifiable, Identifier, IdentifierType
from pyrit.models.literals import ChatMessageRole, PromptDataType, PromptResponseError, SeedType
from pyrit.models.message import (
Message,
Expand Down Expand Up @@ -82,7 +82,9 @@
"group_conversation_message_pieces_by_sequence",
"group_message_pieces_into_conversations",
"HarmDefinition",
"Identifiable",
"Identifier",
"IdentifierType",
"ImagePathDataTypeSerializer",
"Message",
"MessagePiece",
Expand Down
78 changes: 77 additions & 1 deletion pyrit/models/identifiers.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,89 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import hashlib
import json
from abc import abstractmethod
from dataclasses import asdict, dataclass, field, fields, is_dataclass
from typing import Any, Literal

IdentifierType = Literal["class", "instance"]


class Identifiable:
"""
Abstract base class for objects that can provide an identifier dictionary.

This is a legacy interface that will eventually be replaced by Identifier dataclass.
Classes implementing this interface should return a dict describing their identity.
"""

class Identifier:
@abstractmethod
def get_identifier(self) -> dict[str, str]:
pass

def __str__(self) -> str:
return f"{self.get_identifier}"


@dataclass(frozen=True)
class Identifier:
"""
Base dataclass for identifying PyRIT components.

This frozen dataclass provides a stable identifier for registry items,
targets, scorers, attacks, converters, and other components. The hash is computed at creation
time from the core fields and remains constant.

This class serves as:
1. Base for registry metadata (replacing RegistryItemMetadata)
2. Future replacement for get_identifier() dict patterns

All component-specific identifier types should extend this with additional fields.
"""

name: str # The snake_case identifier name (e.g., "self_ask_refusal")
class_name: str # The actual class name, equivalent to __type__ (e.g., "SelfAskRefusalScorer")
class_module: str # The module path, equivalent to __module__ (e.g., "pyrit.score.self_ask_refusal_scorer")

class_description: str = field(metadata={"exclude_from_storage": True})

# Whether this identifies a "class" or "instance"
identifier_type: IdentifierType = field(metadata={"exclude_from_storage": True})
hash: str = field(init=False, compare=False)

def __post_init__(self) -> None:
"""Compute the identifier hash from core fields."""
# Use object.__setattr__ since this is a frozen dataclass
object.__setattr__(self, "hash", self._compute_hash())

def _compute_hash(self) -> str:
"""
Compute a stable SHA256 hash from storable identifier fields.

Fields marked with metadata={"exclude_from_storage": True} and 'hash' itself
are excluded from the hash computation.

Returns:
A hex string of the SHA256 hash.
"""
hashable_dict: dict[str, Any] = {
f.name: getattr(self, f.name)
for f in fields(self)
if f.name != "hash" and not f.metadata.get("exclude_from_storage", False)
}
config_json = json.dumps(hashable_dict, sort_keys=True, separators=(",", ":"), default=_dataclass_encoder)
return hashlib.sha256(config_json.encode("utf-8")).hexdigest()

def to_storage_dict(self) -> dict[str, Any]:
"""Return only fields suitable for DB storage."""
return {
f.name: getattr(self, f.name) for f in fields(self) if not f.metadata.get("exclude_from_storage", False)
}


def _dataclass_encoder(obj: Any) -> Any:
"""JSON encoder that handles dataclasses by converting them to dicts."""
if is_dataclass(obj) and not isinstance(obj, type):
return asdict(obj)
raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable")
4 changes: 2 additions & 2 deletions pyrit/prompt_converter/prompt_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import get_args

from pyrit import prompt_converter
from pyrit.models import Identifier, PromptDataType
from pyrit.models import Identifiable, PromptDataType


@dataclass
Expand All @@ -31,7 +31,7 @@ def __str__(self) -> str:
return f"{self.output_type}: {self.output_text}"


class PromptConverter(abc.ABC, Identifier):
class PromptConverter(abc.ABC, Identifiable):
"""
Base class for converters that transform prompts into a different representation or format.
Expand Down
4 changes: 2 additions & 2 deletions pyrit/prompt_target/common/prompt_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
from typing import Any, Dict, List, Optional

from pyrit.memory import CentralMemory, MemoryInterface
from pyrit.models import Identifier, Message
from pyrit.models import Identifiable, Message

logger = logging.getLogger(__name__)


class PromptTarget(abc.ABC, Identifier):
class PromptTarget(abc.ABC, Identifiable):
"""
Abstract base class for prompt targets.

Expand Down
5 changes: 3 additions & 2 deletions pyrit/registry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

"""Registry module for PyRIT class and instance registries."""

from pyrit.registry.base import RegistryItemMetadata, RegistryProtocol
from pyrit.models.identifiers import Identifier
from pyrit.registry.base import RegistryProtocol
from pyrit.registry.class_registries import (
BaseClassRegistry,
ClassEntry,
Expand Down Expand Up @@ -32,9 +33,9 @@
"discover_in_directory",
"discover_in_package",
"discover_subclasses_in_loaded_modules",
"Identifier",
"InitializerMetadata",
"InitializerRegistry",
"RegistryItemMetadata",
"RegistryProtocol",
"registry_name_to_class_name",
"ScenarioMetadata",
Expand Down
18 changes: 0 additions & 18 deletions pyrit/registry/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
and instance registries (which store T instances).
"""

from dataclasses import dataclass
from typing import Any, Dict, Iterator, List, Optional, Protocol, TypeVar, runtime_checkable

# Type variable for metadata (invariant for Protocol compatibility)
Expand Down Expand Up @@ -77,23 +76,6 @@ def __iter__(self) -> Iterator[str]:
...


@dataclass(frozen=True)
class RegistryItemMetadata:
"""
Base dataclass for registry item metadata.

This dataclass provides descriptive information about a registered item
(either a class or an instance). It is NOT the item itself - it's a
structured object describing the item.

All registry-specific metadata types should extend this with additional fields.
"""

name: str # The snake_case registry name (e.g., "self_ask_refusal")
class_name: str # The actual class name (e.g., "SelfAskRefusalScorer")
description: str # Description from docstring or manual override


def _matches_filters(
metadata: Any,
*,
Expand Down
15 changes: 9 additions & 6 deletions pyrit/registry/class_registries/base_class_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from abc import ABC, abstractmethod
from typing import Callable, Dict, Generic, Iterator, List, Optional, Type, TypeVar

from pyrit.registry.base import RegistryItemMetadata, RegistryProtocol
from pyrit.models.identifiers import Identifier
from pyrit.registry.base import RegistryProtocol
from pyrit.registry.name_utils import class_name_to_registry_name

# Type variable for the registered class type
Expand Down Expand Up @@ -182,19 +183,19 @@ def _build_metadata(self, name: str, entry: ClassEntry[T]) -> MetadataT:
"""
pass

def _build_base_metadata(self, name: str, entry: ClassEntry[T]) -> RegistryItemMetadata:
def _build_base_metadata(self, name: str, entry: ClassEntry[T]) -> Identifier:
"""
Build the common base metadata for a registered class.

This helper extracts fields common to all registries: name, class_name, description.
This helper extracts fields common to all registries: name, class_name, class_description.
Subclasses can use this for building common fields if needed.

Args:
name: The registry name (snake_case identifier).
entry: The ClassEntry containing the registered class.

Returns:
A RegistryItemMetadata dataclass with common fields.
An Identifier dataclass with common fields.
"""
registered_class = entry.registered_class

Expand All @@ -205,10 +206,12 @@ def _build_base_metadata(self, name: str, entry: ClassEntry[T]) -> RegistryItemM
else:
description = entry.description or "No description available"

return RegistryItemMetadata(
return Identifier(
identifier_type="class",
name=name,
class_name=registered_class.__name__,
description=description,
class_module=registered_class.__module__,
class_description=description,
)

def get_class(self, name: str) -> Type[T]:
Expand Down
18 changes: 11 additions & 7 deletions pyrit/registry/class_registries/initializer_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from pathlib import Path
from typing import TYPE_CHECKING, Dict, Optional

from pyrit.registry.base import RegistryItemMetadata
from pyrit.models.identifiers import Identifier
from pyrit.registry.class_registries.base_class_registry import (
BaseClassRegistry,
ClassEntry,
Expand All @@ -34,14 +34,14 @@


@dataclass(frozen=True)
class InitializerMetadata(RegistryItemMetadata):
class InitializerMetadata(Identifier):
"""
Metadata describing a registered PyRITInitializer class.

Use get_class() to get the actual class.
"""

initializer_name: str
display_name: str
required_env_vars: tuple[str, ...]
execution_order: int

Expand Down Expand Up @@ -208,20 +208,24 @@ def _build_metadata(self, name: str, entry: ClassEntry["PyRITInitializer"]) -> I
try:
instance = initializer_class()
return InitializerMetadata(
identifier_type="class",
name=name,
class_name=initializer_class.__name__,
description=instance.description,
initializer_name=instance.name,
class_module=initializer_class.__module__,
class_description=instance.description,
display_name=instance.name,
required_env_vars=tuple(instance.required_env_vars),
execution_order=instance.execution_order,
)
except Exception as e:
logger.warning(f"Failed to get metadata for {name}: {e}")
return InitializerMetadata(
identifier_type="class",
name=name,
class_name=initializer_class.__name__,
description="Error loading initializer metadata",
initializer_name=name,
class_module=initializer_class.__module__,
class_description="Error loading initializer metadata",
display_name=name,
required_env_vars=(),
execution_order=100,
)
Expand Down
8 changes: 5 additions & 3 deletions pyrit/registry/class_registries/scenario_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from pathlib import Path
from typing import TYPE_CHECKING, Optional

from pyrit.registry.base import RegistryItemMetadata
from pyrit.models.identifiers import Identifier
from pyrit.registry.class_registries.base_class_registry import (
BaseClassRegistry,
ClassEntry,
Expand All @@ -33,7 +33,7 @@


@dataclass(frozen=True)
class ScenarioMetadata(RegistryItemMetadata):
class ScenarioMetadata(Identifier):
"""
Metadata describing a registered Scenario class.

Expand Down Expand Up @@ -170,9 +170,11 @@ def _build_metadata(self, name: str, entry: ClassEntry["Scenario"]) -> ScenarioM
max_dataset_size = dataset_config.max_dataset_size

return ScenarioMetadata(
identifier_type="class",
name=name,
class_name=scenario_class.__name__,
description=description,
class_module=scenario_class.__module__,
class_description=description,
default_strategy=scenario_class.get_default_strategy().value,
all_strategies=tuple(s.value for s in strategy_class.get_all_strategies()),
aggregate_strategies=tuple(s.value for s in strategy_class.get_aggregate_strategies()),
Expand Down
5 changes: 3 additions & 2 deletions pyrit/registry/instance_registries/base_instance_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Generic, Iterator, List, Optional, TypeVar

from pyrit.registry.base import RegistryItemMetadata, RegistryProtocol
from pyrit.models.identifiers import Identifier
from pyrit.registry.base import RegistryProtocol

T = TypeVar("T") # The type of instances stored
MetadataT = TypeVar("MetadataT", bound=RegistryItemMetadata)
MetadataT = TypeVar("MetadataT", bound=Identifier)


class BaseInstanceRegistry(ABC, RegistryProtocol[MetadataT], Generic[T, MetadataT]):
Expand Down
Loading