diff --git a/doc/api.rst b/doc/api.rst index c475c1923..23919b52e 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -332,7 +332,9 @@ API Reference group_conversation_message_pieces_by_sequence group_message_pieces_into_conversations HarmDefinition + Identifiable Identifier + IdentifierType ImagePathDataTypeSerializer AllowedCategories AttackOutcome diff --git a/pyrit/cli/frontend_core.py b/pyrit/cli/frontend_core.py index ca1d6f2ad..3739b54cd 100644 --- a/pyrit/cli/frontend_core.py +++ b/pyrit/cli/frontend_core.py @@ -387,7 +387,7 @@ def format_scenario_metadata(*, scenario_metadata: ScenarioMetadata) -> None: _print_header(text=scenario_metadata.name) print(f" Class: {scenario_metadata.class_name}") - description = scenario_metadata.description + description = scenario_metadata.class_description if description: print(" Description:") print(_format_wrapped_text(text=description, indent=" ")) @@ -428,7 +428,7 @@ def format_initializer_metadata(*, initializer_metadata: "InitializerMetadata") """ _print_header(text=initializer_metadata.name) print(f" Class: {initializer_metadata.class_name}") - print(f" Name: {initializer_metadata.initializer_name}") + print(f" Name: {initializer_metadata.display_name}") print(f" Execution Order: {initializer_metadata.execution_order}") if initializer_metadata.required_env_vars: @@ -438,9 +438,9 @@ def format_initializer_metadata(*, initializer_metadata: "InitializerMetadata") else: print(" Required Environment Variables: None") - if initializer_metadata.description: + if initializer_metadata.class_description: print(" Description:") - print(_format_wrapped_text(text=initializer_metadata.description, indent=" ")) + print(_format_wrapped_text(text=initializer_metadata.class_description, indent=" ")) def validate_database(*, database: str) -> str: diff --git a/pyrit/models/__init__.py b/pyrit/models/__init__.py index 5d6e30665..2a43439bf 100644 --- a/pyrit/models/__init__.py +++ b/pyrit/models/__init__.py @@ -22,7 +22,7 @@ ) from pyrit.models.embeddings import EmbeddingData, EmbeddingResponse, EmbeddingSupport, EmbeddingUsageInformation from pyrit.models.harm_definition import HarmDefinition, ScaleDescription, get_all_harm_definitions -from pyrit.models.identifiers import Identifier +from pyrit.models.identifiers import Identifiable, Identifier, IdentifierType from pyrit.models.literals import ChatMessageRole, PromptDataType, PromptResponseError, SeedType from pyrit.models.message import ( Message, @@ -82,7 +82,9 @@ "group_conversation_message_pieces_by_sequence", "group_message_pieces_into_conversations", "HarmDefinition", + "Identifiable", "Identifier", + "IdentifierType", "ImagePathDataTypeSerializer", "Message", "MessagePiece", diff --git a/pyrit/models/identifiers.py b/pyrit/models/identifiers.py index c70eabdf2..f4a758b76 100644 --- a/pyrit/models/identifiers.py +++ b/pyrit/models/identifiers.py @@ -1,13 +1,89 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import hashlib +import json from abc import abstractmethod +from dataclasses import asdict, dataclass, field, fields, is_dataclass +from typing import Any, Literal +IdentifierType = Literal["class", "instance"] + + +class Identifiable: + """ + Abstract base class for objects that can provide an identifier dictionary. + + This is a legacy interface that will eventually be replaced by Identifier dataclass. + Classes implementing this interface should return a dict describing their identity. + """ -class Identifier: @abstractmethod def get_identifier(self) -> dict[str, str]: pass def __str__(self) -> str: return f"{self.get_identifier}" + + +@dataclass(frozen=True) +class Identifier: + """ + Base dataclass for identifying PyRIT components. + + This frozen dataclass provides a stable identifier for registry items, + targets, scorers, attacks, converters, and other components. The hash is computed at creation + time from the core fields and remains constant. + + This class serves as: + 1. Base for registry metadata (replacing RegistryItemMetadata) + 2. Future replacement for get_identifier() dict patterns + + All component-specific identifier types should extend this with additional fields. + """ + + name: str # The snake_case identifier name (e.g., "self_ask_refusal") + class_name: str # The actual class name, equivalent to __type__ (e.g., "SelfAskRefusalScorer") + class_module: str # The module path, equivalent to __module__ (e.g., "pyrit.score.self_ask_refusal_scorer") + + class_description: str = field(metadata={"exclude_from_storage": True}) + + # Whether this identifies a "class" or "instance" + identifier_type: IdentifierType = field(metadata={"exclude_from_storage": True}) + hash: str = field(init=False, compare=False) + + def __post_init__(self) -> None: + """Compute the identifier hash from core fields.""" + # Use object.__setattr__ since this is a frozen dataclass + object.__setattr__(self, "hash", self._compute_hash()) + + def _compute_hash(self) -> str: + """ + Compute a stable SHA256 hash from storable identifier fields. + + Fields marked with metadata={"exclude_from_storage": True} and 'hash' itself + are excluded from the hash computation. + + Returns: + A hex string of the SHA256 hash. + """ + hashable_dict: dict[str, Any] = { + f.name: getattr(self, f.name) + for f in fields(self) + if f.name != "hash" and not f.metadata.get("exclude_from_storage", False) + } + config_json = json.dumps(hashable_dict, sort_keys=True, separators=(",", ":"), default=_dataclass_encoder) + return hashlib.sha256(config_json.encode("utf-8")).hexdigest() + + def to_storage_dict(self) -> dict[str, Any]: + """Return only fields suitable for DB storage.""" + return { + f.name: getattr(self, f.name) for f in fields(self) if not f.metadata.get("exclude_from_storage", False) + } + + +def _dataclass_encoder(obj: Any) -> Any: + """JSON encoder that handles dataclasses by converting them to dicts.""" + if is_dataclass(obj) and not isinstance(obj, type): + return asdict(obj) + raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable") diff --git a/pyrit/prompt_converter/prompt_converter.py b/pyrit/prompt_converter/prompt_converter.py index db9497b81..55a315a00 100644 --- a/pyrit/prompt_converter/prompt_converter.py +++ b/pyrit/prompt_converter/prompt_converter.py @@ -9,7 +9,7 @@ from typing import get_args from pyrit import prompt_converter -from pyrit.models import Identifier, PromptDataType +from pyrit.models import Identifiable, PromptDataType @dataclass @@ -31,7 +31,7 @@ def __str__(self) -> str: return f"{self.output_type}: {self.output_text}" -class PromptConverter(abc.ABC, Identifier): +class PromptConverter(abc.ABC, Identifiable): """ Base class for converters that transform prompts into a different representation or format. diff --git a/pyrit/prompt_target/common/prompt_target.py b/pyrit/prompt_target/common/prompt_target.py index 78b466e04..f39ba537c 100644 --- a/pyrit/prompt_target/common/prompt_target.py +++ b/pyrit/prompt_target/common/prompt_target.py @@ -6,12 +6,12 @@ from typing import Any, Dict, List, Optional from pyrit.memory import CentralMemory, MemoryInterface -from pyrit.models import Identifier, Message +from pyrit.models import Identifiable, Message logger = logging.getLogger(__name__) -class PromptTarget(abc.ABC, Identifier): +class PromptTarget(abc.ABC, Identifiable): """ Abstract base class for prompt targets. diff --git a/pyrit/registry/__init__.py b/pyrit/registry/__init__.py index ba0566020..7c8c8c1fa 100644 --- a/pyrit/registry/__init__.py +++ b/pyrit/registry/__init__.py @@ -3,7 +3,8 @@ """Registry module for PyRIT class and instance registries.""" -from pyrit.registry.base import RegistryItemMetadata, RegistryProtocol +from pyrit.models.identifiers import Identifier +from pyrit.registry.base import RegistryProtocol from pyrit.registry.class_registries import ( BaseClassRegistry, ClassEntry, @@ -32,9 +33,9 @@ "discover_in_directory", "discover_in_package", "discover_subclasses_in_loaded_modules", + "Identifier", "InitializerMetadata", "InitializerRegistry", - "RegistryItemMetadata", "RegistryProtocol", "registry_name_to_class_name", "ScenarioMetadata", diff --git a/pyrit/registry/base.py b/pyrit/registry/base.py index a9c284561..5f5e37400 100644 --- a/pyrit/registry/base.py +++ b/pyrit/registry/base.py @@ -8,7 +8,6 @@ and instance registries (which store T instances). """ -from dataclasses import dataclass from typing import Any, Dict, Iterator, List, Optional, Protocol, TypeVar, runtime_checkable # Type variable for metadata (invariant for Protocol compatibility) @@ -77,23 +76,6 @@ def __iter__(self) -> Iterator[str]: ... -@dataclass(frozen=True) -class RegistryItemMetadata: - """ - Base dataclass for registry item metadata. - - This dataclass provides descriptive information about a registered item - (either a class or an instance). It is NOT the item itself - it's a - structured object describing the item. - - All registry-specific metadata types should extend this with additional fields. - """ - - name: str # The snake_case registry name (e.g., "self_ask_refusal") - class_name: str # The actual class name (e.g., "SelfAskRefusalScorer") - description: str # Description from docstring or manual override - - def _matches_filters( metadata: Any, *, diff --git a/pyrit/registry/class_registries/base_class_registry.py b/pyrit/registry/class_registries/base_class_registry.py index bd6ae3f27..5f2f5bc10 100644 --- a/pyrit/registry/class_registries/base_class_registry.py +++ b/pyrit/registry/class_registries/base_class_registry.py @@ -19,7 +19,8 @@ from abc import ABC, abstractmethod from typing import Callable, Dict, Generic, Iterator, List, Optional, Type, TypeVar -from pyrit.registry.base import RegistryItemMetadata, RegistryProtocol +from pyrit.models.identifiers import Identifier +from pyrit.registry.base import RegistryProtocol from pyrit.registry.name_utils import class_name_to_registry_name # Type variable for the registered class type @@ -182,11 +183,11 @@ def _build_metadata(self, name: str, entry: ClassEntry[T]) -> MetadataT: """ pass - def _build_base_metadata(self, name: str, entry: ClassEntry[T]) -> RegistryItemMetadata: + def _build_base_metadata(self, name: str, entry: ClassEntry[T]) -> Identifier: """ Build the common base metadata for a registered class. - This helper extracts fields common to all registries: name, class_name, description. + This helper extracts fields common to all registries: name, class_name, class_description. Subclasses can use this for building common fields if needed. Args: @@ -194,7 +195,7 @@ def _build_base_metadata(self, name: str, entry: ClassEntry[T]) -> RegistryItemM entry: The ClassEntry containing the registered class. Returns: - A RegistryItemMetadata dataclass with common fields. + An Identifier dataclass with common fields. """ registered_class = entry.registered_class @@ -205,10 +206,12 @@ def _build_base_metadata(self, name: str, entry: ClassEntry[T]) -> RegistryItemM else: description = entry.description or "No description available" - return RegistryItemMetadata( + return Identifier( + identifier_type="class", name=name, class_name=registered_class.__name__, - description=description, + class_module=registered_class.__module__, + class_description=description, ) def get_class(self, name: str) -> Type[T]: diff --git a/pyrit/registry/class_registries/initializer_registry.py b/pyrit/registry/class_registries/initializer_registry.py index fed6d7b6c..96dedad2f 100644 --- a/pyrit/registry/class_registries/initializer_registry.py +++ b/pyrit/registry/class_registries/initializer_registry.py @@ -16,7 +16,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Dict, Optional -from pyrit.registry.base import RegistryItemMetadata +from pyrit.models.identifiers import Identifier from pyrit.registry.class_registries.base_class_registry import ( BaseClassRegistry, ClassEntry, @@ -34,14 +34,14 @@ @dataclass(frozen=True) -class InitializerMetadata(RegistryItemMetadata): +class InitializerMetadata(Identifier): """ Metadata describing a registered PyRITInitializer class. Use get_class() to get the actual class. """ - initializer_name: str + display_name: str required_env_vars: tuple[str, ...] execution_order: int @@ -208,20 +208,24 @@ def _build_metadata(self, name: str, entry: ClassEntry["PyRITInitializer"]) -> I try: instance = initializer_class() return InitializerMetadata( + identifier_type="class", name=name, class_name=initializer_class.__name__, - description=instance.description, - initializer_name=instance.name, + class_module=initializer_class.__module__, + class_description=instance.description, + display_name=instance.name, required_env_vars=tuple(instance.required_env_vars), execution_order=instance.execution_order, ) except Exception as e: logger.warning(f"Failed to get metadata for {name}: {e}") return InitializerMetadata( + identifier_type="class", name=name, class_name=initializer_class.__name__, - description="Error loading initializer metadata", - initializer_name=name, + class_module=initializer_class.__module__, + class_description="Error loading initializer metadata", + display_name=name, required_env_vars=(), execution_order=100, ) diff --git a/pyrit/registry/class_registries/scenario_registry.py b/pyrit/registry/class_registries/scenario_registry.py index 40693f0cd..01215662f 100644 --- a/pyrit/registry/class_registries/scenario_registry.py +++ b/pyrit/registry/class_registries/scenario_registry.py @@ -15,7 +15,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Optional -from pyrit.registry.base import RegistryItemMetadata +from pyrit.models.identifiers import Identifier from pyrit.registry.class_registries.base_class_registry import ( BaseClassRegistry, ClassEntry, @@ -33,7 +33,7 @@ @dataclass(frozen=True) -class ScenarioMetadata(RegistryItemMetadata): +class ScenarioMetadata(Identifier): """ Metadata describing a registered Scenario class. @@ -170,9 +170,11 @@ def _build_metadata(self, name: str, entry: ClassEntry["Scenario"]) -> ScenarioM max_dataset_size = dataset_config.max_dataset_size return ScenarioMetadata( + identifier_type="class", name=name, class_name=scenario_class.__name__, - description=description, + class_module=scenario_class.__module__, + class_description=description, default_strategy=scenario_class.get_default_strategy().value, all_strategies=tuple(s.value for s in strategy_class.get_all_strategies()), aggregate_strategies=tuple(s.value for s in strategy_class.get_aggregate_strategies()), diff --git a/pyrit/registry/instance_registries/base_instance_registry.py b/pyrit/registry/instance_registries/base_instance_registry.py index 351e927e0..2c12fab99 100644 --- a/pyrit/registry/instance_registries/base_instance_registry.py +++ b/pyrit/registry/instance_registries/base_instance_registry.py @@ -18,10 +18,11 @@ from abc import ABC, abstractmethod from typing import Any, Dict, Generic, Iterator, List, Optional, TypeVar -from pyrit.registry.base import RegistryItemMetadata, RegistryProtocol +from pyrit.models.identifiers import Identifier +from pyrit.registry.base import RegistryProtocol T = TypeVar("T") # The type of instances stored -MetadataT = TypeVar("MetadataT", bound=RegistryItemMetadata) +MetadataT = TypeVar("MetadataT", bound=Identifier) class BaseInstanceRegistry(ABC, RegistryProtocol[MetadataT], Generic[T, MetadataT]): diff --git a/pyrit/registry/instance_registries/scorer_registry.py b/pyrit/registry/instance_registries/scorer_registry.py index 7a9411a8b..72798cf15 100644 --- a/pyrit/registry/instance_registries/scorer_registry.py +++ b/pyrit/registry/instance_registries/scorer_registry.py @@ -13,7 +13,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Optional -from pyrit.registry.base import RegistryItemMetadata +from pyrit.models.identifiers import Identifier from pyrit.registry.instance_registries.base_instance_registry import ( BaseInstanceRegistry, ) @@ -29,7 +29,7 @@ @dataclass(frozen=True) -class ScorerMetadata(RegistryItemMetadata): +class ScorerMetadata(Identifier): """ Metadata describing a registered scorer instance. @@ -130,9 +130,11 @@ def _build_metadata(self, name: str, instance: "Scorer") -> ScorerMetadata: scorer_type = "unknown" return ScorerMetadata( + identifier_type="instance", name=name, class_name=instance.__class__.__name__, - description=description, + class_module=instance.__class__.__module__, + class_description=description, scorer_type=scorer_type, scorer_identifier=instance.scorer_identifier, ) diff --git a/tests/unit/cli/test_frontend_core.py b/tests/unit/cli/test_frontend_core.py index 22b51294e..764a036aa 100644 --- a/tests/unit/cli/test_frontend_core.py +++ b/tests/unit/cli/test_frontend_core.py @@ -328,9 +328,11 @@ async def test_print_scenarios_list_with_scenarios(self, capsys): mock_registry = MagicMock() mock_registry.list_metadata.return_value = [ ScenarioMetadata( + identifier_type="class", name="test_scenario", class_name="TestScenario", - description="Test description", + class_module="test.scenarios", + class_description="Test description", default_strategy="default", all_strategies=(), aggregate_strategies=(), @@ -368,10 +370,12 @@ async def test_print_initializers_list_with_initializers(self, capsys): mock_registry = MagicMock() mock_registry.list_metadata.return_value = [ InitializerMetadata( + identifier_type="class", name="test_init", class_name="TestInit", - description="Test initializer", - initializer_name="test", + class_module="test.initializers", + class_description="Test initializer", + display_name="test", execution_order=100, required_env_vars=(), ) @@ -408,9 +412,11 @@ def test_format_scenario_metadata_basic(self, capsys): """Test format_scenario_metadata with basic metadata.""" scenario_metadata = ScenarioMetadata( + identifier_type="class", name="test_scenario", class_name="TestScenario", - description="", + class_module="test.scenarios", + class_description="", default_strategy="", all_strategies=(), aggregate_strategies=(), @@ -428,9 +434,11 @@ def test_format_scenario_metadata_with_description(self, capsys): """Test format_scenario_metadata with description.""" scenario_metadata = ScenarioMetadata( + identifier_type="class", name="test_scenario", class_name="TestScenario", - description="This is a test scenario", + class_module="test.scenarios", + class_description="This is a test scenario", default_strategy="", all_strategies=(), aggregate_strategies=(), @@ -446,9 +454,11 @@ def test_format_scenario_metadata_with_description(self, capsys): def test_format_scenario_metadata_with_strategies(self, capsys): """Test format_scenario_metadata with strategies.""" scenario_metadata = ScenarioMetadata( + identifier_type="class", name="test_scenario", class_name="TestScenario", - description="", + class_module="test.scenarios", + class_description="", default_strategy="strategy1", all_strategies=("strategy1", "strategy2"), aggregate_strategies=(), @@ -466,10 +476,12 @@ def test_format_scenario_metadata_with_strategies(self, capsys): def test_format_initializer_metadata_basic(self, capsys) -> None: """Test format_initializer_metadata with basic metadata.""" initializer_metadata = InitializerMetadata( + identifier_type="class", name="test_init", class_name="TestInit", - description="", - initializer_name="test", + class_module="test.initializers", + class_description="", + display_name="test", required_env_vars=(), execution_order=100, ) @@ -484,10 +496,12 @@ def test_format_initializer_metadata_basic(self, capsys) -> None: def test_format_initializer_metadata_with_env_vars(self, capsys) -> None: """Test format_initializer_metadata with environment variables.""" initializer_metadata = InitializerMetadata( + identifier_type="class", name="test_init", class_name="TestInit", - description="", - initializer_name="test", + class_module="test.initializers", + class_description="", + display_name="test", required_env_vars=("VAR1", "VAR2"), execution_order=100, ) @@ -501,10 +515,12 @@ def test_format_initializer_metadata_with_env_vars(self, capsys) -> None: def test_format_initializer_metadata_with_description(self, capsys) -> None: """Test format_initializer_metadata with description.""" initializer_metadata = InitializerMetadata( + identifier_type="class", name="test_init", class_name="TestInit", - description="Test description", - initializer_name="test", + class_module="test.initializers", + class_description="Test description", + display_name="test", required_env_vars=(), execution_order=100, ) diff --git a/tests/unit/models/test_identifiers.py b/tests/unit/models/test_identifiers.py new file mode 100644 index 000000000..845a65dce --- /dev/null +++ b/tests/unit/models/test_identifiers.py @@ -0,0 +1,331 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from dataclasses import dataclass, field + +import pytest + +from pyrit.models.identifiers import Identifiable, Identifier + + +class TestIdentifiable: + """Tests for the Identifiable abstract base class.""" + + def test_identifiable_get_identifier_is_abstract(self): + """Test that get_identifier is an abstract method that must be implemented.""" + + class ConcreteIdentifiable(Identifiable): + def get_identifier(self) -> dict[str, str]: + return {"type": "test", "name": "example"} + + obj = ConcreteIdentifiable() + result = obj.get_identifier() + assert result == {"type": "test", "name": "example"} + + def test_identifiable_str_returns_identifier(self): + """Test that __str__ returns the get_identifier method reference.""" + + class ConcreteIdentifiable(Identifiable): + def get_identifier(self) -> dict[str, str]: + return {"type": "test"} + + obj = ConcreteIdentifiable() + # __str__ returns the method reference string + assert "get_identifier" in str(obj) + + +class TestIdentifier: + """Tests for the Identifier dataclass.""" + + def test_identifier_creation(self): + """Test creating an Identifier instance with all required fields.""" + identifier = Identifier( + identifier_type="class", + name="test_scorer", + class_name="TestScorer", + class_module="pyrit.test.scorer", + class_description="A test scorer for testing", + ) + assert identifier.identifier_type == "class" + assert identifier.name == "test_scorer" + assert identifier.class_name == "TestScorer" + assert identifier.class_module == "pyrit.test.scorer" + assert identifier.class_description == "A test scorer for testing" + + def test_identifier_is_frozen(self): + """Test that Identifier is immutable.""" + identifier = Identifier( + identifier_type="instance", + name="test_item", + class_name="TestClass", + class_module="test.module", + class_description="Description here", + ) + + with pytest.raises(AttributeError): + identifier.name = "new_name" # type: ignore[misc] + + def test_identifier_type_literal_class(self): + """Test identifier_type with 'class' value.""" + identifier = Identifier( + identifier_type="class", + name="test", + class_name="Test", + class_module="test", + class_description="", + ) + assert identifier.identifier_type == "class" + + def test_identifier_type_literal_instance(self): + """Test identifier_type with 'instance' value.""" + identifier = Identifier( + identifier_type="instance", + name="test", + class_name="Test", + class_module="test", + class_description="", + ) + assert identifier.identifier_type == "instance" + + +class TestIdentifierHash: + """Tests for Identifier hash computation.""" + + def test_hash_computed_at_creation(self): + """Test that hash is computed when the Identifier is created.""" + identifier = Identifier( + identifier_type="class", + name="test_item", + class_name="TestClass", + class_module="test.module", + class_description="A test description", + ) + assert identifier.hash is not None + assert len(identifier.hash) == 64 # SHA256 hex length + + def test_hash_is_deterministic(self): + """Test that the same inputs produce the same hash.""" + identifier1 = Identifier( + identifier_type="class", + name="test_item", + class_name="TestClass", + class_module="test.module", + class_description="A test description", + ) + identifier2 = Identifier( + identifier_type="class", + name="test_item", + class_name="TestClass", + class_module="test.module", + class_description="A test description", + ) + assert identifier1.hash == identifier2.hash + + def test_hash_differs_for_different_storable_fields(self): + """Test that different storable field values produce different hashes.""" + identifier1 = Identifier( + identifier_type="class", + name="test_item", + class_name="TestClass", + class_module="test.module", + class_description="Description", + ) + identifier2 = Identifier( + identifier_type="class", + name="different_item", + class_name="TestClass", + class_module="test.module", + class_description="Description", + ) + assert identifier1.hash != identifier2.hash + + def test_hash_excludes_class_description(self): + """Test that class_description is excluded from hash computation.""" + identifier1 = Identifier( + identifier_type="class", + name="test_item", + class_name="TestClass", + class_module="test.module", + class_description="First description", + ) + identifier2 = Identifier( + identifier_type="class", + name="test_item", + class_name="TestClass", + class_module="test.module", + class_description="Completely different description", + ) + # Hash should be the same since class_description is excluded + assert identifier1.hash == identifier2.hash + + def test_hash_excludes_identifier_type(self): + """Test that identifier_type is excluded from hash computation.""" + identifier1 = Identifier( + identifier_type="class", + name="test_item", + class_name="TestClass", + class_module="test.module", + class_description="Description", + ) + identifier2 = Identifier( + identifier_type="instance", + name="test_item", + class_name="TestClass", + class_module="test.module", + class_description="Description", + ) + # Hash should be the same since identifier_type is excluded + assert identifier1.hash == identifier2.hash + + def test_hash_is_immutable(self): + """Test that the hash cannot be modified.""" + identifier = Identifier( + identifier_type="class", + name="test_item", + class_name="TestClass", + class_module="test.module", + class_description="A test description", + ) + with pytest.raises(AttributeError): + identifier.hash = "new_hash" # type: ignore[misc] + + +class TestIdentifierStorage: + """Tests for Identifier storage functionality.""" + + def test_to_storage_dict_excludes_marked_fields(self): + """Test that to_storage_dict excludes fields marked with exclude_from_storage.""" + identifier = Identifier( + identifier_type="class", + name="test_item", + class_name="TestClass", + class_module="test.module", + class_description="A test description", + ) + storage_dict = identifier.to_storage_dict() + + # Should include storable fields + assert "name" in storage_dict + assert "class_name" in storage_dict + assert "class_module" in storage_dict + assert "hash" in storage_dict + + # Should exclude non-storable fields + assert "class_description" not in storage_dict + assert "identifier_type" not in storage_dict + + def test_to_storage_dict_values_match(self): + """Test that to_storage_dict values match the original identifier.""" + identifier = Identifier( + identifier_type="instance", + name="my_scorer", + class_name="MyScorer", + class_module="pyrit.score.my_scorer", + class_description="My custom scorer", + ) + storage_dict = identifier.to_storage_dict() + + assert storage_dict["name"] == "my_scorer" + assert storage_dict["class_name"] == "MyScorer" + assert storage_dict["class_module"] == "pyrit.score.my_scorer" + assert storage_dict["hash"] == identifier.hash + + +class TestIdentifierSubclass: + """Tests for Identifier subclassing behavior.""" + + def test_subclass_inherits_hash_computation(self): + """Test that subclasses of Identifier also get a computed hash.""" + + @dataclass(frozen=True) + class ExtendedIdentifier(Identifier): + extra_field: str + + extended = ExtendedIdentifier( + identifier_type="class", + name="test_item", + class_name="TestClass", + class_module="test.module", + class_description="A test description", + extra_field="extra_value", + ) + assert extended.hash is not None + assert len(extended.hash) == 64 + + def test_subclass_extra_fields_included_in_hash(self): + """Test that subclass extra fields (not marked) are included in hash.""" + + @dataclass(frozen=True) + class ExtendedIdentifier(Identifier): + extra_field: str + + extended1 = ExtendedIdentifier( + identifier_type="class", + name="test_item", + class_name="TestClass", + class_module="test.module", + class_description="Description", + extra_field="value1", + ) + extended2 = ExtendedIdentifier( + identifier_type="class", + name="test_item", + class_name="TestClass", + class_module="test.module", + class_description="Description", + extra_field="value2", + ) + # Different extra_field values should produce different hashes + assert extended1.hash != extended2.hash + + def test_subclass_excluded_fields_not_in_hash(self): + """Test that subclass fields marked exclude_from_storage are excluded from hash.""" + + @dataclass(frozen=True) + class ExtendedIdentifier(Identifier): + display_only: str = field(metadata={"exclude_from_storage": True}) + + extended1 = ExtendedIdentifier( + identifier_type="class", + name="test_item", + class_name="TestClass", + class_module="test.module", + class_description="Description", + display_only="display1", + ) + extended2 = ExtendedIdentifier( + identifier_type="class", + name="test_item", + class_name="TestClass", + class_module="test.module", + class_description="Description", + display_only="display2", + ) + # display_only is excluded, so hashes should match + assert extended1.hash == extended2.hash + + def test_subclass_to_storage_dict_includes_extra_storable_fields(self): + """Test that to_storage_dict includes subclass storable fields.""" + + @dataclass(frozen=True) + class ExtendedIdentifier(Identifier): + extra_field: str + display_only: str = field(metadata={"exclude_from_storage": True}) + + extended = ExtendedIdentifier( + identifier_type="class", + name="test_item", + class_name="TestClass", + class_module="test.module", + class_description="Description", + extra_field="extra_value", + display_only="display_value", + ) + storage_dict = extended.to_storage_dict() + + # Extra storable field should be included + assert "extra_field" in storage_dict + assert storage_dict["extra_field"] == "extra_value" + + # Display-only field should be excluded + assert "display_only" not in storage_dict diff --git a/tests/unit/registry/test_base.py b/tests/unit/registry/test_base.py index e96d695ec..feb8e326a 100644 --- a/tests/unit/registry/test_base.py +++ b/tests/unit/registry/test_base.py @@ -5,11 +5,12 @@ import pytest -from pyrit.registry.base import RegistryItemMetadata, _matches_filters +from pyrit.models.identifiers import Identifier +from pyrit.registry.base import _matches_filters @dataclass(frozen=True) -class MetadataWithTags(RegistryItemMetadata): +class MetadataWithTags(Identifier): """Test metadata with a tags field for list filtering tests.""" tags: tuple[str, ...] @@ -20,66 +21,80 @@ class TestMatchesFilters: def test_matches_filters_exact_match_string(self): """Test that exact string matches work.""" - metadata = RegistryItemMetadata( + metadata = Identifier( + identifier_type="class", name="test_item", class_name="TestClass", - description="A test item", + class_module="test.module", + class_description="A test item", ) assert _matches_filters(metadata, include_filters={"name": "test_item"}) is True assert _matches_filters(metadata, include_filters={"class_name": "TestClass"}) is True def test_matches_filters_no_match_string(self): """Test that non-matching strings return False.""" - metadata = RegistryItemMetadata( + metadata = Identifier( + identifier_type="class", name="test_item", class_name="TestClass", - description="A test item", + class_module="test.module", + class_description="A test item", ) assert _matches_filters(metadata, include_filters={"name": "other_item"}) is False assert _matches_filters(metadata, include_filters={"class_name": "OtherClass"}) is False def test_matches_filters_multiple_filters_all_match(self): """Test that all filters must match.""" - metadata = RegistryItemMetadata( + metadata = Identifier( + identifier_type="class", name="test_item", class_name="TestClass", - description="A test item", + class_module="test.module", + class_description="A test item", ) assert _matches_filters(metadata, include_filters={"name": "test_item", "class_name": "TestClass"}) is True def test_matches_filters_multiple_filters_partial_match(self): """Test that partial matches return False when not all filters match.""" - metadata = RegistryItemMetadata( + metadata = Identifier( + identifier_type="class", name="test_item", class_name="TestClass", - description="A test item", + class_module="test.module", + class_description="A test item", ) assert _matches_filters(metadata, include_filters={"name": "test_item", "class_name": "OtherClass"}) is False def test_matches_filters_key_not_in_metadata(self): """Test that filtering on a non-existent key returns False.""" - metadata = RegistryItemMetadata( + metadata = Identifier( + identifier_type="class", name="test_item", class_name="TestClass", - description="A test item", + class_module="test.module", + class_description="A test item", ) assert _matches_filters(metadata, include_filters={"nonexistent_key": "value"}) is False def test_matches_filters_empty_filters(self): """Test that empty filters return True.""" - metadata = RegistryItemMetadata( + metadata = Identifier( + identifier_type="class", name="test_item", class_name="TestClass", - description="A test item", + class_module="test.module", + class_description="A test item", ) assert _matches_filters(metadata) is True def test_matches_filters_list_value_contains_filter(self): """Test filtering when metadata value is a list and filter value is in the list.""" metadata = MetadataWithTags( + identifier_type="class", name="test_item", class_name="TestClass", - description="A test item", + class_module="test.module", + class_description="A test item", tags=("tag1", "tag2", "tag3"), ) assert _matches_filters(metadata, include_filters={"tags": "tag1"}) is True @@ -88,19 +103,23 @@ def test_matches_filters_list_value_contains_filter(self): def test_matches_filters_list_value_not_contains_filter(self): """Test filtering when metadata value is a list and filter value is not in the list.""" metadata = MetadataWithTags( + identifier_type="class", name="test_item", class_name="TestClass", - description="A test item", + class_module="test.module", + class_description="A test item", tags=("tag1", "tag2", "tag3"), ) assert _matches_filters(metadata, include_filters={"tags": "missing_tag"}) is False def test_matches_filters_exclude_exact_match(self): """Test that exclude filters work for exact matches.""" - metadata = RegistryItemMetadata( + metadata = Identifier( + identifier_type="class", name="test_item", class_name="TestClass", - description="A test item", + class_module="test.module", + class_description="A test item", ) assert _matches_filters(metadata, exclude_filters={"name": "test_item"}) is False assert _matches_filters(metadata, exclude_filters={"name": "other_item"}) is True @@ -108,9 +127,11 @@ def test_matches_filters_exclude_exact_match(self): def test_matches_filters_exclude_list_value(self): """Test exclude filters work for list values.""" metadata = MetadataWithTags( + identifier_type="class", name="test_item", class_name="TestClass", - description="A test item", + class_module="test.module", + class_description="A test item", tags=("tag1", "tag2", "tag3"), ) assert _matches_filters(metadata, exclude_filters={"tags": "tag1"}) is False @@ -118,20 +139,24 @@ def test_matches_filters_exclude_list_value(self): def test_matches_filters_exclude_nonexistent_key(self): """Test that exclude filters for non-existent keys don't exclude the item.""" - metadata = RegistryItemMetadata( + metadata = Identifier( + identifier_type="class", name="test_item", class_name="TestClass", - description="A test item", + class_module="test.module", + class_description="A test item", ) # Non-existent key in exclude filter should not exclude the item assert _matches_filters(metadata, exclude_filters={"nonexistent_key": "value"}) is True def test_matches_filters_combined_include_and_exclude(self): """Test combined include and exclude filters.""" - metadata = RegistryItemMetadata( + metadata = Identifier( + identifier_type="class", name="test_item", class_name="TestClass", - description="A test item", + class_module="test.module", + class_description="A test item", ) # Include matches, exclude doesn't -> should pass assert ( @@ -156,27 +181,106 @@ def test_matches_filters_combined_include_and_exclude(self): ) -class TestRegistryItemMetadata: - """Tests for the RegistryItemMetadata dataclass.""" +class TestIdentifier: + """Tests for the Identifier dataclass and hash computation.""" - def test_registry_item_metadata_creation(self): - """Test creating a RegistryItemMetadata instance.""" - metadata = RegistryItemMetadata( + def test_identifier_creation(self): + """Test creating an Identifier instance.""" + metadata = Identifier( + identifier_type="class", name="test_scorer", class_name="TestScorer", - description="A test scorer for testing", + class_module="pyrit.test.scorer", + class_description="A test scorer for testing", ) + assert metadata.identifier_type == "class" assert metadata.name == "test_scorer" assert metadata.class_name == "TestScorer" - assert metadata.description == "A test scorer for testing" + assert metadata.class_module == "pyrit.test.scorer" + assert metadata.class_description == "A test scorer for testing" - def test_registry_item_metadata_is_frozen(self): - """Test that RegistryItemMetadata is immutable.""" - metadata = RegistryItemMetadata( + def test_identifier_is_frozen(self): + """Test that Identifier is immutable.""" + metadata = Identifier( + identifier_type="class", name="test_item", class_name="TestClass", - description="Description here", + class_module="test.module", + class_description="Description here", ) with pytest.raises(AttributeError): metadata.name = "new_name" # type: ignore[misc] + + def test_identifier_hash_computed_at_creation(self): + """Test that hash is computed when the Identifier is created.""" + identifier = Identifier( + identifier_type="instance", + name="test_item", + class_name="TestClass", + class_module="test.module", + class_description="A test description", + ) + assert identifier.hash is not None + assert len(identifier.hash) == 64 # SHA256 hex length + + def test_identifier_hash_is_deterministic(self): + """Test that the same inputs produce the same hash.""" + identifier1 = Identifier( + identifier_type="class", + name="test_item", + class_name="TestClass", + class_module="test.module", + class_description="A test description", + ) + identifier2 = Identifier( + identifier_type="class", + name="test_item", + class_name="TestClass", + class_module="test.module", + class_description="A test description", + ) + assert identifier1.hash == identifier2.hash + + def test_identifier_hash_differs_for_different_inputs(self): + """Test that different inputs produce different hashes.""" + identifier1 = Identifier( + identifier_type="class", + name="test_item", + class_name="TestClass", + class_module="test.module", + class_description="A test description", + ) + identifier2 = Identifier( + identifier_type="class", + name="different_item", + class_name="TestClass", + class_module="test.module", + class_description="A test description", + ) + assert identifier1.hash != identifier2.hash + + def test_identifier_hash_is_immutable(self): + """Test that the hash cannot be modified.""" + identifier = Identifier( + identifier_type="class", + name="test_item", + class_name="TestClass", + class_module="test.module", + class_description="A test description", + ) + with pytest.raises(AttributeError): + identifier.hash = "new_hash" # type: ignore[misc] + + def test_identifier_subclass_inherits_hash(self): + """Test that subclasses of Identifier also get a computed hash.""" + metadata = MetadataWithTags( + identifier_type="class", + name="test_item", + class_name="TestClass", + class_module="test.module", + class_description="A test description", + tags=("tag1", "tag2"), + ) + assert metadata.hash is not None + assert len(metadata.hash) == 64 diff --git a/tests/unit/registry/test_base_instance_registry.py b/tests/unit/registry/test_base_instance_registry.py index 886de0028..075c6756a 100644 --- a/tests/unit/registry/test_base_instance_registry.py +++ b/tests/unit/registry/test_base_instance_registry.py @@ -3,12 +3,12 @@ from dataclasses import dataclass -from pyrit.registry.base import RegistryItemMetadata +from pyrit.models.identifiers import Identifier from pyrit.registry.instance_registries.base_instance_registry import BaseInstanceRegistry @dataclass(frozen=True) -class SampleItemMetadata(RegistryItemMetadata): +class SampleItemMetadata(Identifier): """Sample metadata with an extra field.""" category: str @@ -20,9 +20,11 @@ class ConcreteTestRegistry(BaseInstanceRegistry[str, SampleItemMetadata]): def _build_metadata(self, name: str, instance: str) -> SampleItemMetadata: """Build test metadata from a string instance.""" return SampleItemMetadata( + identifier_type="instance", name=name, class_name="str", - description=f"Description for {instance}", + class_module="builtins", + class_description=f"Description for {instance}", category="test" if "test" in instance.lower() else "other", ) diff --git a/tests/unit/registry/test_scorer_registry.py b/tests/unit/registry/test_scorer_registry.py index 638b52684..31feabc2a 100644 --- a/tests/unit/registry/test_scorer_registry.py +++ b/tests/unit/registry/test_scorer_registry.py @@ -246,13 +246,13 @@ def test_build_metadata_includes_scorer_identifier(self): assert isinstance(metadata[0].scorer_identifier, ScorerIdentifier) def test_build_metadata_description_from_docstring(self): - """Test that description is derived from the scorer's docstring.""" + """Test that class_description is derived from the scorer's docstring.""" scorer = MockTrueFalseScorer() self.registry.register_instance(scorer, name="tf_scorer") metadata = self.registry.list_metadata() # MockTrueFalseScorer has a docstring - assert "Mock TrueFalseScorer for testing" in metadata[0].description + assert "Mock TrueFalseScorer for testing" in metadata[0].class_description class TestScorerRegistryListMetadataFiltering: @@ -360,15 +360,19 @@ def test_scorer_metadata_has_required_fields(self): mock_identifier = ScorerIdentifier(type="test_type") metadata = ScorerMetadata( + identifier_type="instance", name="test_scorer", class_name="TestScorer", - description="A test scorer", + class_module="test.module", + class_description="A test scorer", scorer_type="true_false", scorer_identifier=mock_identifier, ) + assert metadata.identifier_type == "instance" assert metadata.name == "test_scorer" assert metadata.class_name == "TestScorer" - assert metadata.description == "A test scorer" + assert metadata.class_module == "test.module" + assert metadata.class_description == "A test scorer" assert metadata.scorer_type == "true_false" assert metadata.scorer_identifier == mock_identifier