Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 43 additions & 25 deletions pyrit/memory/memory_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,38 @@

logger = logging.getLogger(__name__)

# ref: https://www.sqlite.org/limits.html
# Lowest default maximum is 999, intentionally setting it to half
_SQLITE_MAX_BIND_VARS = 500

Model = TypeVar("Model")


def _batched_in_condition(column: InstrumentedAttribute, values: Sequence[Any]) -> ColumnElement[bool]:
"""
Create a batched IN condition to avoid SQLite bind variable limits.

When the number of values exceeds _SQLITE_MAX_BIND_VARS, this function
creates an OR of multiple IN conditions, each with at most _SQLITE_MAX_BIND_VARS values.

Args:
column: The SQLAlchemy column to filter on.
values: The list of values to filter by.

Returns:
A SQLAlchemy condition (either a single IN or OR of multiple INs).
"""
if len(values) <= _SQLITE_MAX_BIND_VARS:
return column.in_(values)

# Batch the values and create OR of IN conditions
conditions = []
for i in range(0, len(values), _SQLITE_MAX_BIND_VARS):
batch = values[i : i + _SQLITE_MAX_BIND_VARS]
conditions.append(column.in_(batch))
return or_(*conditions)


class MemoryInterface(abc.ABC):
"""
Abstract interface for conversation memory storage systems.
Expand Down Expand Up @@ -364,7 +392,7 @@ def get_scores(
conditions: list[Any] = []

if score_ids:
conditions.append(ScoreEntry.id.in_(score_ids))
conditions.append(_batched_in_condition(ScoreEntry.id, list(score_ids)))
if score_type:
conditions.append(ScoreEntry.score_type == score_type)
if score_category:
Expand Down Expand Up @@ -540,8 +568,7 @@ def get_message_pieces(
if conversation_id:
conditions.append(PromptMemoryEntry.conversation_id == str(conversation_id))
if prompt_ids:
prompt_ids = [str(pi) for pi in prompt_ids]
conditions.append(PromptMemoryEntry.id.in_(prompt_ids))
conditions.append(_batched_in_condition(PromptMemoryEntry.id, [str(pi) for pi in prompt_ids]))
if labels:
conditions.extend(self._get_message_pieces_memory_label_conditions(memory_labels=labels))
if prompt_metadata:
Expand All @@ -551,15 +578,15 @@ def get_message_pieces(
if sent_before:
conditions.append(PromptMemoryEntry.timestamp <= sent_before)
if original_values:
conditions.append(PromptMemoryEntry.original_value.in_(original_values))
conditions.append(_batched_in_condition(PromptMemoryEntry.original_value, list(original_values)))
if converted_values:
conditions.append(PromptMemoryEntry.converted_value.in_(converted_values))
conditions.append(_batched_in_condition(PromptMemoryEntry.converted_value, list(converted_values)))
if data_type:
conditions.append(PromptMemoryEntry.converted_value_data_type == data_type)
if not_data_type:
conditions.append(PromptMemoryEntry.converted_value_data_type != not_data_type)
if converted_value_sha256:
conditions.append(PromptMemoryEntry.converted_value_sha256.in_(converted_value_sha256))
conditions.append(_batched_in_condition(PromptMemoryEntry.converted_value_sha256, list(converted_value_sha256)))

try:
memory_entries: Sequence[PromptMemoryEntry] = self._query_entries(
Expand Down Expand Up @@ -1242,27 +1269,24 @@ def get_attack_results(

if attack_result_ids is not None:
if len(attack_result_ids) == 0:
# Empty list means no results
return []
conditions.append(AttackResultEntry.id.in_(attack_result_ids))
conditions.append(_batched_in_condition(AttackResultEntry.id, list(attack_result_ids)))
if conversation_id:
conditions.append(AttackResultEntry.conversation_id == conversation_id)
if objective:
conditions.append(AttackResultEntry.objective.contains(objective))

if objective_sha256:
conditions.append(AttackResultEntry.objective_sha256.in_(objective_sha256))
if len(objective_sha256) == 0:
return []
conditions.append(_batched_in_condition(AttackResultEntry.objective_sha256, list(objective_sha256)))
if outcome:
conditions.append(AttackResultEntry.outcome == outcome)

if targeted_harm_categories:
# Use database-specific JSON query method
conditions.append(
self._get_attack_result_harm_category_condition(targeted_harm_categories=targeted_harm_categories)
)

if labels:
# Use database-specific JSON query method
conditions.append(self._get_attack_result_label_condition(labels=labels))

try:
Expand Down Expand Up @@ -1426,18 +1450,15 @@ def get_scenario_results(
Returns:
Sequence[ScenarioResult]: A list of ScenarioResult objects that match the specified filters.
"""
if scenario_result_ids is not None and len(scenario_result_ids) == 0:
return []

conditions: list[ColumnElement[bool]] = []

if scenario_result_ids is not None:
if len(scenario_result_ids) == 0:
# Empty list means no results
return []
conditions.append(ScenarioResultEntry.id.in_(scenario_result_ids))
if scenario_result_ids:
conditions.append(_batched_in_condition(ScenarioResultEntry.id, list(scenario_result_ids)))

if scenario_name:
# Normalize CLI snake_case names (e.g., "foundry" or "content_harms")
# to class names (e.g., "Foundry" or "ContentHarms")
# This allows users to query with either format
normalized_name = ScenarioResult.normalize_scenario_name(scenario_name)
conditions.append(ScenarioResultEntry.scenario_name.contains(normalized_name))

Expand All @@ -1454,15 +1475,12 @@ def get_scenario_results(
conditions.append(ScenarioResultEntry.completion_time <= added_before)

if labels:
# Use database-specific JSON query method
conditions.append(self._get_scenario_result_label_condition(labels=labels))

if objective_target_endpoint:
# Use database-specific JSON query method
conditions.append(self._get_scenario_result_target_endpoint_condition(endpoint=objective_target_endpoint))

if objective_target_model_name:
# Use database-specific JSON query method
conditions.append(self._get_scenario_result_target_model_condition(model_name=objective_target_model_name))

try:
Expand All @@ -1486,7 +1504,7 @@ def get_scenario_results(
# Query all AttackResults in a single batch if there are any
if all_conversation_ids:
# Build condition to query multiple conversation IDs at once
attack_conditions = [AttackResultEntry.conversation_id.in_(all_conversation_ids)]
attack_conditions = [_batched_in_condition(AttackResultEntry.conversation_id, all_conversation_ids)]
attack_entries: Sequence[AttackResultEntry] = self._query_entries(
AttackResultEntry, conditions=and_(*attack_conditions)
)
Expand Down
Loading