From c2b172e5b63fe78f839c3990d29930f48dfa05e6 Mon Sep 17 00:00:00 2001 From: Maifee Ul Asad Date: Sun, 25 Jan 2026 15:51:54 +0600 Subject: [PATCH 1/4] FEAT: implement batching for memory interface (#845) - add `_SQLITE_MAX_BIND_VARS` constant to handle SQLite limits - implement batching in `get_scores()` for `score_ids` parameter - implement batching in `get_message_pieces()` for `prompt_ids`, `original_values`, `converted_values`, `converted_value_sha256` - implement batching in `get_attack_results()` for `attack_result_ids`, `objective_sha256 parameters` - implement batching in `get_scenario_results()` for `scenario_result_ids` - refactor necessary filter conditions across batched queries - handle empty list edge cases --- pyrit/memory/memory_interface.py | 155 ++++++++++++++++++++++++------- 1 file changed, 120 insertions(+), 35 deletions(-) diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 9fc682d6d..54ba450a0 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -50,6 +50,9 @@ logger = logging.getLogger(__name__) +# ref: https://www.sqlite.org/limits.html +# Lowest default maximum is 999, intentionally setting it to half +_SQLITE_MAX_BIND_VARS = 500 Model = TypeVar("Model") @@ -361,10 +364,9 @@ def get_scores( Returns: Sequence[Score]: A list of Score objects that match the specified filters. """ + # Build base conditions without score_ids, we will handle that with batching conditions: list[Any] = [] - if score_ids: - conditions.append(ScoreEntry.id.in_(score_ids)) if score_type: conditions.append(ScoreEntry.score_type == score_type) if score_category: @@ -374,6 +376,18 @@ def get_scores( if sent_before: conditions.append(ScoreEntry.timestamp <= sent_before) + # Handle score_ids with batching to avoid SQLite bind variable limits + if score_ids: + all_entries: list[ScoreEntry] = [] + for i in range(0, len(score_ids), _SQLITE_MAX_BIND_VARS): + batch = score_ids[i : i + _SQLITE_MAX_BIND_VARS] + batch_conditions = conditions + [ScoreEntry.id.in_(batch)] + batch_entries: Sequence[ScoreEntry] = self._query_entries( + ScoreEntry, conditions=and_(*batch_conditions) + ) + all_entries.extend(batch_entries) + return [entry.get_score() for entry in all_entries] + if not conditions: return [] @@ -532,6 +546,7 @@ def get_message_pieces( Exception: If there is an error retrieving the prompts, an exception is logged and an empty list is returned. """ + # Build base conditions (without parameters that may need batching) conditions = [] if attack_id: conditions.append(self._get_message_pieces_attack_conditions(attack_id=str(attack_id))) @@ -539,9 +554,6 @@ def get_message_pieces( conditions.append(PromptMemoryEntry.role == role) if conversation_id: conditions.append(PromptMemoryEntry.conversation_id == str(conversation_id)) - if prompt_ids: - prompt_ids = [str(pi) for pi in prompt_ids] - conditions.append(PromptMemoryEntry.id.in_(prompt_ids)) if labels: conditions.extend(self._get_message_pieces_memory_label_conditions(memory_labels=labels)) if prompt_metadata: @@ -550,21 +562,59 @@ def get_message_pieces( conditions.append(PromptMemoryEntry.timestamp >= sent_after) if sent_before: conditions.append(PromptMemoryEntry.timestamp <= sent_before) - if original_values: - conditions.append(PromptMemoryEntry.original_value.in_(original_values)) - if converted_values: - conditions.append(PromptMemoryEntry.converted_value.in_(converted_values)) if data_type: conditions.append(PromptMemoryEntry.converted_value_data_type == data_type) if not_data_type: conditions.append(PromptMemoryEntry.converted_value_data_type != not_data_type) - if converted_value_sha256: + + # Identify which parameter needs batching (prioritize the one provided) + batch_param = None + batch_values = None + batch_column = None + + if prompt_ids: + batch_param = "prompt_ids" + batch_values = [str(pi) for pi in prompt_ids] + batch_column = PromptMemoryEntry.id + elif original_values and len(original_values) > _SQLITE_MAX_BIND_VARS: + batch_param = "original_values" + batch_values = list(original_values) + batch_column = PromptMemoryEntry.original_value + elif converted_values and len(converted_values) > _SQLITE_MAX_BIND_VARS: + batch_param = "converted_values" + batch_values = list(converted_values) + batch_column = PromptMemoryEntry.converted_value + elif converted_value_sha256 and len(converted_value_sha256) > _SQLITE_MAX_BIND_VARS: + batch_param = "converted_value_sha256" + batch_values = list(converted_value_sha256) + batch_column = PromptMemoryEntry.converted_value_sha256 + + # Add non-batched IN conditions + if original_values and batch_param != "original_values": + conditions.append(PromptMemoryEntry.original_value.in_(original_values)) + if converted_values and batch_param != "converted_values": + conditions.append(PromptMemoryEntry.converted_value.in_(converted_values)) + if converted_value_sha256 and batch_param != "converted_value_sha256": conditions.append(PromptMemoryEntry.converted_value_sha256.in_(converted_value_sha256)) try: - memory_entries: Sequence[PromptMemoryEntry] = self._query_entries( - PromptMemoryEntry, conditions=and_(*conditions) if conditions else None, join_scores=True - ) + if batch_values: + all_entries: MutableSequence[PromptMemoryEntry] = [] + for i in range(0, len(batch_values), _SQLITE_MAX_BIND_VARS): + batch = batch_values[i : i + _SQLITE_MAX_BIND_VARS] + batch_conditions = conditions + [batch_column.in_(batch)] + batch_entries: Sequence[PromptMemoryEntry] = self._query_entries( + PromptMemoryEntry, + conditions=and_(*batch_conditions) if batch_conditions else None, + join_scores=True, + ) + all_entries.extend(batch_entries) + memory_entries = all_entries + else: + memory_entries = self._query_entries( + PromptMemoryEntry, conditions=and_(*conditions) if conditions else None, join_scores=True + ) + message_pieces = [memory_entry.get_message_piece() for memory_entry in memory_entries] return sort_message_pieces(message_pieces=message_pieces) except Exception as e: @@ -1238,34 +1288,62 @@ def get_attack_results( Returns: Sequence[AttackResult]: A list of AttackResult objects that match the specified filters. """ + # Build base conditions (without parameters that may need batching) conditions: list[ColumnElement[bool]] = [] - if attack_result_ids is not None: - if len(attack_result_ids) == 0: - # Empty list means no results - return [] - conditions.append(AttackResultEntry.id.in_(attack_result_ids)) if conversation_id: conditions.append(AttackResultEntry.conversation_id == conversation_id) if objective: conditions.append(AttackResultEntry.objective.contains(objective)) - - if objective_sha256: - conditions.append(AttackResultEntry.objective_sha256.in_(objective_sha256)) if outcome: conditions.append(AttackResultEntry.outcome == outcome) if targeted_harm_categories: - # Use database-specific JSON query method conditions.append( self._get_attack_result_harm_category_condition(targeted_harm_categories=targeted_harm_categories) ) if labels: - # Use database-specific JSON query method conditions.append(self._get_attack_result_label_condition(labels=labels)) + # Handle empty lists + if attack_result_ids is not None and len(attack_result_ids) == 0: + return [] + if objective_sha256 is not None and len(objective_sha256) == 0: + return [] + + # Identify which parameter needs batching + batch_values = None + batch_column = None + batch_param_name = None + + if attack_result_ids and len(attack_result_ids) > _SQLITE_MAX_BIND_VARS: + batch_values = list(attack_result_ids) + batch_column = AttackResultEntry.id + batch_param_name = "attack_result_ids" + elif objective_sha256 and len(objective_sha256) > _SQLITE_MAX_BIND_VARS: + batch_values = list(objective_sha256) + batch_column = AttackResultEntry.objective_sha256 + batch_param_name = "objective_sha256" + + # Add non-batched IN conditions + if attack_result_ids and batch_param_name != "attack_result_ids": + conditions.append(AttackResultEntry.id.in_(attack_result_ids)) + if objective_sha256 and batch_param_name != "objective_sha256": + conditions.append(AttackResultEntry.objective_sha256.in_(objective_sha256)) + try: + if batch_values: + all_entries: list[AttackResultEntry] = [] + for i in range(0, len(batch_values), _SQLITE_MAX_BIND_VARS): + batch = batch_values[i : i + _SQLITE_MAX_BIND_VARS] + batch_conditions = list(conditions) + [batch_column.in_(batch)] + batch_entries: Sequence[AttackResultEntry] = self._query_entries( + AttackResultEntry, conditions=and_(*batch_conditions) if batch_conditions else None + ) + all_entries.extend(batch_entries) + return [entry.get_attack_result() for entry in all_entries] + entries: Sequence[AttackResultEntry] = self._query_entries( AttackResultEntry, conditions=and_(*conditions) if conditions else None ) @@ -1426,18 +1504,13 @@ def get_scenario_results( Returns: Sequence[ScenarioResult]: A list of ScenarioResult objects that match the specified filters. """ - conditions: list[ColumnElement[bool]] = [] + # Handle empty list + if scenario_result_ids is not None and len(scenario_result_ids) == 0: + return [] - if scenario_result_ids is not None: - if len(scenario_result_ids) == 0: - # Empty list means no results - return [] - conditions.append(ScenarioResultEntry.id.in_(scenario_result_ids)) + conditions: list[ColumnElement[bool]] = [] if scenario_name: - # Normalize CLI snake_case names (e.g., "foundry" or "content_harms") - # to class names (e.g., "Foundry" or "ContentHarms") - # This allows users to query with either format normalized_name = ScenarioResult.normalize_scenario_name(scenario_name) conditions.append(ScenarioResultEntry.scenario_name.contains(normalized_name)) @@ -1466,9 +1539,21 @@ def get_scenario_results( conditions.append(self._get_scenario_result_target_model_condition(model_name=objective_target_model_name)) try: - entries: Sequence[ScenarioResultEntry] = self._query_entries( - ScenarioResultEntry, conditions=and_(*conditions) if conditions else None - ) + # Handle scenario_result_ids with batching if needed + if scenario_result_ids and len(scenario_result_ids) > _SQLITE_MAX_BIND_VARS: + all_entries: MutableSequence[ScenarioResultEntry] = [] + for i in range(0, len(scenario_result_ids), _SQLITE_MAX_BIND_VARS): + batch = list(scenario_result_ids)[i : i + _SQLITE_MAX_BIND_VARS] + batch_conditions = list(conditions) + [ScenarioResultEntry.id.in_(batch)] + batch_entries: Sequence[ScenarioResultEntry] = self._query_entries( + ScenarioResultEntry, conditions=and_(*batch_conditions) if batch_conditions else None + ) + all_entries.extend(batch_entries) + entries = all_entries + else: + if scenario_result_ids: + conditions.append(ScenarioResultEntry.id.in_(scenario_result_ids)) + entries = self._query_entries(ScenarioResultEntry, conditions=and_(*conditions) if conditions else None) # Convert entries to ScenarioResults and populate attack_results efficiently scenario_results = [] From bd53d22431e432595ee183379ae791a87338c963 Mon Sep 17 00:00:00 2001 From: Maifee Ul Asad Date: Sun, 25 Jan 2026 15:53:10 +0600 Subject: [PATCH 2/4] TEST: batching scale tests for memory interface SQLite limits (#845) ``` python -m pytest tests/unit/memory/memory_interface/test_batching_scale.py -v ``` --- .../memory_interface/test_batching_scale.py | 179 ++++++++++++++++++ 1 file changed, 179 insertions(+) create mode 100644 tests/unit/memory/memory_interface/test_batching_scale.py diff --git a/tests/unit/memory/memory_interface/test_batching_scale.py b/tests/unit/memory/memory_interface/test_batching_scale.py new file mode 100644 index 000000000..2862977be --- /dev/null +++ b/tests/unit/memory/memory_interface/test_batching_scale.py @@ -0,0 +1,179 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Tests for batching functionality to handle large numbers of IDs. +This addresses the scaling bug where methods like get_scores_by_prompt_ids +fail when querying with many IDs due to SQLite bind variable limits. +""" + +import uuid + +from pyrit.memory import MemoryInterface +from pyrit.memory.memory_interface import _SQLITE_MAX_BIND_VARS +from pyrit.models import MessagePiece, Score + + +def _create_message_piece(conversation_id: str = None, original_value: str = "test message") -> MessagePiece: + """Create a sample message piece for testing.""" + return MessagePiece( + id=str(uuid.uuid4()), + role="user", + original_value=original_value, + converted_value=original_value, + sequence=0, + conversation_id=conversation_id or str(uuid.uuid4()), + labels={"test": "label"}, + attack_identifier={"id": str(uuid.uuid4())}, + ) + + +def _create_score(message_piece_id: str) -> Score: + """Create a sample score for testing.""" + return Score( + score_value="0.5", + score_value_description="test score", + score_type="float_scale", + score_category=["test"], + score_rationale="test rationale", + score_metadata={}, + scorer_class_identifier={"__type__": "TestScorer"}, + message_piece_id=message_piece_id, + ) + + +class TestBatchingScale: + """Tests for batching when querying with many IDs.""" + + def test_get_message_pieces_with_many_prompt_ids(self, sqlite_instance: MemoryInterface): + """Test that get_message_pieces works with more IDs than the batch limit.""" + # Create more message pieces than the batch limit + num_pieces = _SQLITE_MAX_BIND_VARS + 100 + pieces = [_create_message_piece() for _ in range(num_pieces)] + + # Add to memory + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) + + # Query with all IDs - this should work with batching + all_ids = [piece.id for piece in pieces] + results = sqlite_instance.get_message_pieces(prompt_ids=all_ids) + + assert len(results) == num_pieces, f"Expected {num_pieces} results, got {len(results)}" + + def test_get_message_pieces_with_exact_batch_size(self, sqlite_instance: MemoryInterface): + """Test that get_message_pieces works with exactly the batch limit.""" + num_pieces = _SQLITE_MAX_BIND_VARS + pieces = [_create_message_piece() for _ in range(num_pieces)] + + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) + + all_ids = [piece.id for piece in pieces] + results = sqlite_instance.get_message_pieces(prompt_ids=all_ids) + + assert len(results) == num_pieces + + def test_get_message_pieces_with_double_batch_size(self, sqlite_instance: MemoryInterface): + """Test that get_message_pieces works with double the batch limit.""" + num_pieces = _SQLITE_MAX_BIND_VARS * 2 + pieces = [_create_message_piece() for _ in range(num_pieces)] + + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) + + all_ids = [piece.id for piece in pieces] + results = sqlite_instance.get_message_pieces(prompt_ids=all_ids) + + assert len(results) == num_pieces + + def test_get_scores_with_many_score_ids(self, sqlite_instance: MemoryInterface): + """Test that get_scores works with more IDs than the batch limit.""" + # Create message pieces first (scores need to reference them) + num_scores = _SQLITE_MAX_BIND_VARS + 100 + pieces = [_create_message_piece() for _ in range(num_scores)] + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) + + # Create and add scores + scores = [_create_score(str(piece.id)) for piece in pieces] + sqlite_instance.add_scores_to_memory(scores=scores) + + # Query with all score IDs - this should work with batching + all_score_ids = [str(score.id) for score in scores] + results = sqlite_instance.get_scores(score_ids=all_score_ids) + + assert len(results) == num_scores, f"Expected {num_scores} results, got {len(results)}" + + def test_get_prompt_scores_with_many_prompt_ids(self, sqlite_instance: MemoryInterface): + """Test that get_prompt_scores works with more prompt IDs than the batch limit.""" + # Create message pieces + num_pieces = _SQLITE_MAX_BIND_VARS + 50 + pieces = [_create_message_piece() for _ in range(num_pieces)] + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) + + # Create and add scores for half of them + num_scores = num_pieces // 2 + scores = [_create_score(str(pieces[i].id)) for i in range(num_scores)] + sqlite_instance.add_scores_to_memory(scores=scores) + + # Query with all prompt IDs - should return scores for pieces that have them + all_prompt_ids = [piece.id for piece in pieces] + results = sqlite_instance.get_prompt_scores(prompt_ids=all_prompt_ids) + + assert len(results) == num_scores, f"Expected {num_scores} results, got {len(results)}" + + def test_get_message_pieces_batching_preserves_other_filters(self, sqlite_instance: MemoryInterface): + """Test that batching still applies other filter conditions correctly.""" + # Create pieces with different roles + num_pieces = _SQLITE_MAX_BIND_VARS + 50 + user_pieces = [_create_message_piece() for _ in range(num_pieces)] + for piece in user_pieces: + piece.role = "user" + + assistant_pieces = [_create_message_piece() for _ in range(50)] + for piece in assistant_pieces: + piece.role = "assistant" + + all_pieces = user_pieces + assistant_pieces + sqlite_instance.add_message_pieces_to_memory(message_pieces=all_pieces) + + # Query with all IDs but filter by role + all_ids = [piece.id for piece in all_pieces] + results = sqlite_instance.get_message_pieces(prompt_ids=all_ids, role="user") + + assert len(results) == num_pieces, f"Expected {num_pieces} user pieces, got {len(results)}" + + def test_get_message_pieces_small_list_still_works(self, sqlite_instance: MemoryInterface): + """Test that small ID lists (under batch limit) still work correctly.""" + num_pieces = 10 + pieces = [_create_message_piece() for _ in range(num_pieces)] + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) + + all_ids = [piece.id for piece in pieces] + results = sqlite_instance.get_message_pieces(prompt_ids=all_ids) + + assert len(results) == num_pieces + + def test_get_message_pieces_with_many_original_values(self, sqlite_instance: MemoryInterface): + """Test that get_message_pieces works with many original_values exceeding batch limit.""" + num_pieces = _SQLITE_MAX_BIND_VARS + 100 + # Create pieces with unique original values + pieces = [_create_message_piece(original_value=f"unique_value_{i}") for i in range(num_pieces)] + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) + + # Query with all original values + all_values = [piece.original_value for piece in pieces] + results = sqlite_instance.get_message_pieces(original_values=all_values) + + assert len(results) == num_pieces, f"Expected {num_pieces} results, got {len(results)}" + + def test_get_message_pieces_with_many_converted_value_sha256(self, sqlite_instance: MemoryInterface): + """Test that get_message_pieces works with many converted_value_sha256 exceeding batch limit.""" + num_pieces = _SQLITE_MAX_BIND_VARS + 100 + pieces = [_create_message_piece(original_value=f"unique_value_{i}") for i in range(num_pieces)] + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) + + # Get SHA256 hashes from stored pieces + stored_pieces = sqlite_instance.get_message_pieces() + all_hashes = [piece.converted_value_sha256 for piece in stored_pieces if piece.converted_value_sha256] + + if len(all_hashes) > _SQLITE_MAX_BIND_VARS: + results = sqlite_instance.get_message_pieces(converted_value_sha256=all_hashes) + assert len(results) == len(all_hashes) From 1022e644e17a31c9aa682793ef06c74594264996 Mon Sep 17 00:00:00 2001 From: Maifee Ul Asad Date: Mon, 26 Jan 2026 17:23:13 +0600 Subject: [PATCH 3/4] FIX: independent batching for memory interface from review (#845) - independent batching of all parameters - extracted batch in condition - helper functions --- pyrit/memory/memory_interface.py | 173 ++++++++++--------------------- 1 file changed, 53 insertions(+), 120 deletions(-) diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 54ba450a0..5a4985852 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -57,6 +57,31 @@ Model = TypeVar("Model") +def _batched_in_condition(column: InstrumentedAttribute, values: Sequence[Any]) -> ColumnElement[bool]: + """ + Create a batched IN condition to avoid SQLite bind variable limits. + + When the number of values exceeds _SQLITE_MAX_BIND_VARS, this function + creates an OR of multiple IN conditions, each with at most _SQLITE_MAX_BIND_VARS values. + + Args: + column: The SQLAlchemy column to filter on. + values: The list of values to filter by. + + Returns: + A SQLAlchemy condition (either a single IN or OR of multiple INs). + """ + if len(values) <= _SQLITE_MAX_BIND_VARS: + return column.in_(values) + + # Batch the values and create OR of IN conditions + conditions = [] + for i in range(0, len(values), _SQLITE_MAX_BIND_VARS): + batch = values[i : i + _SQLITE_MAX_BIND_VARS] + conditions.append(column.in_(batch)) + return or_(*conditions) + + class MemoryInterface(abc.ABC): """ Abstract interface for conversation memory storage systems. @@ -364,9 +389,10 @@ def get_scores( Returns: Sequence[Score]: A list of Score objects that match the specified filters. """ - # Build base conditions without score_ids, we will handle that with batching conditions: list[Any] = [] + if score_ids: + conditions.append(_batched_in_condition(ScoreEntry.id, list(score_ids))) if score_type: conditions.append(ScoreEntry.score_type == score_type) if score_category: @@ -376,18 +402,6 @@ def get_scores( if sent_before: conditions.append(ScoreEntry.timestamp <= sent_before) - # Handle score_ids with batching to avoid SQLite bind variable limits - if score_ids: - all_entries: list[ScoreEntry] = [] - for i in range(0, len(score_ids), _SQLITE_MAX_BIND_VARS): - batch = score_ids[i : i + _SQLITE_MAX_BIND_VARS] - batch_conditions = conditions + [ScoreEntry.id.in_(batch)] - batch_entries: Sequence[ScoreEntry] = self._query_entries( - ScoreEntry, conditions=and_(*batch_conditions) - ) - all_entries.extend(batch_entries) - return [entry.get_score() for entry in all_entries] - if not conditions: return [] @@ -546,7 +560,6 @@ def get_message_pieces( Exception: If there is an error retrieving the prompts, an exception is logged and an empty list is returned. """ - # Build base conditions (without parameters that may need batching) conditions = [] if attack_id: conditions.append(self._get_message_pieces_attack_conditions(attack_id=str(attack_id))) @@ -554,6 +567,8 @@ def get_message_pieces( conditions.append(PromptMemoryEntry.role == role) if conversation_id: conditions.append(PromptMemoryEntry.conversation_id == str(conversation_id)) + if prompt_ids: + conditions.append(_batched_in_condition(PromptMemoryEntry.id, [str(pi) for pi in prompt_ids])) if labels: conditions.extend(self._get_message_pieces_memory_label_conditions(memory_labels=labels)) if prompt_metadata: @@ -562,59 +577,21 @@ def get_message_pieces( conditions.append(PromptMemoryEntry.timestamp >= sent_after) if sent_before: conditions.append(PromptMemoryEntry.timestamp <= sent_before) + if original_values: + conditions.append(_batched_in_condition(PromptMemoryEntry.original_value, list(original_values))) + if converted_values: + conditions.append(_batched_in_condition(PromptMemoryEntry.converted_value, list(converted_values))) if data_type: conditions.append(PromptMemoryEntry.converted_value_data_type == data_type) if not_data_type: conditions.append(PromptMemoryEntry.converted_value_data_type != not_data_type) - - # Identify which parameter needs batching (prioritize the one provided) - batch_param = None - batch_values = None - batch_column = None - - if prompt_ids: - batch_param = "prompt_ids" - batch_values = [str(pi) for pi in prompt_ids] - batch_column = PromptMemoryEntry.id - elif original_values and len(original_values) > _SQLITE_MAX_BIND_VARS: - batch_param = "original_values" - batch_values = list(original_values) - batch_column = PromptMemoryEntry.original_value - elif converted_values and len(converted_values) > _SQLITE_MAX_BIND_VARS: - batch_param = "converted_values" - batch_values = list(converted_values) - batch_column = PromptMemoryEntry.converted_value - elif converted_value_sha256 and len(converted_value_sha256) > _SQLITE_MAX_BIND_VARS: - batch_param = "converted_value_sha256" - batch_values = list(converted_value_sha256) - batch_column = PromptMemoryEntry.converted_value_sha256 - - # Add non-batched IN conditions - if original_values and batch_param != "original_values": - conditions.append(PromptMemoryEntry.original_value.in_(original_values)) - if converted_values and batch_param != "converted_values": - conditions.append(PromptMemoryEntry.converted_value.in_(converted_values)) - if converted_value_sha256 and batch_param != "converted_value_sha256": - conditions.append(PromptMemoryEntry.converted_value_sha256.in_(converted_value_sha256)) + if converted_value_sha256: + conditions.append(_batched_in_condition(PromptMemoryEntry.converted_value_sha256, list(converted_value_sha256))) try: - if batch_values: - all_entries: MutableSequence[PromptMemoryEntry] = [] - for i in range(0, len(batch_values), _SQLITE_MAX_BIND_VARS): - batch = batch_values[i : i + _SQLITE_MAX_BIND_VARS] - batch_conditions = conditions + [batch_column.in_(batch)] - batch_entries: Sequence[PromptMemoryEntry] = self._query_entries( - PromptMemoryEntry, - conditions=and_(*batch_conditions) if batch_conditions else None, - join_scores=True, - ) - all_entries.extend(batch_entries) - memory_entries = all_entries - else: - memory_entries = self._query_entries( - PromptMemoryEntry, conditions=and_(*conditions) if conditions else None, join_scores=True - ) - + memory_entries: Sequence[PromptMemoryEntry] = self._query_entries( + PromptMemoryEntry, conditions=and_(*conditions) if conditions else None, join_scores=True + ) message_pieces = [memory_entry.get_message_piece() for memory_entry in memory_entries] return sort_message_pieces(message_pieces=message_pieces) except Exception as e: @@ -1288,13 +1265,20 @@ def get_attack_results( Returns: Sequence[AttackResult]: A list of AttackResult objects that match the specified filters. """ - # Build base conditions (without parameters that may need batching) conditions: list[ColumnElement[bool]] = [] + if attack_result_ids is not None: + if len(attack_result_ids) == 0: + return [] + conditions.append(_batched_in_condition(AttackResultEntry.id, list(attack_result_ids))) if conversation_id: conditions.append(AttackResultEntry.conversation_id == conversation_id) if objective: conditions.append(AttackResultEntry.objective.contains(objective)) + if objective_sha256: + if len(objective_sha256) == 0: + return [] + conditions.append(_batched_in_condition(AttackResultEntry.objective_sha256, list(objective_sha256))) if outcome: conditions.append(AttackResultEntry.outcome == outcome) @@ -1302,48 +1286,10 @@ def get_attack_results( conditions.append( self._get_attack_result_harm_category_condition(targeted_harm_categories=targeted_harm_categories) ) - if labels: conditions.append(self._get_attack_result_label_condition(labels=labels)) - # Handle empty lists - if attack_result_ids is not None and len(attack_result_ids) == 0: - return [] - if objective_sha256 is not None and len(objective_sha256) == 0: - return [] - - # Identify which parameter needs batching - batch_values = None - batch_column = None - batch_param_name = None - - if attack_result_ids and len(attack_result_ids) > _SQLITE_MAX_BIND_VARS: - batch_values = list(attack_result_ids) - batch_column = AttackResultEntry.id - batch_param_name = "attack_result_ids" - elif objective_sha256 and len(objective_sha256) > _SQLITE_MAX_BIND_VARS: - batch_values = list(objective_sha256) - batch_column = AttackResultEntry.objective_sha256 - batch_param_name = "objective_sha256" - - # Add non-batched IN conditions - if attack_result_ids and batch_param_name != "attack_result_ids": - conditions.append(AttackResultEntry.id.in_(attack_result_ids)) - if objective_sha256 and batch_param_name != "objective_sha256": - conditions.append(AttackResultEntry.objective_sha256.in_(objective_sha256)) - try: - if batch_values: - all_entries: list[AttackResultEntry] = [] - for i in range(0, len(batch_values), _SQLITE_MAX_BIND_VARS): - batch = batch_values[i : i + _SQLITE_MAX_BIND_VARS] - batch_conditions = list(conditions) + [batch_column.in_(batch)] - batch_entries: Sequence[AttackResultEntry] = self._query_entries( - AttackResultEntry, conditions=and_(*batch_conditions) if batch_conditions else None - ) - all_entries.extend(batch_entries) - return [entry.get_attack_result() for entry in all_entries] - entries: Sequence[AttackResultEntry] = self._query_entries( AttackResultEntry, conditions=and_(*conditions) if conditions else None ) @@ -1504,12 +1450,14 @@ def get_scenario_results( Returns: Sequence[ScenarioResult]: A list of ScenarioResult objects that match the specified filters. """ - # Handle empty list if scenario_result_ids is not None and len(scenario_result_ids) == 0: return [] conditions: list[ColumnElement[bool]] = [] + if scenario_result_ids: + conditions.append(_batched_in_condition(ScenarioResultEntry.id, list(scenario_result_ids))) + if scenario_name: normalized_name = ScenarioResult.normalize_scenario_name(scenario_name) conditions.append(ScenarioResultEntry.scenario_name.contains(normalized_name)) @@ -1527,33 +1475,18 @@ def get_scenario_results( conditions.append(ScenarioResultEntry.completion_time <= added_before) if labels: - # Use database-specific JSON query method conditions.append(self._get_scenario_result_label_condition(labels=labels)) if objective_target_endpoint: - # Use database-specific JSON query method conditions.append(self._get_scenario_result_target_endpoint_condition(endpoint=objective_target_endpoint)) if objective_target_model_name: - # Use database-specific JSON query method conditions.append(self._get_scenario_result_target_model_condition(model_name=objective_target_model_name)) try: - # Handle scenario_result_ids with batching if needed - if scenario_result_ids and len(scenario_result_ids) > _SQLITE_MAX_BIND_VARS: - all_entries: MutableSequence[ScenarioResultEntry] = [] - for i in range(0, len(scenario_result_ids), _SQLITE_MAX_BIND_VARS): - batch = list(scenario_result_ids)[i : i + _SQLITE_MAX_BIND_VARS] - batch_conditions = list(conditions) + [ScenarioResultEntry.id.in_(batch)] - batch_entries: Sequence[ScenarioResultEntry] = self._query_entries( - ScenarioResultEntry, conditions=and_(*batch_conditions) if batch_conditions else None - ) - all_entries.extend(batch_entries) - entries = all_entries - else: - if scenario_result_ids: - conditions.append(ScenarioResultEntry.id.in_(scenario_result_ids)) - entries = self._query_entries(ScenarioResultEntry, conditions=and_(*conditions) if conditions else None) + entries: Sequence[ScenarioResultEntry] = self._query_entries( + ScenarioResultEntry, conditions=and_(*conditions) if conditions else None + ) # Convert entries to ScenarioResults and populate attack_results efficiently scenario_results = [] @@ -1571,7 +1504,7 @@ def get_scenario_results( # Query all AttackResults in a single batch if there are any if all_conversation_ids: # Build condition to query multiple conversation IDs at once - attack_conditions = [AttackResultEntry.conversation_id.in_(all_conversation_ids)] + attack_conditions = [_batched_in_condition(AttackResultEntry.conversation_id, all_conversation_ids)] attack_entries: Sequence[AttackResultEntry] = self._query_entries( AttackResultEntry, conditions=and_(*attack_conditions) ) From 6a563df8196b7635eb2e5cf5075614e352827a9e Mon Sep 17 00:00:00 2001 From: Maifee Ul Asad Date: Mon, 26 Jan 2026 17:24:12 +0600 Subject: [PATCH 4/4] TEST: independent batching test for memory interface from review (#845) - tests focusing on independent batching of all parameters --- .../memory_interface/test_batching_scale.py | 321 +++++++++++++++++- 1 file changed, 319 insertions(+), 2 deletions(-) diff --git a/tests/unit/memory/memory_interface/test_batching_scale.py b/tests/unit/memory/memory_interface/test_batching_scale.py index 2862977be..4ba301ee2 100644 --- a/tests/unit/memory/memory_interface/test_batching_scale.py +++ b/tests/unit/memory/memory_interface/test_batching_scale.py @@ -7,20 +7,29 @@ fail when querying with many IDs due to SQLite bind variable limits. """ +import hashlib import uuid +from unittest.mock import MagicMock, patch + +from sqlalchemy import Column, Integer +from sqlalchemy.orm import declarative_base from pyrit.memory import MemoryInterface -from pyrit.memory.memory_interface import _SQLITE_MAX_BIND_VARS +from pyrit.memory.memory_interface import _SQLITE_MAX_BIND_VARS, _batched_in_condition from pyrit.models import MessagePiece, Score def _create_message_piece(conversation_id: str = None, original_value: str = "test message") -> MessagePiece: """Create a sample message piece for testing.""" + converted_value = original_value + # Compute SHA256 for converted_value so filtering by sha256 works + sha256 = hashlib.sha256(converted_value.encode("utf-8")).hexdigest() return MessagePiece( id=str(uuid.uuid4()), role="user", original_value=original_value, - converted_value=original_value, + converted_value=converted_value, + converted_value_sha256=sha256, sequence=0, conversation_id=conversation_id or str(uuid.uuid4()), labels={"test": "label"}, @@ -42,6 +51,157 @@ def _create_score(message_piece_id: str) -> Score: ) +class TestBatchedInCondition: + """Tests for the _batched_in_condition helper function.""" + + def test_batched_in_condition_small_list(self): + """Test that small lists generate a simple IN condition.""" + Base = declarative_base() + + class TestModel(Base): + __tablename__ = "test" + id = Column(Integer, primary_key=True) + + values = list(range(10)) + condition = _batched_in_condition(TestModel.id, values) + + # Should be a simple IN clause, not an OR + assert "IN" in str(condition) + assert "OR" not in str(condition) + + def test_batched_in_condition_exact_batch_size(self): + """Test with exactly _SQLITE_MAX_BIND_VARS values.""" + Base = declarative_base() + + class TestModel(Base): + __tablename__ = "test" + id = Column(Integer, primary_key=True) + + values = list(range(_SQLITE_MAX_BIND_VARS)) + condition = _batched_in_condition(TestModel.id, values) + + # Should still be a simple IN clause at the limit + assert "IN" in str(condition) + # May or may not have OR depending on implementation at boundary + + def test_batched_in_condition_over_batch_size(self): + """Test with values exceeding _SQLITE_MAX_BIND_VARS.""" + Base = declarative_base() + + class TestModel(Base): + __tablename__ = "test" + id = Column(Integer, primary_key=True) + + values = list(range(_SQLITE_MAX_BIND_VARS + 100)) + condition = _batched_in_condition(TestModel.id, values) + + # Should generate OR of multiple IN clauses + condition_str = str(condition) + assert "OR" in condition_str + assert "IN" in condition_str + + def test_batched_in_condition_double_batch_size(self): + """Test with double the batch size.""" + Base = declarative_base() + + class TestModel(Base): + __tablename__ = "test" + id = Column(Integer, primary_key=True) + + values = list(range(_SQLITE_MAX_BIND_VARS * 2)) + condition = _batched_in_condition(TestModel.id, values) + + # Should generate multiple batches + condition_str = str(condition) + assert "OR" in condition_str + # Should have at least 2 IN clauses + assert condition_str.count("IN") >= 2 + + def test_batched_in_condition_three_batches(self): + """Test with enough values to require three batches.""" + Base = declarative_base() + + class TestModel(Base): + __tablename__ = "test" + id = Column(Integer, primary_key=True) + + values = list(range(_SQLITE_MAX_BIND_VARS * 2 + 100)) + condition = _batched_in_condition(TestModel.id, values) + + condition_str = str(condition) + assert "OR" in condition_str + # Should have at least 3 IN clauses + assert condition_str.count("IN") >= 3 + + def test_batched_in_condition_empty_list(self): + """Test with an empty list.""" + Base = declarative_base() + + class TestModel(Base): + __tablename__ = "test" + id = Column(Integer, primary_key=True) + + values = [] + condition = _batched_in_condition(TestModel.id, values) + + # Empty list should still generate valid SQL + condition_str = str(condition) + assert "IN" in condition_str + + def test_batched_in_condition_multiple_columns(self): + """Test combining multiple batched conditions with AND logic.""" + from sqlalchemy import String, and_ + + Base = declarative_base() + + class TestModel(Base): + __tablename__ = "test" + id = Column(Integer, primary_key=True) + name = Column(String) + email = Column(String) + + # Create multiple large value lists for different columns + num_values = (_SQLITE_MAX_BIND_VARS * 2) + 100 + id_values = list(range(num_values)) + name_values = [f"name_{i}" for i in range(num_values)] + email_values = [f"email_{i}@test.com" for i in range(num_values)] + + # Create batched conditions for each column + id_condition = _batched_in_condition(TestModel.id, id_values) + name_condition = _batched_in_condition(TestModel.name, name_values) + email_condition = _batched_in_condition(TestModel.email, email_values) + + # Combine with AND (simulating real query behavior) + combined_condition = and_(id_condition, name_condition, email_condition) + combined_str = str(combined_condition) + + # Verify all three columns are present in the query + assert "id" in combined_str.lower() + assert "name" in combined_str.lower() + assert "email" in combined_str.lower() + + # Verify OR clauses are present (batching is active) + assert combined_str.count("OR") >= 3 # At least one OR per batched column + + # Verify AND combines the conditions + assert "AND" in combined_str + + # Verify `id` count matches expected batches + expected_id_batches = (num_values + _SQLITE_MAX_BIND_VARS - 1) // _SQLITE_MAX_BIND_VARS + actual_id_batches = combined_str.count("id IN") + assert actual_id_batches == expected_id_batches + + # Verify `name` count matches expected batches + expected_name_batches = (num_values + _SQLITE_MAX_BIND_VARS - 1) // _SQLITE_MAX_BIND_VARS + actual_name_batches = combined_str.count("name IN") + assert actual_name_batches == expected_name_batches + + # Verify `email` count matches expected batches + expected_email_batches = (num_values + _SQLITE_MAX_BIND_VARS - 1) // _SQLITE_MAX_BIND_VARS + actual_email_batches = combined_str.count("email IN") + assert actual_email_batches == expected_email_batches + + class TestBatchingScale: """Tests for batching when querying with many IDs.""" @@ -177,3 +337,160 @@ def test_get_message_pieces_with_many_converted_value_sha256(self, sqlite_instan if len(all_hashes) > _SQLITE_MAX_BIND_VARS: results = sqlite_instance.get_message_pieces(converted_value_sha256=all_hashes) assert len(results) == len(all_hashes) + + def test_get_message_pieces_combines_filters_correctly(self, sqlite_instance: MemoryInterface): + """Test that multiple filters can be combined (e.g., prompt_ids AND role).""" + # Create message pieces with different roles + num_pieces = 50 + user_pieces = [_create_message_piece() for _ in range(num_pieces)] + for piece in user_pieces: + piece.role = "user" + + assistant_pieces = [_create_message_piece() for _ in range(num_pieces)] + for piece in assistant_pieces: + piece.role = "assistant" + + all_pieces = user_pieces + assistant_pieces + sqlite_instance.add_message_pieces_to_memory(message_pieces=all_pieces) + + # Query with both prompt_ids AND role filter + user_ids = [piece.id for piece in user_pieces] + results = sqlite_instance.get_message_pieces(prompt_ids=user_ids, role="user") + + # Should return only user pieces (intersection of both filters) + assert len(results) == num_pieces + assert all(r.role == "user" for r in results) + + # Query with role filter and a subset of IDs + subset_ids = user_ids[:10] + results = sqlite_instance.get_message_pieces(prompt_ids=subset_ids, role="user") + assert len(results) == 10 + + def test_get_message_pieces_multiple_large_params_simultaneously(self, sqlite_instance: MemoryInterface): + """Test batching with multiple parameters exceeding batch limit simultaneously.""" + # Create enough pieces to exceed batch limit with unique values + num_pieces = _SQLITE_MAX_BIND_VARS + 200 + pieces = [ + _create_message_piece(original_value=f"original_value_{i}") for i in range(num_pieces) + ] + + # Add to memory + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) + + # Get all stored pieces to extract their IDs and SHA256 hashes + stored_pieces = sqlite_instance.get_message_pieces() + assert len(stored_pieces) >= num_pieces + + # Extract multiple large parameter lists + all_ids = [piece.id for piece in stored_pieces[:num_pieces]] + all_original_values = [piece.original_value for piece in stored_pieces[:num_pieces]] + all_sha256 = [piece.converted_value_sha256 for piece in stored_pieces[:num_pieces]] + + # Query with multiple large parameters simultaneously + # This tests that ALL parameters are batched correctly, not just one + results = sqlite_instance.get_message_pieces( + prompt_ids=all_ids, + original_values=all_original_values, + converted_value_sha256=all_sha256, + ) + + # Should return all pieces that match ALL conditions (intersection) + assert len(results) == num_pieces, ( + f"Expected {num_pieces} results when filtering with multiple large parameters, " + f"got {len(results)}" + ) + + # Verify all returned pieces match all filter criteria + result_ids = {r.id for r in results} + result_original_values = {r.original_value for r in results} + result_sha256 = {r.converted_value_sha256 for r in results} + + assert result_ids == set(all_ids), "Returned IDs don't match filter" + assert result_original_values == set(all_original_values), "Returned original_values don't match filter" + assert result_sha256 == set(all_sha256), "Returned SHA256 hashes don't match filter" + + def test_get_message_pieces_multiple_batched_params_with_query_spy(self, sqlite_instance: MemoryInterface): + """Test that batching generates correct queries when multiple params exceed limit.""" + # Create pieces exceeding batch limit + num_pieces = _SQLITE_MAX_BIND_VARS + 100 + pieces = [ + _create_message_piece(original_value=f"value_{i}") for i in range(num_pieces) + ] + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) + + # Get stored pieces + stored_pieces = sqlite_instance.get_message_pieces() + all_ids = [piece.id for piece in stored_pieces[:num_pieces]] + all_original_values = [piece.original_value for piece in stored_pieces[:num_pieces]] + + # Mock _query_entries to track how it's called + original_query = sqlite_instance._query_entries + call_count = 0 + captured_conditions = [] + + def spy_query(*args, **kwargs): + nonlocal call_count + call_count += 1 + if "conditions" in kwargs and kwargs["conditions"] is not None: + captured_conditions.append(str(kwargs["conditions"])) + return original_query(*args, **kwargs) + + with patch.object(sqlite_instance, "_query_entries", side_effect=spy_query): + results = sqlite_instance.get_message_pieces( + prompt_ids=all_ids, original_values=all_original_values + ) + + # Should get all results despite batching + assert len(results) == num_pieces + + # Should have been called (could be 1 call with OR conditions) + assert call_count >= 1 + + # Verify query conditions include both filters + if captured_conditions: + combined_conditions = " ".join(captured_conditions) + # Both column filters should be present in the query + assert "id" in combined_conditions.lower() or "prompt" in combined_conditions.lower() + assert "original_value" in combined_conditions.lower() + + def test_get_message_pieces_triple_large_params_preserves_intersection(self, sqlite_instance: MemoryInterface): + """Test that filtering with 3 large parameter lists returns correct intersection.""" + # Create a large set of pieces + total_pieces = _SQLITE_MAX_BIND_VARS + 150 + pieces = [ + _create_message_piece( + conversation_id=str(uuid.uuid4()), original_value=f"content_{i}" + ) + for i in range(total_pieces) + ] + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) + + # Get stored pieces + stored_pieces = sqlite_instance.get_message_pieces() + + # Create three overlapping large filter lists + # List 1: All IDs + filter_ids = [p.id for p in stored_pieces[:total_pieces]] + + # List 2: All original values + filter_original_values = [p.original_value for p in stored_pieces[:total_pieces]] + + # List 3: Subset of SHA256 hashes (to test intersection) + subset_size = _SQLITE_MAX_BIND_VARS + 50 + filter_sha256 = [p.converted_value_sha256 for p in stored_pieces[:subset_size]] + + # Query with all three large parameters + results = sqlite_instance.get_message_pieces( + prompt_ids=filter_ids, + original_values=filter_original_values, + converted_value_sha256=filter_sha256, + ) + + # Should return only the intersection (subset_size items) + assert ( + len(results) == subset_size + ), f"Expected {subset_size} results from intersection, got {len(results)}" + + # Verify all results have SHA256 in the filter list + result_sha256 = {r.converted_value_sha256 for r in results} + assert result_sha256.issubset(set(filter_sha256)), "Results contain unexpected SHA256 values"