From 80469d5bc292c11da7e91b0da852b6006ced3233 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Fri, 23 Jan 2026 10:55:38 -0800 Subject: [PATCH 1/5] refactoring register metadata --- doc/api.rst | 1 + pyrit/cli/frontend_core.py | 8 +- pyrit/models/__init__.py | 3 +- pyrit/models/identifiers.py | 63 +++++++- pyrit/prompt_converter/prompt_converter.py | 4 +- pyrit/prompt_target/common/prompt_target.py | 4 +- pyrit/registry/__init__.py | 5 +- pyrit/registry/base.py | 18 +-- .../class_registries/base_class_registry.py | 14 +- .../class_registries/initializer_registry.py | 16 +- .../class_registries/scenario_registry.py | 7 +- .../base_instance_registry.py | 5 +- .../instance_registries/scorer_registry.py | 7 +- tests/unit/cli/test_frontend_core.py | 32 ++-- tests/unit/registry/test_base.py | 150 ++++++++++++++---- .../registry/test_base_instance_registry.py | 7 +- tests/unit/registry/test_scorer_registry.py | 10 +- 17 files changed, 252 insertions(+), 102 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index c475c1923..ca966c587 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -332,6 +332,7 @@ API Reference group_conversation_message_pieces_by_sequence group_message_pieces_into_conversations HarmDefinition + Identifiable Identifier ImagePathDataTypeSerializer AllowedCategories diff --git a/pyrit/cli/frontend_core.py b/pyrit/cli/frontend_core.py index ca1d6f2ad..3739b54cd 100644 --- a/pyrit/cli/frontend_core.py +++ b/pyrit/cli/frontend_core.py @@ -387,7 +387,7 @@ def format_scenario_metadata(*, scenario_metadata: ScenarioMetadata) -> None: _print_header(text=scenario_metadata.name) print(f" Class: {scenario_metadata.class_name}") - description = scenario_metadata.description + description = scenario_metadata.class_description if description: print(" Description:") print(_format_wrapped_text(text=description, indent=" ")) @@ -428,7 +428,7 @@ def format_initializer_metadata(*, initializer_metadata: "InitializerMetadata") """ _print_header(text=initializer_metadata.name) print(f" Class: {initializer_metadata.class_name}") - print(f" Name: {initializer_metadata.initializer_name}") + print(f" Name: {initializer_metadata.display_name}") print(f" Execution Order: {initializer_metadata.execution_order}") if initializer_metadata.required_env_vars: @@ -438,9 +438,9 @@ def format_initializer_metadata(*, initializer_metadata: "InitializerMetadata") else: print(" Required Environment Variables: None") - if initializer_metadata.description: + if initializer_metadata.class_description: print(" Description:") - print(_format_wrapped_text(text=initializer_metadata.description, indent=" ")) + print(_format_wrapped_text(text=initializer_metadata.class_description, indent=" ")) def validate_database(*, database: str) -> str: diff --git a/pyrit/models/__init__.py b/pyrit/models/__init__.py index 5d6e30665..2d4d0a80e 100644 --- a/pyrit/models/__init__.py +++ b/pyrit/models/__init__.py @@ -22,7 +22,7 @@ ) from pyrit.models.embeddings import EmbeddingData, EmbeddingResponse, EmbeddingSupport, EmbeddingUsageInformation from pyrit.models.harm_definition import HarmDefinition, ScaleDescription, get_all_harm_definitions -from pyrit.models.identifiers import Identifier +from pyrit.models.identifiers import Identifiable, Identifier from pyrit.models.literals import ChatMessageRole, PromptDataType, PromptResponseError, SeedType from pyrit.models.message import ( Message, @@ -82,6 +82,7 @@ "group_conversation_message_pieces_by_sequence", "group_message_pieces_into_conversations", "HarmDefinition", + "Identifiable", "Identifier", "ImagePathDataTypeSerializer", "Message", diff --git a/pyrit/models/identifiers.py b/pyrit/models/identifiers.py index c70eabdf2..10792d99f 100644 --- a/pyrit/models/identifiers.py +++ b/pyrit/models/identifiers.py @@ -1,13 +1,74 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import hashlib +import json from abc import abstractmethod +from dataclasses import asdict, dataclass, field, fields, is_dataclass +from typing import Any -class Identifier: +class Identifiable: + """ + Abstract base class for objects that can provide an identifier dictionary. + + This is a legacy interface that will eventually be replaced by Identifier dataclass. + Classes implementing this interface should return a dict describing their identity. + """ + @abstractmethod def get_identifier(self) -> dict[str, str]: pass def __str__(self) -> str: return f"{self.get_identifier}" + + +@dataclass(frozen=True) +class Identifier: + """ + Base dataclass for identifying PyRIT components. + + This frozen dataclass provides a stable identifier for registry items, + targets, scorers, attacks, converters, and other components. The hash is computed at creation + time from the core fields and remains constant. + + This class serves as: + 1. Base for registry metadata (replacing RegistryItemMetadata) + 2. Future replacement for get_identifier() dict patterns + + All component-specific identifier types should extend this with additional fields. + """ + + name: str # The snake_case identifier name (e.g., "self_ask_refusal") + class_name: str # The actual class name, equivalent to __type__ (e.g., "SelfAskRefusalScorer") + class_module: str # The module path, equivalent to __module__ (e.g., "pyrit.score.self_ask_refusal_scorer") + class_description: str # Description from docstring or manual override + hash: str = field(init=False, compare=False) + + def __post_init__(self) -> None: + """Compute the identifier hash from core fields.""" + # Use object.__setattr__ since this is a frozen dataclass + object.__setattr__(self, "hash", self._compute_hash()) + + def _compute_hash(self) -> str: + """ + Compute a stable SHA256 hash from all identifier fields. + + All fields except 'hash' itself are automatically included. + + Returns: + A hex string of the SHA256 hash. + """ + hashable_dict: dict[str, Any] = { + f.name: getattr(self, f.name) for f in fields(self) if f.name != "hash" + } + config_json = json.dumps(hashable_dict, sort_keys=True, separators=(",", ":"), default=_dataclass_encoder) + return hashlib.sha256(config_json.encode("utf-8")).hexdigest() + + +def _dataclass_encoder(obj: Any) -> Any: + """JSON encoder that handles dataclasses by converting them to dicts.""" + if is_dataclass(obj) and not isinstance(obj, type): + return asdict(obj) + raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable") \ No newline at end of file diff --git a/pyrit/prompt_converter/prompt_converter.py b/pyrit/prompt_converter/prompt_converter.py index db9497b81..55a315a00 100644 --- a/pyrit/prompt_converter/prompt_converter.py +++ b/pyrit/prompt_converter/prompt_converter.py @@ -9,7 +9,7 @@ from typing import get_args from pyrit import prompt_converter -from pyrit.models import Identifier, PromptDataType +from pyrit.models import Identifiable, PromptDataType @dataclass @@ -31,7 +31,7 @@ def __str__(self) -> str: return f"{self.output_type}: {self.output_text}" -class PromptConverter(abc.ABC, Identifier): +class PromptConverter(abc.ABC, Identifiable): """ Base class for converters that transform prompts into a different representation or format. diff --git a/pyrit/prompt_target/common/prompt_target.py b/pyrit/prompt_target/common/prompt_target.py index 78b466e04..f39ba537c 100644 --- a/pyrit/prompt_target/common/prompt_target.py +++ b/pyrit/prompt_target/common/prompt_target.py @@ -6,12 +6,12 @@ from typing import Any, Dict, List, Optional from pyrit.memory import CentralMemory, MemoryInterface -from pyrit.models import Identifier, Message +from pyrit.models import Identifiable, Message logger = logging.getLogger(__name__) -class PromptTarget(abc.ABC, Identifier): +class PromptTarget(abc.ABC, Identifiable): """ Abstract base class for prompt targets. diff --git a/pyrit/registry/__init__.py b/pyrit/registry/__init__.py index ba0566020..7c8c8c1fa 100644 --- a/pyrit/registry/__init__.py +++ b/pyrit/registry/__init__.py @@ -3,7 +3,8 @@ """Registry module for PyRIT class and instance registries.""" -from pyrit.registry.base import RegistryItemMetadata, RegistryProtocol +from pyrit.models.identifiers import Identifier +from pyrit.registry.base import RegistryProtocol from pyrit.registry.class_registries import ( BaseClassRegistry, ClassEntry, @@ -32,9 +33,9 @@ "discover_in_directory", "discover_in_package", "discover_subclasses_in_loaded_modules", + "Identifier", "InitializerMetadata", "InitializerRegistry", - "RegistryItemMetadata", "RegistryProtocol", "registry_name_to_class_name", "ScenarioMetadata", diff --git a/pyrit/registry/base.py b/pyrit/registry/base.py index a9c284561..89b29248f 100644 --- a/pyrit/registry/base.py +++ b/pyrit/registry/base.py @@ -8,9 +8,10 @@ and instance registries (which store T instances). """ -from dataclasses import dataclass from typing import Any, Dict, Iterator, List, Optional, Protocol, TypeVar, runtime_checkable +from pyrit.models.identifiers import Identifier + # Type variable for metadata (invariant for Protocol compatibility) MetadataT = TypeVar("MetadataT") @@ -77,21 +78,6 @@ def __iter__(self) -> Iterator[str]: ... -@dataclass(frozen=True) -class RegistryItemMetadata: - """ - Base dataclass for registry item metadata. - - This dataclass provides descriptive information about a registered item - (either a class or an instance). It is NOT the item itself - it's a - structured object describing the item. - - All registry-specific metadata types should extend this with additional fields. - """ - - name: str # The snake_case registry name (e.g., "self_ask_refusal") - class_name: str # The actual class name (e.g., "SelfAskRefusalScorer") - description: str # Description from docstring or manual override def _matches_filters( diff --git a/pyrit/registry/class_registries/base_class_registry.py b/pyrit/registry/class_registries/base_class_registry.py index bd6ae3f27..549581584 100644 --- a/pyrit/registry/class_registries/base_class_registry.py +++ b/pyrit/registry/class_registries/base_class_registry.py @@ -19,7 +19,8 @@ from abc import ABC, abstractmethod from typing import Callable, Dict, Generic, Iterator, List, Optional, Type, TypeVar -from pyrit.registry.base import RegistryItemMetadata, RegistryProtocol +from pyrit.models.identifiers import Identifier +from pyrit.registry.base import RegistryProtocol from pyrit.registry.name_utils import class_name_to_registry_name # Type variable for the registered class type @@ -182,11 +183,11 @@ def _build_metadata(self, name: str, entry: ClassEntry[T]) -> MetadataT: """ pass - def _build_base_metadata(self, name: str, entry: ClassEntry[T]) -> RegistryItemMetadata: + def _build_base_metadata(self, name: str, entry: ClassEntry[T]) -> Identifier: """ Build the common base metadata for a registered class. - This helper extracts fields common to all registries: name, class_name, description. + This helper extracts fields common to all registries: name, class_name, class_description. Subclasses can use this for building common fields if needed. Args: @@ -194,7 +195,7 @@ def _build_base_metadata(self, name: str, entry: ClassEntry[T]) -> RegistryItemM entry: The ClassEntry containing the registered class. Returns: - A RegistryItemMetadata dataclass with common fields. + An Identifier dataclass with common fields. """ registered_class = entry.registered_class @@ -205,10 +206,11 @@ def _build_base_metadata(self, name: str, entry: ClassEntry[T]) -> RegistryItemM else: description = entry.description or "No description available" - return RegistryItemMetadata( + return Identifier( name=name, class_name=registered_class.__name__, - description=description, + class_module=registered_class.__module__, + class_description=description, ) def get_class(self, name: str) -> Type[T]: diff --git a/pyrit/registry/class_registries/initializer_registry.py b/pyrit/registry/class_registries/initializer_registry.py index fed6d7b6c..6c8cc51e4 100644 --- a/pyrit/registry/class_registries/initializer_registry.py +++ b/pyrit/registry/class_registries/initializer_registry.py @@ -16,7 +16,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Dict, Optional -from pyrit.registry.base import RegistryItemMetadata +from pyrit.models.identifiers import Identifier from pyrit.registry.class_registries.base_class_registry import ( BaseClassRegistry, ClassEntry, @@ -34,14 +34,14 @@ @dataclass(frozen=True) -class InitializerMetadata(RegistryItemMetadata): +class InitializerMetadata(Identifier): """ Metadata describing a registered PyRITInitializer class. Use get_class() to get the actual class. """ - initializer_name: str + display_name: str required_env_vars: tuple[str, ...] execution_order: int @@ -210,8 +210,9 @@ def _build_metadata(self, name: str, entry: ClassEntry["PyRITInitializer"]) -> I return InitializerMetadata( name=name, class_name=initializer_class.__name__, - description=instance.description, - initializer_name=instance.name, + class_module=initializer_class.__module__, + class_description=instance.description, + display_name=instance.name, required_env_vars=tuple(instance.required_env_vars), execution_order=instance.execution_order, ) @@ -220,8 +221,9 @@ def _build_metadata(self, name: str, entry: ClassEntry["PyRITInitializer"]) -> I return InitializerMetadata( name=name, class_name=initializer_class.__name__, - description="Error loading initializer metadata", - initializer_name=name, + class_module=initializer_class.__module__, + class_description="Error loading initializer metadata", + display_name=name, required_env_vars=(), execution_order=100, ) diff --git a/pyrit/registry/class_registries/scenario_registry.py b/pyrit/registry/class_registries/scenario_registry.py index 40693f0cd..9f68f8fe6 100644 --- a/pyrit/registry/class_registries/scenario_registry.py +++ b/pyrit/registry/class_registries/scenario_registry.py @@ -15,7 +15,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Optional -from pyrit.registry.base import RegistryItemMetadata +from pyrit.models.identifiers import Identifier from pyrit.registry.class_registries.base_class_registry import ( BaseClassRegistry, ClassEntry, @@ -33,7 +33,7 @@ @dataclass(frozen=True) -class ScenarioMetadata(RegistryItemMetadata): +class ScenarioMetadata(Identifier): """ Metadata describing a registered Scenario class. @@ -172,7 +172,8 @@ def _build_metadata(self, name: str, entry: ClassEntry["Scenario"]) -> ScenarioM return ScenarioMetadata( name=name, class_name=scenario_class.__name__, - description=description, + class_module=scenario_class.__module__, + class_description=description, default_strategy=scenario_class.get_default_strategy().value, all_strategies=tuple(s.value for s in strategy_class.get_all_strategies()), aggregate_strategies=tuple(s.value for s in strategy_class.get_aggregate_strategies()), diff --git a/pyrit/registry/instance_registries/base_instance_registry.py b/pyrit/registry/instance_registries/base_instance_registry.py index 351e927e0..2c12fab99 100644 --- a/pyrit/registry/instance_registries/base_instance_registry.py +++ b/pyrit/registry/instance_registries/base_instance_registry.py @@ -18,10 +18,11 @@ from abc import ABC, abstractmethod from typing import Any, Dict, Generic, Iterator, List, Optional, TypeVar -from pyrit.registry.base import RegistryItemMetadata, RegistryProtocol +from pyrit.models.identifiers import Identifier +from pyrit.registry.base import RegistryProtocol T = TypeVar("T") # The type of instances stored -MetadataT = TypeVar("MetadataT", bound=RegistryItemMetadata) +MetadataT = TypeVar("MetadataT", bound=Identifier) class BaseInstanceRegistry(ABC, RegistryProtocol[MetadataT], Generic[T, MetadataT]): diff --git a/pyrit/registry/instance_registries/scorer_registry.py b/pyrit/registry/instance_registries/scorer_registry.py index 7a9411a8b..6c574f72b 100644 --- a/pyrit/registry/instance_registries/scorer_registry.py +++ b/pyrit/registry/instance_registries/scorer_registry.py @@ -13,7 +13,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Optional -from pyrit.registry.base import RegistryItemMetadata +from pyrit.models.identifiers import Identifier from pyrit.registry.instance_registries.base_instance_registry import ( BaseInstanceRegistry, ) @@ -29,7 +29,7 @@ @dataclass(frozen=True) -class ScorerMetadata(RegistryItemMetadata): +class ScorerMetadata(Identifier): """ Metadata describing a registered scorer instance. @@ -132,7 +132,8 @@ def _build_metadata(self, name: str, instance: "Scorer") -> ScorerMetadata: return ScorerMetadata( name=name, class_name=instance.__class__.__name__, - description=description, + class_module=instance.__class__.__module__, + class_description=description, scorer_type=scorer_type, scorer_identifier=instance.scorer_identifier, ) diff --git a/tests/unit/cli/test_frontend_core.py b/tests/unit/cli/test_frontend_core.py index 22b51294e..cb02a1b7f 100644 --- a/tests/unit/cli/test_frontend_core.py +++ b/tests/unit/cli/test_frontend_core.py @@ -330,7 +330,8 @@ async def test_print_scenarios_list_with_scenarios(self, capsys): ScenarioMetadata( name="test_scenario", class_name="TestScenario", - description="Test description", + class_module="test.scenarios", + class_description="Test description", default_strategy="default", all_strategies=(), aggregate_strategies=(), @@ -370,8 +371,9 @@ async def test_print_initializers_list_with_initializers(self, capsys): InitializerMetadata( name="test_init", class_name="TestInit", - description="Test initializer", - initializer_name="test", + class_module="test.initializers", + class_description="Test initializer", + display_name="test", execution_order=100, required_env_vars=(), ) @@ -410,7 +412,8 @@ def test_format_scenario_metadata_basic(self, capsys): scenario_metadata = ScenarioMetadata( name="test_scenario", class_name="TestScenario", - description="", + class_module="test.scenarios", + class_description="", default_strategy="", all_strategies=(), aggregate_strategies=(), @@ -430,7 +433,8 @@ def test_format_scenario_metadata_with_description(self, capsys): scenario_metadata = ScenarioMetadata( name="test_scenario", class_name="TestScenario", - description="This is a test scenario", + class_module="test.scenarios", + class_description="This is a test scenario", default_strategy="", all_strategies=(), aggregate_strategies=(), @@ -448,7 +452,8 @@ def test_format_scenario_metadata_with_strategies(self, capsys): scenario_metadata = ScenarioMetadata( name="test_scenario", class_name="TestScenario", - description="", + class_module="test.scenarios", + class_description="", default_strategy="strategy1", all_strategies=("strategy1", "strategy2"), aggregate_strategies=(), @@ -468,8 +473,9 @@ def test_format_initializer_metadata_basic(self, capsys) -> None: initializer_metadata = InitializerMetadata( name="test_init", class_name="TestInit", - description="", - initializer_name="test", + class_module="test.initializers", + class_description="", + display_name="test", required_env_vars=(), execution_order=100, ) @@ -486,8 +492,9 @@ def test_format_initializer_metadata_with_env_vars(self, capsys) -> None: initializer_metadata = InitializerMetadata( name="test_init", class_name="TestInit", - description="", - initializer_name="test", + class_module="test.initializers", + class_description="", + display_name="test", required_env_vars=("VAR1", "VAR2"), execution_order=100, ) @@ -503,8 +510,9 @@ def test_format_initializer_metadata_with_description(self, capsys) -> None: initializer_metadata = InitializerMetadata( name="test_init", class_name="TestInit", - description="Test description", - initializer_name="test", + class_module="test.initializers", + class_description="Test description", + display_name="test", required_env_vars=(), execution_order=100, ) diff --git a/tests/unit/registry/test_base.py b/tests/unit/registry/test_base.py index e96d695ec..2c1f739d3 100644 --- a/tests/unit/registry/test_base.py +++ b/tests/unit/registry/test_base.py @@ -5,11 +5,12 @@ import pytest -from pyrit.registry.base import RegistryItemMetadata, _matches_filters +from pyrit.models.identifiers import Identifier +from pyrit.registry.base import _matches_filters @dataclass(frozen=True) -class MetadataWithTags(RegistryItemMetadata): +class MetadataWithTags(Identifier): """Test metadata with a tags field for list filtering tests.""" tags: tuple[str, ...] @@ -20,57 +21,63 @@ class TestMatchesFilters: def test_matches_filters_exact_match_string(self): """Test that exact string matches work.""" - metadata = RegistryItemMetadata( + metadata = Identifier( name="test_item", class_name="TestClass", - description="A test item", + class_module="test.module", + class_description="A test item", ) assert _matches_filters(metadata, include_filters={"name": "test_item"}) is True assert _matches_filters(metadata, include_filters={"class_name": "TestClass"}) is True def test_matches_filters_no_match_string(self): """Test that non-matching strings return False.""" - metadata = RegistryItemMetadata( + metadata = Identifier( name="test_item", class_name="TestClass", - description="A test item", + class_module="test.module", + class_description="A test item", ) assert _matches_filters(metadata, include_filters={"name": "other_item"}) is False assert _matches_filters(metadata, include_filters={"class_name": "OtherClass"}) is False def test_matches_filters_multiple_filters_all_match(self): """Test that all filters must match.""" - metadata = RegistryItemMetadata( + metadata = Identifier( name="test_item", class_name="TestClass", - description="A test item", + class_module="test.module", + class_description="A test item", ) assert _matches_filters(metadata, include_filters={"name": "test_item", "class_name": "TestClass"}) is True def test_matches_filters_multiple_filters_partial_match(self): """Test that partial matches return False when not all filters match.""" - metadata = RegistryItemMetadata( + metadata = Identifier( name="test_item", class_name="TestClass", - description="A test item", + class_module="test.module", + class_description="A test item", ) assert _matches_filters(metadata, include_filters={"name": "test_item", "class_name": "OtherClass"}) is False def test_matches_filters_key_not_in_metadata(self): """Test that filtering on a non-existent key returns False.""" - metadata = RegistryItemMetadata( + metadata = Identifier( name="test_item", class_name="TestClass", - description="A test item", + class_module="test.module", + class_description="A test item", ) assert _matches_filters(metadata, include_filters={"nonexistent_key": "value"}) is False def test_matches_filters_empty_filters(self): """Test that empty filters return True.""" - metadata = RegistryItemMetadata( + metadata = Identifier( name="test_item", class_name="TestClass", - description="A test item", + class_module="test.module", + class_description="A test item", ) assert _matches_filters(metadata) is True @@ -79,7 +86,8 @@ def test_matches_filters_list_value_contains_filter(self): metadata = MetadataWithTags( name="test_item", class_name="TestClass", - description="A test item", + class_module="test.module", + class_description="A test item", tags=("tag1", "tag2", "tag3"), ) assert _matches_filters(metadata, include_filters={"tags": "tag1"}) is True @@ -90,17 +98,19 @@ def test_matches_filters_list_value_not_contains_filter(self): metadata = MetadataWithTags( name="test_item", class_name="TestClass", - description="A test item", + class_module="test.module", + class_description="A test item", tags=("tag1", "tag2", "tag3"), ) assert _matches_filters(metadata, include_filters={"tags": "missing_tag"}) is False def test_matches_filters_exclude_exact_match(self): """Test that exclude filters work for exact matches.""" - metadata = RegistryItemMetadata( + metadata = Identifier( name="test_item", class_name="TestClass", - description="A test item", + class_module="test.module", + class_description="A test item", ) assert _matches_filters(metadata, exclude_filters={"name": "test_item"}) is False assert _matches_filters(metadata, exclude_filters={"name": "other_item"}) is True @@ -110,7 +120,8 @@ def test_matches_filters_exclude_list_value(self): metadata = MetadataWithTags( name="test_item", class_name="TestClass", - description="A test item", + class_module="test.module", + class_description="A test item", tags=("tag1", "tag2", "tag3"), ) assert _matches_filters(metadata, exclude_filters={"tags": "tag1"}) is False @@ -118,20 +129,22 @@ def test_matches_filters_exclude_list_value(self): def test_matches_filters_exclude_nonexistent_key(self): """Test that exclude filters for non-existent keys don't exclude the item.""" - metadata = RegistryItemMetadata( + metadata = Identifier( name="test_item", class_name="TestClass", - description="A test item", + class_module="test.module", + class_description="A test item", ) # Non-existent key in exclude filter should not exclude the item assert _matches_filters(metadata, exclude_filters={"nonexistent_key": "value"}) is True def test_matches_filters_combined_include_and_exclude(self): """Test combined include and exclude filters.""" - metadata = RegistryItemMetadata( + metadata = Identifier( name="test_item", class_name="TestClass", - description="A test item", + class_module="test.module", + class_description="A test item", ) # Include matches, exclude doesn't -> should pass assert ( @@ -156,27 +169,96 @@ def test_matches_filters_combined_include_and_exclude(self): ) -class TestRegistryItemMetadata: - """Tests for the RegistryItemMetadata dataclass.""" +class TestIdentifier: + """Tests for the Identifier dataclass and hash computation.""" - def test_registry_item_metadata_creation(self): - """Test creating a RegistryItemMetadata instance.""" - metadata = RegistryItemMetadata( + def test_identifier_creation(self): + """Test creating an Identifier instance.""" + metadata = Identifier( name="test_scorer", class_name="TestScorer", - description="A test scorer for testing", + class_module="pyrit.test.scorer", + class_description="A test scorer for testing", ) assert metadata.name == "test_scorer" assert metadata.class_name == "TestScorer" - assert metadata.description == "A test scorer for testing" + assert metadata.class_module == "pyrit.test.scorer" + assert metadata.class_description == "A test scorer for testing" - def test_registry_item_metadata_is_frozen(self): - """Test that RegistryItemMetadata is immutable.""" - metadata = RegistryItemMetadata( + def test_identifier_is_frozen(self): + """Test that Identifier is immutable.""" + metadata = Identifier( name="test_item", class_name="TestClass", - description="Description here", + class_module="test.module", + class_description="Description here", ) with pytest.raises(AttributeError): metadata.name = "new_name" # type: ignore[misc] + + def test_identifier_hash_computed_at_creation(self): + """Test that hash is computed when the Identifier is created.""" + identifier = Identifier( + name="test_item", + class_name="TestClass", + class_module="test.module", + class_description="A test description", + ) + assert identifier.hash is not None + assert len(identifier.hash) == 64 # SHA256 hex length + + def test_identifier_hash_is_deterministic(self): + """Test that the same inputs produce the same hash.""" + identifier1 = Identifier( + name="test_item", + class_name="TestClass", + class_module="test.module", + class_description="A test description", + ) + identifier2 = Identifier( + name="test_item", + class_name="TestClass", + class_module="test.module", + class_description="A test description", + ) + assert identifier1.hash == identifier2.hash + + def test_identifier_hash_differs_for_different_inputs(self): + """Test that different inputs produce different hashes.""" + identifier1 = Identifier( + name="test_item", + class_name="TestClass", + class_module="test.module", + class_description="A test description", + ) + identifier2 = Identifier( + name="different_item", + class_name="TestClass", + class_module="test.module", + class_description="A test description", + ) + assert identifier1.hash != identifier2.hash + + def test_identifier_hash_is_immutable(self): + """Test that the hash cannot be modified.""" + identifier = Identifier( + name="test_item", + class_name="TestClass", + class_module="test.module", + class_description="A test description", + ) + with pytest.raises(AttributeError): + identifier.hash = "new_hash" # type: ignore[misc] + + def test_identifier_subclass_inherits_hash(self): + """Test that subclasses of Identifier also get a computed hash.""" + metadata = MetadataWithTags( + name="test_item", + class_name="TestClass", + class_module="test.module", + class_description="A test description", + tags=("tag1", "tag2"), + ) + assert metadata.hash is not None + assert len(metadata.hash) == 64 diff --git a/tests/unit/registry/test_base_instance_registry.py b/tests/unit/registry/test_base_instance_registry.py index 886de0028..4760c1f05 100644 --- a/tests/unit/registry/test_base_instance_registry.py +++ b/tests/unit/registry/test_base_instance_registry.py @@ -3,12 +3,12 @@ from dataclasses import dataclass -from pyrit.registry.base import RegistryItemMetadata +from pyrit.models.identifiers import Identifier from pyrit.registry.instance_registries.base_instance_registry import BaseInstanceRegistry @dataclass(frozen=True) -class SampleItemMetadata(RegistryItemMetadata): +class SampleItemMetadata(Identifier): """Sample metadata with an extra field.""" category: str @@ -22,7 +22,8 @@ def _build_metadata(self, name: str, instance: str) -> SampleItemMetadata: return SampleItemMetadata( name=name, class_name="str", - description=f"Description for {instance}", + class_module="builtins", + class_description=f"Description for {instance}", category="test" if "test" in instance.lower() else "other", ) diff --git a/tests/unit/registry/test_scorer_registry.py b/tests/unit/registry/test_scorer_registry.py index 638b52684..edd7e103f 100644 --- a/tests/unit/registry/test_scorer_registry.py +++ b/tests/unit/registry/test_scorer_registry.py @@ -246,13 +246,13 @@ def test_build_metadata_includes_scorer_identifier(self): assert isinstance(metadata[0].scorer_identifier, ScorerIdentifier) def test_build_metadata_description_from_docstring(self): - """Test that description is derived from the scorer's docstring.""" + """Test that class_description is derived from the scorer's docstring.""" scorer = MockTrueFalseScorer() self.registry.register_instance(scorer, name="tf_scorer") metadata = self.registry.list_metadata() # MockTrueFalseScorer has a docstring - assert "Mock TrueFalseScorer for testing" in metadata[0].description + assert "Mock TrueFalseScorer for testing" in metadata[0].class_description class TestScorerRegistryListMetadataFiltering: @@ -362,13 +362,15 @@ def test_scorer_metadata_has_required_fields(self): metadata = ScorerMetadata( name="test_scorer", class_name="TestScorer", - description="A test scorer", + class_module="test.module", + class_description="A test scorer", scorer_type="true_false", scorer_identifier=mock_identifier, ) assert metadata.name == "test_scorer" assert metadata.class_name == "TestScorer" - assert metadata.description == "A test scorer" + assert metadata.class_module == "test.module" + assert metadata.class_description == "A test scorer" assert metadata.scorer_type == "true_false" assert metadata.scorer_identifier == mock_identifier From 7b9118e272b567a422c834f87d5f556c3432aabf Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Fri, 23 Jan 2026 10:58:04 -0800 Subject: [PATCH 2/5] pre-commit --- pyrit/models/identifiers.py | 6 ++---- pyrit/registry/base.py | 4 ---- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/pyrit/models/identifiers.py b/pyrit/models/identifiers.py index 10792d99f..23d980f5b 100644 --- a/pyrit/models/identifiers.py +++ b/pyrit/models/identifiers.py @@ -60,9 +60,7 @@ def _compute_hash(self) -> str: Returns: A hex string of the SHA256 hash. """ - hashable_dict: dict[str, Any] = { - f.name: getattr(self, f.name) for f in fields(self) if f.name != "hash" - } + hashable_dict: dict[str, Any] = {f.name: getattr(self, f.name) for f in fields(self) if f.name != "hash"} config_json = json.dumps(hashable_dict, sort_keys=True, separators=(",", ":"), default=_dataclass_encoder) return hashlib.sha256(config_json.encode("utf-8")).hexdigest() @@ -71,4 +69,4 @@ def _dataclass_encoder(obj: Any) -> Any: """JSON encoder that handles dataclasses by converting them to dicts.""" if is_dataclass(obj) and not isinstance(obj, type): return asdict(obj) - raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable") \ No newline at end of file + raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable") diff --git a/pyrit/registry/base.py b/pyrit/registry/base.py index 89b29248f..5f5e37400 100644 --- a/pyrit/registry/base.py +++ b/pyrit/registry/base.py @@ -10,8 +10,6 @@ from typing import Any, Dict, Iterator, List, Optional, Protocol, TypeVar, runtime_checkable -from pyrit.models.identifiers import Identifier - # Type variable for metadata (invariant for Protocol compatibility) MetadataT = TypeVar("MetadataT") @@ -78,8 +76,6 @@ def __iter__(self) -> Iterator[str]: ... - - def _matches_filters( metadata: Any, *, From 650da3e42dfbdc70f9a25ac867a25a394589209a Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Fri, 23 Jan 2026 14:00:13 -0800 Subject: [PATCH 3/5] pr feedback --- doc/api.rst | 1 + pyrit/models/__init__.py | 3 ++- pyrit/models/identifiers.py | 5 ++++- .../class_registries/base_class_registry.py | 1 + .../class_registries/initializer_registry.py | 2 ++ .../class_registries/scenario_registry.py | 1 + .../instance_registries/scorer_registry.py | 1 + tests/unit/cli/test_frontend_core.py | 8 +++++++ tests/unit/registry/test_base.py | 22 +++++++++++++++++++ .../registry/test_base_instance_registry.py | 1 + tests/unit/registry/test_scorer_registry.py | 2 ++ 11 files changed, 45 insertions(+), 2 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index ca966c587..23919b52e 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -334,6 +334,7 @@ API Reference HarmDefinition Identifiable Identifier + IdentifierType ImagePathDataTypeSerializer AllowedCategories AttackOutcome diff --git a/pyrit/models/__init__.py b/pyrit/models/__init__.py index 2d4d0a80e..2a43439bf 100644 --- a/pyrit/models/__init__.py +++ b/pyrit/models/__init__.py @@ -22,7 +22,7 @@ ) from pyrit.models.embeddings import EmbeddingData, EmbeddingResponse, EmbeddingSupport, EmbeddingUsageInformation from pyrit.models.harm_definition import HarmDefinition, ScaleDescription, get_all_harm_definitions -from pyrit.models.identifiers import Identifiable, Identifier +from pyrit.models.identifiers import Identifiable, Identifier, IdentifierType from pyrit.models.literals import ChatMessageRole, PromptDataType, PromptResponseError, SeedType from pyrit.models.message import ( Message, @@ -84,6 +84,7 @@ "HarmDefinition", "Identifiable", "Identifier", + "IdentifierType", "ImagePathDataTypeSerializer", "Message", "MessagePiece", diff --git a/pyrit/models/identifiers.py b/pyrit/models/identifiers.py index 23d980f5b..8bc8b141b 100644 --- a/pyrit/models/identifiers.py +++ b/pyrit/models/identifiers.py @@ -5,7 +5,9 @@ import json from abc import abstractmethod from dataclasses import asdict, dataclass, field, fields, is_dataclass -from typing import Any +from typing import Any, Literal + +IdentifierType = Literal["class", "instance"] class Identifiable: @@ -44,6 +46,7 @@ class Identifier: class_name: str # The actual class name, equivalent to __type__ (e.g., "SelfAskRefusalScorer") class_module: str # The module path, equivalent to __module__ (e.g., "pyrit.score.self_ask_refusal_scorer") class_description: str # Description from docstring or manual override + identifier_type: IdentifierType # Whether this identifies a "class" or "instance" hash: str = field(init=False, compare=False) def __post_init__(self) -> None: diff --git a/pyrit/registry/class_registries/base_class_registry.py b/pyrit/registry/class_registries/base_class_registry.py index 549581584..5f2f5bc10 100644 --- a/pyrit/registry/class_registries/base_class_registry.py +++ b/pyrit/registry/class_registries/base_class_registry.py @@ -207,6 +207,7 @@ def _build_base_metadata(self, name: str, entry: ClassEntry[T]) -> Identifier: description = entry.description or "No description available" return Identifier( + identifier_type="class", name=name, class_name=registered_class.__name__, class_module=registered_class.__module__, diff --git a/pyrit/registry/class_registries/initializer_registry.py b/pyrit/registry/class_registries/initializer_registry.py index 6c8cc51e4..96dedad2f 100644 --- a/pyrit/registry/class_registries/initializer_registry.py +++ b/pyrit/registry/class_registries/initializer_registry.py @@ -208,6 +208,7 @@ def _build_metadata(self, name: str, entry: ClassEntry["PyRITInitializer"]) -> I try: instance = initializer_class() return InitializerMetadata( + identifier_type="class", name=name, class_name=initializer_class.__name__, class_module=initializer_class.__module__, @@ -219,6 +220,7 @@ def _build_metadata(self, name: str, entry: ClassEntry["PyRITInitializer"]) -> I except Exception as e: logger.warning(f"Failed to get metadata for {name}: {e}") return InitializerMetadata( + identifier_type="class", name=name, class_name=initializer_class.__name__, class_module=initializer_class.__module__, diff --git a/pyrit/registry/class_registries/scenario_registry.py b/pyrit/registry/class_registries/scenario_registry.py index 9f68f8fe6..01215662f 100644 --- a/pyrit/registry/class_registries/scenario_registry.py +++ b/pyrit/registry/class_registries/scenario_registry.py @@ -170,6 +170,7 @@ def _build_metadata(self, name: str, entry: ClassEntry["Scenario"]) -> ScenarioM max_dataset_size = dataset_config.max_dataset_size return ScenarioMetadata( + identifier_type="class", name=name, class_name=scenario_class.__name__, class_module=scenario_class.__module__, diff --git a/pyrit/registry/instance_registries/scorer_registry.py b/pyrit/registry/instance_registries/scorer_registry.py index 6c574f72b..72798cf15 100644 --- a/pyrit/registry/instance_registries/scorer_registry.py +++ b/pyrit/registry/instance_registries/scorer_registry.py @@ -130,6 +130,7 @@ def _build_metadata(self, name: str, instance: "Scorer") -> ScorerMetadata: scorer_type = "unknown" return ScorerMetadata( + identifier_type="instance", name=name, class_name=instance.__class__.__name__, class_module=instance.__class__.__module__, diff --git a/tests/unit/cli/test_frontend_core.py b/tests/unit/cli/test_frontend_core.py index cb02a1b7f..764a036aa 100644 --- a/tests/unit/cli/test_frontend_core.py +++ b/tests/unit/cli/test_frontend_core.py @@ -328,6 +328,7 @@ async def test_print_scenarios_list_with_scenarios(self, capsys): mock_registry = MagicMock() mock_registry.list_metadata.return_value = [ ScenarioMetadata( + identifier_type="class", name="test_scenario", class_name="TestScenario", class_module="test.scenarios", @@ -369,6 +370,7 @@ async def test_print_initializers_list_with_initializers(self, capsys): mock_registry = MagicMock() mock_registry.list_metadata.return_value = [ InitializerMetadata( + identifier_type="class", name="test_init", class_name="TestInit", class_module="test.initializers", @@ -410,6 +412,7 @@ def test_format_scenario_metadata_basic(self, capsys): """Test format_scenario_metadata with basic metadata.""" scenario_metadata = ScenarioMetadata( + identifier_type="class", name="test_scenario", class_name="TestScenario", class_module="test.scenarios", @@ -431,6 +434,7 @@ def test_format_scenario_metadata_with_description(self, capsys): """Test format_scenario_metadata with description.""" scenario_metadata = ScenarioMetadata( + identifier_type="class", name="test_scenario", class_name="TestScenario", class_module="test.scenarios", @@ -450,6 +454,7 @@ def test_format_scenario_metadata_with_description(self, capsys): def test_format_scenario_metadata_with_strategies(self, capsys): """Test format_scenario_metadata with strategies.""" scenario_metadata = ScenarioMetadata( + identifier_type="class", name="test_scenario", class_name="TestScenario", class_module="test.scenarios", @@ -471,6 +476,7 @@ def test_format_scenario_metadata_with_strategies(self, capsys): def test_format_initializer_metadata_basic(self, capsys) -> None: """Test format_initializer_metadata with basic metadata.""" initializer_metadata = InitializerMetadata( + identifier_type="class", name="test_init", class_name="TestInit", class_module="test.initializers", @@ -490,6 +496,7 @@ def test_format_initializer_metadata_basic(self, capsys) -> None: def test_format_initializer_metadata_with_env_vars(self, capsys) -> None: """Test format_initializer_metadata with environment variables.""" initializer_metadata = InitializerMetadata( + identifier_type="class", name="test_init", class_name="TestInit", class_module="test.initializers", @@ -508,6 +515,7 @@ def test_format_initializer_metadata_with_env_vars(self, capsys) -> None: def test_format_initializer_metadata_with_description(self, capsys) -> None: """Test format_initializer_metadata with description.""" initializer_metadata = InitializerMetadata( + identifier_type="class", name="test_init", class_name="TestInit", class_module="test.initializers", diff --git a/tests/unit/registry/test_base.py b/tests/unit/registry/test_base.py index 2c1f739d3..feb8e326a 100644 --- a/tests/unit/registry/test_base.py +++ b/tests/unit/registry/test_base.py @@ -22,6 +22,7 @@ class TestMatchesFilters: def test_matches_filters_exact_match_string(self): """Test that exact string matches work.""" metadata = Identifier( + identifier_type="class", name="test_item", class_name="TestClass", class_module="test.module", @@ -33,6 +34,7 @@ def test_matches_filters_exact_match_string(self): def test_matches_filters_no_match_string(self): """Test that non-matching strings return False.""" metadata = Identifier( + identifier_type="class", name="test_item", class_name="TestClass", class_module="test.module", @@ -44,6 +46,7 @@ def test_matches_filters_no_match_string(self): def test_matches_filters_multiple_filters_all_match(self): """Test that all filters must match.""" metadata = Identifier( + identifier_type="class", name="test_item", class_name="TestClass", class_module="test.module", @@ -54,6 +57,7 @@ def test_matches_filters_multiple_filters_all_match(self): def test_matches_filters_multiple_filters_partial_match(self): """Test that partial matches return False when not all filters match.""" metadata = Identifier( + identifier_type="class", name="test_item", class_name="TestClass", class_module="test.module", @@ -64,6 +68,7 @@ def test_matches_filters_multiple_filters_partial_match(self): def test_matches_filters_key_not_in_metadata(self): """Test that filtering on a non-existent key returns False.""" metadata = Identifier( + identifier_type="class", name="test_item", class_name="TestClass", class_module="test.module", @@ -74,6 +79,7 @@ def test_matches_filters_key_not_in_metadata(self): def test_matches_filters_empty_filters(self): """Test that empty filters return True.""" metadata = Identifier( + identifier_type="class", name="test_item", class_name="TestClass", class_module="test.module", @@ -84,6 +90,7 @@ def test_matches_filters_empty_filters(self): def test_matches_filters_list_value_contains_filter(self): """Test filtering when metadata value is a list and filter value is in the list.""" metadata = MetadataWithTags( + identifier_type="class", name="test_item", class_name="TestClass", class_module="test.module", @@ -96,6 +103,7 @@ def test_matches_filters_list_value_contains_filter(self): def test_matches_filters_list_value_not_contains_filter(self): """Test filtering when metadata value is a list and filter value is not in the list.""" metadata = MetadataWithTags( + identifier_type="class", name="test_item", class_name="TestClass", class_module="test.module", @@ -107,6 +115,7 @@ def test_matches_filters_list_value_not_contains_filter(self): def test_matches_filters_exclude_exact_match(self): """Test that exclude filters work for exact matches.""" metadata = Identifier( + identifier_type="class", name="test_item", class_name="TestClass", class_module="test.module", @@ -118,6 +127,7 @@ def test_matches_filters_exclude_exact_match(self): def test_matches_filters_exclude_list_value(self): """Test exclude filters work for list values.""" metadata = MetadataWithTags( + identifier_type="class", name="test_item", class_name="TestClass", class_module="test.module", @@ -130,6 +140,7 @@ def test_matches_filters_exclude_list_value(self): def test_matches_filters_exclude_nonexistent_key(self): """Test that exclude filters for non-existent keys don't exclude the item.""" metadata = Identifier( + identifier_type="class", name="test_item", class_name="TestClass", class_module="test.module", @@ -141,6 +152,7 @@ def test_matches_filters_exclude_nonexistent_key(self): def test_matches_filters_combined_include_and_exclude(self): """Test combined include and exclude filters.""" metadata = Identifier( + identifier_type="class", name="test_item", class_name="TestClass", class_module="test.module", @@ -175,11 +187,13 @@ class TestIdentifier: def test_identifier_creation(self): """Test creating an Identifier instance.""" metadata = Identifier( + identifier_type="class", name="test_scorer", class_name="TestScorer", class_module="pyrit.test.scorer", class_description="A test scorer for testing", ) + assert metadata.identifier_type == "class" assert metadata.name == "test_scorer" assert metadata.class_name == "TestScorer" assert metadata.class_module == "pyrit.test.scorer" @@ -188,6 +202,7 @@ def test_identifier_creation(self): def test_identifier_is_frozen(self): """Test that Identifier is immutable.""" metadata = Identifier( + identifier_type="class", name="test_item", class_name="TestClass", class_module="test.module", @@ -200,6 +215,7 @@ def test_identifier_is_frozen(self): def test_identifier_hash_computed_at_creation(self): """Test that hash is computed when the Identifier is created.""" identifier = Identifier( + identifier_type="instance", name="test_item", class_name="TestClass", class_module="test.module", @@ -211,12 +227,14 @@ def test_identifier_hash_computed_at_creation(self): def test_identifier_hash_is_deterministic(self): """Test that the same inputs produce the same hash.""" identifier1 = Identifier( + identifier_type="class", name="test_item", class_name="TestClass", class_module="test.module", class_description="A test description", ) identifier2 = Identifier( + identifier_type="class", name="test_item", class_name="TestClass", class_module="test.module", @@ -227,12 +245,14 @@ def test_identifier_hash_is_deterministic(self): def test_identifier_hash_differs_for_different_inputs(self): """Test that different inputs produce different hashes.""" identifier1 = Identifier( + identifier_type="class", name="test_item", class_name="TestClass", class_module="test.module", class_description="A test description", ) identifier2 = Identifier( + identifier_type="class", name="different_item", class_name="TestClass", class_module="test.module", @@ -243,6 +263,7 @@ def test_identifier_hash_differs_for_different_inputs(self): def test_identifier_hash_is_immutable(self): """Test that the hash cannot be modified.""" identifier = Identifier( + identifier_type="class", name="test_item", class_name="TestClass", class_module="test.module", @@ -254,6 +275,7 @@ def test_identifier_hash_is_immutable(self): def test_identifier_subclass_inherits_hash(self): """Test that subclasses of Identifier also get a computed hash.""" metadata = MetadataWithTags( + identifier_type="class", name="test_item", class_name="TestClass", class_module="test.module", diff --git a/tests/unit/registry/test_base_instance_registry.py b/tests/unit/registry/test_base_instance_registry.py index 4760c1f05..075c6756a 100644 --- a/tests/unit/registry/test_base_instance_registry.py +++ b/tests/unit/registry/test_base_instance_registry.py @@ -20,6 +20,7 @@ class ConcreteTestRegistry(BaseInstanceRegistry[str, SampleItemMetadata]): def _build_metadata(self, name: str, instance: str) -> SampleItemMetadata: """Build test metadata from a string instance.""" return SampleItemMetadata( + identifier_type="instance", name=name, class_name="str", class_module="builtins", diff --git a/tests/unit/registry/test_scorer_registry.py b/tests/unit/registry/test_scorer_registry.py index edd7e103f..31feabc2a 100644 --- a/tests/unit/registry/test_scorer_registry.py +++ b/tests/unit/registry/test_scorer_registry.py @@ -360,6 +360,7 @@ def test_scorer_metadata_has_required_fields(self): mock_identifier = ScorerIdentifier(type="test_type") metadata = ScorerMetadata( + identifier_type="instance", name="test_scorer", class_name="TestScorer", class_module="test.module", @@ -368,6 +369,7 @@ def test_scorer_metadata_has_required_fields(self): scorer_identifier=mock_identifier, ) + assert metadata.identifier_type == "instance" assert metadata.name == "test_scorer" assert metadata.class_name == "TestScorer" assert metadata.class_module == "test.module" From 85b0c978cc2e98fbaff32bd2b98daef5ec364c2b Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Fri, 23 Jan 2026 14:59:00 -0800 Subject: [PATCH 4/5] pr feedback --- pyrit/models/identifiers.py | 24 +- tests/unit/models/test_identifiers.py | 331 ++++++++++++++++++++++++++ 2 files changed, 350 insertions(+), 5 deletions(-) create mode 100644 tests/unit/models/test_identifiers.py diff --git a/pyrit/models/identifiers.py b/pyrit/models/identifiers.py index 8bc8b141b..f4a758b76 100644 --- a/pyrit/models/identifiers.py +++ b/pyrit/models/identifiers.py @@ -45,8 +45,11 @@ class Identifier: name: str # The snake_case identifier name (e.g., "self_ask_refusal") class_name: str # The actual class name, equivalent to __type__ (e.g., "SelfAskRefusalScorer") class_module: str # The module path, equivalent to __module__ (e.g., "pyrit.score.self_ask_refusal_scorer") - class_description: str # Description from docstring or manual override - identifier_type: IdentifierType # Whether this identifies a "class" or "instance" + + class_description: str = field(metadata={"exclude_from_storage": True}) + + # Whether this identifies a "class" or "instance" + identifier_type: IdentifierType = field(metadata={"exclude_from_storage": True}) hash: str = field(init=False, compare=False) def __post_init__(self) -> None: @@ -56,17 +59,28 @@ def __post_init__(self) -> None: def _compute_hash(self) -> str: """ - Compute a stable SHA256 hash from all identifier fields. + Compute a stable SHA256 hash from storable identifier fields. - All fields except 'hash' itself are automatically included. + Fields marked with metadata={"exclude_from_storage": True} and 'hash' itself + are excluded from the hash computation. Returns: A hex string of the SHA256 hash. """ - hashable_dict: dict[str, Any] = {f.name: getattr(self, f.name) for f in fields(self) if f.name != "hash"} + hashable_dict: dict[str, Any] = { + f.name: getattr(self, f.name) + for f in fields(self) + if f.name != "hash" and not f.metadata.get("exclude_from_storage", False) + } config_json = json.dumps(hashable_dict, sort_keys=True, separators=(",", ":"), default=_dataclass_encoder) return hashlib.sha256(config_json.encode("utf-8")).hexdigest() + def to_storage_dict(self) -> dict[str, Any]: + """Return only fields suitable for DB storage.""" + return { + f.name: getattr(self, f.name) for f in fields(self) if not f.metadata.get("exclude_from_storage", False) + } + def _dataclass_encoder(obj: Any) -> Any: """JSON encoder that handles dataclasses by converting them to dicts.""" diff --git a/tests/unit/models/test_identifiers.py b/tests/unit/models/test_identifiers.py new file mode 100644 index 000000000..ec5cc52ae --- /dev/null +++ b/tests/unit/models/test_identifiers.py @@ -0,0 +1,331 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from dataclasses import dataclass, field + +import pytest + +from pyrit.models.identifiers import Identifiable, Identifier, IdentifierType + + +class TestIdentifiable: + """Tests for the Identifiable abstract base class.""" + + def test_identifiable_get_identifier_is_abstract(self): + """Test that get_identifier is an abstract method that must be implemented.""" + + class ConcreteIdentifiable(Identifiable): + def get_identifier(self) -> dict[str, str]: + return {"type": "test", "name": "example"} + + obj = ConcreteIdentifiable() + result = obj.get_identifier() + assert result == {"type": "test", "name": "example"} + + def test_identifiable_str_returns_identifier(self): + """Test that __str__ returns the get_identifier method reference.""" + + class ConcreteIdentifiable(Identifiable): + def get_identifier(self) -> dict[str, str]: + return {"type": "test"} + + obj = ConcreteIdentifiable() + # __str__ returns the method reference string + assert "get_identifier" in str(obj) + + +class TestIdentifier: + """Tests for the Identifier dataclass.""" + + def test_identifier_creation(self): + """Test creating an Identifier instance with all required fields.""" + identifier = Identifier( + identifier_type="class", + name="test_scorer", + class_name="TestScorer", + class_module="pyrit.test.scorer", + class_description="A test scorer for testing", + ) + assert identifier.identifier_type == "class" + assert identifier.name == "test_scorer" + assert identifier.class_name == "TestScorer" + assert identifier.class_module == "pyrit.test.scorer" + assert identifier.class_description == "A test scorer for testing" + + def test_identifier_is_frozen(self): + """Test that Identifier is immutable.""" + identifier = Identifier( + identifier_type="instance", + name="test_item", + class_name="TestClass", + class_module="test.module", + class_description="Description here", + ) + + with pytest.raises(AttributeError): + identifier.name = "new_name" # type: ignore[misc] + + def test_identifier_type_literal_class(self): + """Test identifier_type with 'class' value.""" + identifier = Identifier( + identifier_type="class", + name="test", + class_name="Test", + class_module="test", + class_description="", + ) + assert identifier.identifier_type == "class" + + def test_identifier_type_literal_instance(self): + """Test identifier_type with 'instance' value.""" + identifier = Identifier( + identifier_type="instance", + name="test", + class_name="Test", + class_module="test", + class_description="", + ) + assert identifier.identifier_type == "instance" + + +class TestIdentifierHash: + """Tests for Identifier hash computation.""" + + def test_hash_computed_at_creation(self): + """Test that hash is computed when the Identifier is created.""" + identifier = Identifier( + identifier_type="class", + name="test_item", + class_name="TestClass", + class_module="test.module", + class_description="A test description", + ) + assert identifier.hash is not None + assert len(identifier.hash) == 64 # SHA256 hex length + + def test_hash_is_deterministic(self): + """Test that the same inputs produce the same hash.""" + identifier1 = Identifier( + identifier_type="class", + name="test_item", + class_name="TestClass", + class_module="test.module", + class_description="A test description", + ) + identifier2 = Identifier( + identifier_type="class", + name="test_item", + class_name="TestClass", + class_module="test.module", + class_description="A test description", + ) + assert identifier1.hash == identifier2.hash + + def test_hash_differs_for_different_storable_fields(self): + """Test that different storable field values produce different hashes.""" + identifier1 = Identifier( + identifier_type="class", + name="test_item", + class_name="TestClass", + class_module="test.module", + class_description="Description", + ) + identifier2 = Identifier( + identifier_type="class", + name="different_item", + class_name="TestClass", + class_module="test.module", + class_description="Description", + ) + assert identifier1.hash != identifier2.hash + + def test_hash_excludes_class_description(self): + """Test that class_description is excluded from hash computation.""" + identifier1 = Identifier( + identifier_type="class", + name="test_item", + class_name="TestClass", + class_module="test.module", + class_description="First description", + ) + identifier2 = Identifier( + identifier_type="class", + name="test_item", + class_name="TestClass", + class_module="test.module", + class_description="Completely different description", + ) + # Hash should be the same since class_description is excluded + assert identifier1.hash == identifier2.hash + + def test_hash_excludes_identifier_type(self): + """Test that identifier_type is excluded from hash computation.""" + identifier1 = Identifier( + identifier_type="class", + name="test_item", + class_name="TestClass", + class_module="test.module", + class_description="Description", + ) + identifier2 = Identifier( + identifier_type="instance", + name="test_item", + class_name="TestClass", + class_module="test.module", + class_description="Description", + ) + # Hash should be the same since identifier_type is excluded + assert identifier1.hash == identifier2.hash + + def test_hash_is_immutable(self): + """Test that the hash cannot be modified.""" + identifier = Identifier( + identifier_type="class", + name="test_item", + class_name="TestClass", + class_module="test.module", + class_description="A test description", + ) + with pytest.raises(AttributeError): + identifier.hash = "new_hash" # type: ignore[misc] + + +class TestIdentifierStorage: + """Tests for Identifier storage functionality.""" + + def test_to_storage_dict_excludes_marked_fields(self): + """Test that to_storage_dict excludes fields marked with exclude_from_storage.""" + identifier = Identifier( + identifier_type="class", + name="test_item", + class_name="TestClass", + class_module="test.module", + class_description="A test description", + ) + storage_dict = identifier.to_storage_dict() + + # Should include storable fields + assert "name" in storage_dict + assert "class_name" in storage_dict + assert "class_module" in storage_dict + assert "hash" in storage_dict + + # Should exclude non-storable fields + assert "class_description" not in storage_dict + assert "identifier_type" not in storage_dict + + def test_to_storage_dict_values_match(self): + """Test that to_storage_dict values match the original identifier.""" + identifier = Identifier( + identifier_type="instance", + name="my_scorer", + class_name="MyScorer", + class_module="pyrit.score.my_scorer", + class_description="My custom scorer", + ) + storage_dict = identifier.to_storage_dict() + + assert storage_dict["name"] == "my_scorer" + assert storage_dict["class_name"] == "MyScorer" + assert storage_dict["class_module"] == "pyrit.score.my_scorer" + assert storage_dict["hash"] == identifier.hash + + +class TestIdentifierSubclass: + """Tests for Identifier subclassing behavior.""" + + def test_subclass_inherits_hash_computation(self): + """Test that subclasses of Identifier also get a computed hash.""" + + @dataclass(frozen=True) + class ExtendedIdentifier(Identifier): + extra_field: str + + extended = ExtendedIdentifier( + identifier_type="class", + name="test_item", + class_name="TestClass", + class_module="test.module", + class_description="A test description", + extra_field="extra_value", + ) + assert extended.hash is not None + assert len(extended.hash) == 64 + + def test_subclass_extra_fields_included_in_hash(self): + """Test that subclass extra fields (not marked) are included in hash.""" + + @dataclass(frozen=True) + class ExtendedIdentifier(Identifier): + extra_field: str + + extended1 = ExtendedIdentifier( + identifier_type="class", + name="test_item", + class_name="TestClass", + class_module="test.module", + class_description="Description", + extra_field="value1", + ) + extended2 = ExtendedIdentifier( + identifier_type="class", + name="test_item", + class_name="TestClass", + class_module="test.module", + class_description="Description", + extra_field="value2", + ) + # Different extra_field values should produce different hashes + assert extended1.hash != extended2.hash + + def test_subclass_excluded_fields_not_in_hash(self): + """Test that subclass fields marked exclude_from_storage are excluded from hash.""" + + @dataclass(frozen=True) + class ExtendedIdentifier(Identifier): + display_only: str = field(metadata={"exclude_from_storage": True}) + + extended1 = ExtendedIdentifier( + identifier_type="class", + name="test_item", + class_name="TestClass", + class_module="test.module", + class_description="Description", + display_only="display1", + ) + extended2 = ExtendedIdentifier( + identifier_type="class", + name="test_item", + class_name="TestClass", + class_module="test.module", + class_description="Description", + display_only="display2", + ) + # display_only is excluded, so hashes should match + assert extended1.hash == extended2.hash + + def test_subclass_to_storage_dict_includes_extra_storable_fields(self): + """Test that to_storage_dict includes subclass storable fields.""" + + @dataclass(frozen=True) + class ExtendedIdentifier(Identifier): + extra_field: str + display_only: str = field(metadata={"exclude_from_storage": True}) + + extended = ExtendedIdentifier( + identifier_type="class", + name="test_item", + class_name="TestClass", + class_module="test.module", + class_description="Description", + extra_field="extra_value", + display_only="display_value", + ) + storage_dict = extended.to_storage_dict() + + # Extra storable field should be included + assert "extra_field" in storage_dict + assert storage_dict["extra_field"] == "extra_value" + + # Display-only field should be excluded + assert "display_only" not in storage_dict From c512b64d3750c4fa949fe1600d3eb56b4eea6ba4 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Fri, 23 Jan 2026 15:12:07 -0800 Subject: [PATCH 5/5] pre-commit --- tests/unit/models/test_identifiers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/models/test_identifiers.py b/tests/unit/models/test_identifiers.py index ec5cc52ae..845a65dce 100644 --- a/tests/unit/models/test_identifiers.py +++ b/tests/unit/models/test_identifiers.py @@ -5,7 +5,7 @@ import pytest -from pyrit.models.identifiers import Identifiable, Identifier, IdentifierType +from pyrit.models.identifiers import Identifiable, Identifier class TestIdentifiable: