-
Notifications
You must be signed in to change notification settings - Fork 478
feat(api): AI services backend — refine prompt tool call #3687
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
0b6ddcb
0f04e37
e88293d
be41428
454a876
7ce78e4
601bbee
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
|
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,17 @@ | ||
| from oss.src.core.ai_services.dtos import ( | ||
| AIServicesStatus, | ||
| ToolCallRequest, | ||
| ToolCallResponse, | ||
| ) | ||
|
|
||
|
|
||
| class AIServicesStatusResponse(AIServicesStatus): | ||
| pass | ||
|
|
||
|
|
||
| class ToolCallRequestModel(ToolCallRequest): | ||
| pass | ||
|
|
||
|
|
||
| class ToolCallResponseModel(ToolCallResponse): | ||
| pass |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,123 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from fastapi import APIRouter, HTTPException, Request, status | ||
| from pydantic import ValidationError | ||
|
|
||
| from oss.src.utils.common import is_ee | ||
| from oss.src.utils.exceptions import intercept_exceptions | ||
| from oss.src.utils.throttling import check_throttle | ||
|
|
||
| from oss.src.core.ai_services.dtos import TOOL_REFINE_PROMPT | ||
| from oss.src.core.ai_services.service import AIServicesService | ||
| from oss.src.apis.fastapi.ai_services.models import ( | ||
| AIServicesStatusResponse, | ||
| ToolCallRequestModel, | ||
| ToolCallResponseModel, | ||
| ) | ||
|
|
||
|
|
||
| if is_ee(): | ||
| from ee.src.models.shared_models import Permission | ||
| from ee.src.utils.permissions import check_action_access, FORBIDDEN_EXCEPTION | ||
|
|
||
|
|
||
| _RATE_LIMIT_BURST = 10 | ||
| _RATE_LIMIT_PER_MIN = 30 | ||
|
|
||
|
|
||
| class AIServicesRouter: | ||
| def __init__( | ||
| self, | ||
| *, | ||
| ai_services_service: AIServicesService, | ||
| ): | ||
| self.service = ai_services_service | ||
| self.router = APIRouter() | ||
|
|
||
| self.router.add_api_route( | ||
| "/status", | ||
| self.get_status, | ||
| methods=["GET"], | ||
| operation_id="ai_services_status", | ||
| status_code=status.HTTP_200_OK, | ||
| response_model=AIServicesStatusResponse, | ||
| response_model_exclude_none=True, | ||
| ) | ||
|
|
||
| self.router.add_api_route( | ||
| "/tools/call", | ||
| self.call_tool, | ||
| methods=["POST"], | ||
| operation_id="ai_services_tools_call", | ||
| status_code=status.HTTP_200_OK, | ||
| response_model=ToolCallResponseModel, | ||
| response_model_exclude_none=True, | ||
| ) | ||
|
|
||
| @intercept_exceptions() | ||
| async def get_status(self, request: Request) -> AIServicesStatusResponse: | ||
| allow_tools = True | ||
|
|
||
| if is_ee(): | ||
| allow_tools = await check_action_access( # type: ignore | ||
| user_uid=request.state.user_id, | ||
| project_id=request.state.project_id, | ||
| permission=Permission.EDIT_WORKFLOWS, # type: ignore | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are all tools meant to mutate |
||
| ) | ||
|
|
||
| return self.service.status(allow_tools=allow_tools) | ||
|
|
||
| @intercept_exceptions() | ||
| async def call_tool( | ||
| self, | ||
| request: Request, | ||
| *, | ||
| tool_call: ToolCallRequestModel, | ||
| ) -> ToolCallResponseModel: | ||
| if not self.service.enabled: | ||
| raise HTTPException(status_code=503, detail="AI services are disabled") | ||
|
|
||
| if is_ee(): | ||
| if not await check_action_access( # type: ignore | ||
| user_uid=request.state.user_id, | ||
| project_id=request.state.project_id, | ||
| permission=Permission.EDIT_WORKFLOWS, # type: ignore | ||
| ): | ||
| raise FORBIDDEN_EXCEPTION # type: ignore | ||
|
|
||
| # Router-level rate limit | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Out of curiosity, |
||
| key = { | ||
| "ep": "ai_services", | ||
| "tool": tool_call.name, | ||
| "org": getattr(request.state, "organization_id", None), | ||
| "user": getattr(request.state, "user_id", None), | ||
| } | ||
| result = await check_throttle( | ||
| key, | ||
| max_capacity=_RATE_LIMIT_BURST, | ||
| refill_rate=_RATE_LIMIT_PER_MIN, | ||
| ) | ||
| if not result.allow: | ||
| retry_after = ( | ||
| int(result.retry_after_seconds) if result.retry_after_seconds else 1 | ||
| ) | ||
| raise HTTPException( | ||
| status_code=429, | ||
| detail="Rate limit exceeded", | ||
| headers={"Retry-After": str(retry_after)}, | ||
| ) | ||
|
|
||
| # Tool routing + strict request validation | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Eventually, we might want to push this down to the dispatcher, which would generate a domain-level exception, caught here and turned into an HTTP exception. |
||
| if tool_call.name != TOOL_REFINE_PROMPT: | ||
| raise HTTPException(status_code=400, detail="Unknown tool") | ||
|
|
||
| try: | ||
| return await self.service.call_tool( | ||
| name=tool_call.name, | ||
| arguments=tool_call.arguments, | ||
| ) | ||
| except ValidationError as e: | ||
| raise HTTPException(status_code=400, detail=e.errors()) from e | ||
| except ValueError as e: | ||
| # Unknown tool or invalid argument shape | ||
| raise HTTPException(status_code=400, detail=str(e)) from e | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
|
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,101 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from typing import Any, Dict, Optional, Tuple | ||
|
|
||
| import httpx | ||
|
|
||
| from oss.src.utils.logging import get_module_logger | ||
|
|
||
|
|
||
| log = get_module_logger(__name__) | ||
|
|
||
|
|
||
| class AgentaAIServicesClient: | ||
| """Thin HTTP client to call Agenta Cloud workflow invocation APIs.""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| *, | ||
| api_url: str, | ||
| api_key: str, | ||
| timeout_s: float = 20.0, | ||
| ): | ||
| self.api_url = api_url.rstrip("/") | ||
| self.api_key = api_key | ||
| self.timeout_s = timeout_s | ||
|
|
||
| async def invoke_deployed_prompt( | ||
| self, | ||
| *, | ||
| application_slug: str, | ||
| environment_slug: str, | ||
| inputs: Dict[str, Any], | ||
| ) -> Tuple[Optional[Any], Optional[str]]: | ||
| """Invoke a deployed prompt by app/environment slug. | ||
|
|
||
| NOTE: This targets the cloud completion runner endpoint. | ||
|
|
||
| Returns: (raw_response, trace_id) | ||
| """ | ||
|
|
||
| url = f"{self.api_url}/services/completion/run" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This turns into |
||
|
|
||
| payload: Dict[str, Any] = { | ||
| "inputs": inputs, | ||
| "environment": environment_slug, | ||
| "app": application_slug, | ||
| } | ||
|
|
||
| headers = { | ||
| "Authorization": f"ApiKey {self.api_key}", | ||
| "Content-Type": "application/json", | ||
| "Accept": "application/json", | ||
| } | ||
|
|
||
| try: | ||
| async with httpx.AsyncClient(timeout=self.timeout_s) as client: | ||
| res = await client.post(url, json=payload, headers=headers) | ||
|
|
||
| # Non-2xx responses still carry useful error payloads | ||
| data: Any = None | ||
| try: | ||
| data = res.json() | ||
| except Exception: | ||
| data = None | ||
|
|
||
| if res.status_code < 200 or res.status_code >= 300: | ||
| log.warning( | ||
| "[ai-services] Upstream invoke failed", | ||
| status_code=res.status_code, | ||
| url=url, | ||
| ) | ||
| # Surface as tool execution error (caller maps to isError) | ||
| return { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You might want to create domain-level exceptions via Pydantic models and then raise exception. There are some examples of this throughout the codebase (not enough IMO). |
||
| "_error": True, | ||
| "status_code": res.status_code, | ||
| "detail": data, | ||
| }, None | ||
|
|
||
| trace_id = None | ||
| if isinstance(data, dict): | ||
| trace_id = data.get("trace_id") or data.get("traceId") | ||
|
|
||
| return data, trace_id | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same thing for the returns data via dtos. |
||
|
|
||
| return None, None | ||
|
|
||
| except httpx.TimeoutException: | ||
| log.warning("[ai-services] Upstream invoke timed out", url=url) | ||
| return { | ||
| "_error": True, | ||
| "status_code": 504, | ||
| "detail": "Upstream timeout", | ||
| }, None | ||
|
|
||
| except Exception as e: # pylint: disable=broad-exception-caught | ||
| log.warning("[ai-services] Upstream invoke error", url=url, error=str(e)) | ||
| return { | ||
| "_error": True, | ||
| "status_code": 502, | ||
| "detail": "Upstream error", | ||
| }, None | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| from typing import Any, Dict, List, Literal, Optional | ||
|
|
||
|
|
||
| from pydantic import BaseModel, ConfigDict, Field | ||
|
|
||
|
|
||
| TOOL_REFINE_PROMPT = "tools.agenta.api.refine_prompt" | ||
|
|
||
|
|
||
| class ToolDefinition(BaseModel): | ||
| name: str | ||
| title: str | ||
| description: str | ||
| inputSchema: Dict[str, Any] | ||
| outputSchema: Dict[str, Any] | ||
|
|
||
|
|
||
| class AIServicesStatus(BaseModel): | ||
| enabled: bool | ||
| tools: List[ToolDefinition] = Field(default_factory=list) | ||
|
|
||
|
|
||
| class ToolCallRequest(BaseModel): | ||
| name: str | ||
| arguments: Dict[str, Any] = Field(default_factory=dict) | ||
|
|
||
|
|
||
| class ToolCallTextContent(BaseModel): | ||
| type: Literal["text"] = "text" | ||
| text: str | ||
|
|
||
|
|
||
| class ToolCallMeta(BaseModel): | ||
| trace_id: Optional[str] = None | ||
|
|
||
|
|
||
| class ToolCallResponse(BaseModel): | ||
| content: List[ToolCallTextContent] = Field(default_factory=list) | ||
| structuredContent: Optional[Dict[str, Any]] = None | ||
| isError: bool = False | ||
| meta: Optional[ToolCallMeta] = None | ||
|
|
||
|
|
||
| class RefinePromptArguments(BaseModel): | ||
| prompt_template_json: str = Field(min_length=1, max_length=100_000) | ||
| guidelines: Optional[str] = Field(default=None, max_length=10_000) | ||
| context: Optional[str] = Field(default=None, max_length=10_000) | ||
|
|
||
| model_config = ConfigDict(extra="forbid") |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this meant to become an env var or a feat flag ?