Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion backend/app/core/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,9 @@ class BiasCategories(Enum):
Generic = "generic"
Healthcare = "healthcare"
Education = "education"
All = "all"
All = "all"

class GuardrailOnFail(Enum):
Exception = "exception"
Fix = "fix"
Rephrase = "rephrase"
11 changes: 9 additions & 2 deletions backend/app/core/guardrail_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,17 @@
from app.models.guardrail_config import ValidatorConfigItem

def build_guard(validator_items):
validators = [v_item.build() for v_item in validator_items]
validators = []

for v_item in validator_items:
validator = v_item.build(
on_fail=v_item.resolve_on_fail()
)
validators.append(validator)

return Guard().use_many(*validators)

def get_validator_config_models():
annotated_args = get_args(ValidatorConfigItem)
union_type = annotated_args[0]
return get_args(union_type)
return get_args(union_type)
4 changes: 4 additions & 0 deletions backend/app/core/on_fail_actions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from guardrails.validators import FailResult

def rephrase_query_on_fail(value: str, fail_result: FailResult):
return f"Please rephrase the query without unsafe content. {fail_result.error_message}"
4 changes: 2 additions & 2 deletions backend/app/models/ban_list_safety_validator_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ class BanListSafetyValidatorConfig(BaseValidatorConfig):
type: Literal["ban_list"]
banned_words: List[str] #list of banned words to be redacted

def build(self):
def build(self, *, on_fail):
return BanList(
banned_words=self.banned_words,
on_fail=self.resolve_on_fail(),
on_fail=on_fail,
)
30 changes: 15 additions & 15 deletions backend/app/models/base_validator_config.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
from typing import Any, Literal, Optional
from typing import Any, Optional

from guardrails import OnFailAction
from guardrails.validators import Validator
from sqlmodel import SQLModel

ON_FAIL_STR = Literal["exception", "fix", "noop", "reask"]
from app.core.enum import GuardrailOnFail
from app.core.on_fail_actions import rephrase_query_on_fail


_ON_FAIL_MAP = {
GuardrailOnFail.Fix: OnFailAction.FIX,
GuardrailOnFail.Exception: OnFailAction.EXCEPTION,
GuardrailOnFail.Rephrase: rephrase_query_on_fail,
}

class BaseValidatorConfig(SQLModel):
on_fail: Optional[ON_FAIL_STR] = OnFailAction.FIX
on_fail: GuardrailOnFail = GuardrailOnFail.Fix

model_config = {"arbitrary_types_allowed": True}

def resolve_on_fail(self):
if self.on_fail is None:
return None

try:
return OnFailAction[self.on_fail.upper()]
except KeyError:
raise ValueError(
f"Invalid on_fail value: {self.on_fail}. "
"Expected one of: exception, fix, noop, reask"
)

def build(self) -> Any:
return _ON_FAIL_MAP[self.on_fail]

def build(self, *, on_fail) -> Validator:
raise NotImplementedError(
f"{self.__class__.__name__} must implement build()"
)
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ class GenderAssumptionBiasSafetyValidatorConfig(BaseValidatorConfig):
type: Literal["gender_assumption_bias"]
categories: Optional[List[BiasCategories]] = [BiasCategories.All] # preferred category (based on sector)

def build(self):
def build(self, *, on_fail):
return GenderAssumptionBias(
categories=self.categories,
on_fail=self.resolve_on_fail(),
on_fail=on_fail,
)
4 changes: 2 additions & 2 deletions backend/app/models/lexical_slur_safety_validator_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ class LexicalSlurSafetyValidatorConfig(BaseValidatorConfig):
languages: List[str] = ["en", "hi"] # list of languages to check slurs in
severity: Literal["low", "medium", "high", "all"] = "all" # severity level of slurs to check

def build(self):
def build(self, *, on_fail):
return LexicalSlur(
languages=self.languages,
severity=SlurSeverity(self.severity),
on_fail=self.resolve_on_fail(),
on_fail=on_fail,
)
4 changes: 2 additions & 2 deletions backend/app/models/pii_remover_safety_validator_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ class PIIRemoverSafetyValidatorConfig(BaseValidatorConfig):
entity_types: Optional[List[str]] = None # list of PII entity types to remove
threshold: float = 0.5 # confidence threshold for PII detection

def build(self):
def build(self, *, on_fail):
return PIIRemover(
entity_types=self.entity_types,
threshold=self.threshold,
on_fail=self.resolve_on_fail(),
on_fail=on_fail,
)
24 changes: 23 additions & 1 deletion backend/app/tests/test_guardrails_api_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,4 +169,26 @@ def test_input_guardrails_with_validator_actions_exception(integration_client):

body = response.json()
assert body["success"] is False
assert "chakki" in body["error"]
assert "chakki" in body["error"]


def test_input_guardrails_with_validator_actions_rephrase(integration_client):
response = integration_client.post(
"/api/v1/guardrails/input/",
json={
"request_id": request_id,
"input": "This sentence contains chakki.",
"validators": [
{
"type": "uli_slur_match",
"severity": "all",
"on_fail": "rephrase",
}
],
},
)

assert response.status_code == 200
body = response.json()
assert body["success"] is True
assert "Please rephrase the query without unsafe content. Mentioned toxic words" in body["data"]["safe_input"]