diff --git a/src/app/endpoints/query_v2.py b/src/app/endpoints/query_v2.py index e6b14522f..1e5abc45b 100644 --- a/src/app/endpoints/query_v2.py +++ b/src/app/endpoints/query_v2.py @@ -8,7 +8,15 @@ from fastapi import APIRouter, Depends, Request from llama_stack.apis.agents.openai_responses import ( + OpenAIResponseMCPApprovalRequest, + OpenAIResponseMCPApprovalResponse, OpenAIResponseObject, + OpenAIResponseOutput, + OpenAIResponseOutputMessageFileSearchToolCall, + OpenAIResponseOutputMessageFunctionToolCall, + OpenAIResponseOutputMessageMCPCall, + OpenAIResponseOutputMessageMCPListTools, + OpenAIResponseOutputMessageWebSearchToolCall, ) from llama_stack_client import AsyncLlamaStackClient @@ -41,6 +49,7 @@ get_topic_summary_system_prompt, ) from utils.mcp_headers import mcp_headers_dependency +from utils.query import parse_arguments_string from utils.responses import extract_text_from_response_output_item from utils.shields import ( append_turn_to_conversation, @@ -73,153 +82,160 @@ def _build_tool_call_summary( # pylint: disable=too-many-return-statements,too-many-branches - output_item: Any, + output_item: OpenAIResponseOutput, ) -> tuple[Optional[ToolCallSummary], Optional[ToolResultSummary]]: - """Translate applicable Responses API tool outputs into ``ToolCallSummary`` records. + """Translate Responses API tool outputs into ToolCallSummary and ToolResultSummary records. - The OpenAI ``response.output`` array may contain any ``OpenAIResponseOutput`` variant: - ``message``, ``function_call``, ``file_search_call``, ``web_search_call``, ``mcp_call``, - ``mcp_list_tools``, or ``mcp_approval_request``. The OpenAI Spec supports more types - but as llamastack does not support them, yet they are not considered here. + Processes OpenAI response output items and extracts tool call and result information. + + Args: + output_item: An OpenAIResponseOutput item from the response.output array + + Returns: + A tuple of (ToolCallSummary, ToolResultSummary) one of them possibly None + if current llama stack Responses API does not provide the information. + + Supported tool types: + - function_call: Function tool calls with parsed arguments (no result) + - file_search_call: File search operations with results + - web_search_call: Web search operations (incomplete) + - mcp_call: MCP calls with server labels + - mcp_list_tools: MCP server tool listings + - mcp_approval_request: MCP approval requests (no result) + - mcp_approval_response: MCP approval responses (no call) """ item_type = getattr(output_item, "type", None) if item_type == "function_call": - parsed_arguments = getattr(output_item, "arguments", "") - if isinstance(parsed_arguments, dict): - args = parsed_arguments - else: - args = {"arguments": parsed_arguments} - - call_id = getattr(output_item, "id", None) or getattr( - output_item, "call_id", None - ) + item = cast(OpenAIResponseOutputMessageFunctionToolCall, output_item) return ( ToolCallSummary( - id=str(call_id), - name=getattr(output_item, "name", "function_call"), - args=args, + id=item.call_id, + name=item.name, + args=parse_arguments_string(item.arguments), type="function_call", ), - None, + None, # not supported by Responses API at all ) if item_type == "file_search_call": - args = { - "queries": list(getattr(output_item, "queries", [])), - "status": getattr(output_item, "status", None), - } - results = getattr(output_item, "results", None) - response_payload: Optional[Any] = None - if results is not None: - # Store only the essential result metadata to avoid large payloads + item = cast(OpenAIResponseOutputMessageFileSearchToolCall, output_item) + response_payload: Optional[dict[str, Any]] = None + if item.results is not None: response_payload = { - "results": [ - { - "file_id": ( - getattr(result, "file_id", None) - if not isinstance(result, dict) - else result.get("file_id") - ), - "filename": ( - getattr(result, "filename", None) - if not isinstance(result, dict) - else result.get("filename") - ), - "score": ( - getattr(result, "score", None) - if not isinstance(result, dict) - else result.get("score") - ), - } - for result in results - ] + "results": [result.model_dump() for result in item.results] } return ToolCallSummary( - id=str(getattr(output_item, "id")), + id=item.id, name=DEFAULT_RAG_TOOL, - args=args, + args={"queries": item.queries}, type="file_search_call", ), ToolResultSummary( - id=str(getattr(output_item, "id")), - status=str(getattr(output_item, "status", None)), - content=json.dumps(response_payload) if response_payload else None, + id=item.id, + status=item.status, + content=json.dumps(response_payload) if response_payload else "", type="file_search_call", round=1, ) + # Incomplete OpenAI Responses API definition in LLS: action attribute not supported yet if item_type == "web_search_call": - args = {"status": getattr(output_item, "status", None)} + item = cast(OpenAIResponseOutputMessageWebSearchToolCall, output_item) return ( ToolCallSummary( - id=str(getattr(output_item, "id")), + id=item.id, name="web_search", - args=args, + args={}, type="web_search_call", ), - None, + ToolResultSummary( + id=item.id, + status=item.status, + content="", + type="web_search_call", + round=1, + ), ) if item_type == "mcp_call": - parsed_arguments = getattr(output_item, "arguments", "") - args = {"arguments": parsed_arguments} - server_label = getattr(output_item, "server_label", None) - if server_label: - args["server_label"] = server_label - error = getattr(output_item, "error", None) - if error: - args["error"] = error + item = cast(OpenAIResponseOutputMessageMCPCall, output_item) + args = parse_arguments_string(item.arguments) + if item.server_label: + args["server_label"] = item.server_label + content = item.error if item.error else (item.output if item.output else "") return ToolCallSummary( - id=str(getattr(output_item, "id")), - name=getattr(output_item, "name", "mcp_call"), + id=item.id, + name=item.name, args=args, type="mcp_call", ), ToolResultSummary( - id=str(getattr(output_item, "id")), - status=str(getattr(output_item, "status", None)), - content=getattr(output_item, "output", ""), + id=item.id, + status="success" if item.error is None else "failure", + content=content, type="mcp_call", round=1, ) if item_type == "mcp_list_tools": - tool_names: list[str] = [] - for tool in getattr(output_item, "tools", []): - if hasattr(tool, "name"): - tool_names.append(str(getattr(tool, "name"))) - elif isinstance(tool, dict) and tool.get("name"): - tool_names.append(str(tool.get("name"))) - args = { - "server_label": getattr(output_item, "server_label", None), - "tools": tool_names, + item = cast(OpenAIResponseOutputMessageMCPListTools, output_item) + tools_info = [ + { + "name": tool.name, + "description": tool.description, + "input_schema": tool.input_schema, + } + for tool in item.tools + ] + content_dict = { + "server_label": item.server_label, + "tools": tools_info, } return ( ToolCallSummary( - id=str(getattr(output_item, "id")), + id=item.id, name="mcp_list_tools", - args=args, + args={"server_label": item.server_label}, type="mcp_list_tools", ), - None, + ToolResultSummary( + id=item.id, + status="success", + content=json.dumps(content_dict), + type="mcp_list_tools", + round=1, + ), ) if item_type == "mcp_approval_request": - parsed_arguments = getattr(output_item, "arguments", "") - args = {"arguments": parsed_arguments} - server_label = getattr(output_item, "server_label", None) - if server_label: - args["server_label"] = server_label + item = cast(OpenAIResponseMCPApprovalRequest, output_item) + args = parse_arguments_string(item.arguments) return ( ToolCallSummary( - id=str(getattr(output_item, "id")), - name=getattr(output_item, "name", "mcp_approval_request"), + id=item.id, + name=item.name, args=args, type="tool_call", ), None, ) + if item_type == "mcp_approval_response": + item = cast(OpenAIResponseMCPApprovalResponse, output_item) + content_dict = {} + if item.reason: + content_dict["reason"] = item.reason + return ( + None, + ToolResultSummary( + id=item.approval_request_id, + status="success" if item.approve else "denied", + content=json.dumps(content_dict), + type="mcp_approval_response", + round=1, + ), + ) + return None, None diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 9bc5b7a75..d81534328 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -10,7 +10,6 @@ from typing import ( Annotated, Any, - AsyncGenerator, AsyncIterator, Iterator, Optional, @@ -369,7 +368,7 @@ def generic_llm_error(error: Exception, media_type: str) -> str: ) -async def stream_http_error(error: AbstractErrorResponse) -> AsyncGenerator[str, None]: +def stream_http_error(error: AbstractErrorResponse) -> Iterator[str]: """ Yield an SSE-formatted error response for generic LLM or API errors. diff --git a/src/app/endpoints/streaming_query_v2.py b/src/app/endpoints/streaming_query_v2.py index 847bcacbf..c947b208f 100644 --- a/src/app/endpoints/streaming_query_v2.py +++ b/src/app/endpoints/streaming_query_v2.py @@ -6,14 +6,13 @@ from fastapi import APIRouter, Depends, Request from fastapi.responses import StreamingResponse from llama_stack.apis.agents.openai_responses import ( - OpenAIResponseContentPartOutputText, - OpenAIResponseMessage, OpenAIResponseObject, OpenAIResponseObjectStream, OpenAIResponseObjectStreamResponseCompleted, - OpenAIResponseObjectStreamResponseContentPartAdded, + OpenAIResponseObjectStreamResponseFailed, + OpenAIResponseObjectStreamResponseOutputItemDone, OpenAIResponseObjectStreamResponseOutputTextDelta, - OpenAIResponseOutputMessageContentOutputText, + OpenAIResponseObjectStreamResponseOutputTextDone, ) from llama_stack_client import AsyncLlamaStackClient @@ -23,14 +22,19 @@ validate_attachments_metadata, ) from app.endpoints.query_v2 import ( + _build_tool_call_summary, extract_token_usage_from_responses_api, get_topic_summary, parse_referenced_documents_from_responses_api, prepare_tools_for_responses_api, ) from app.endpoints.streaming_query import ( + LLM_TOKEN_EVENT, + LLM_TOOL_CALL_EVENT, + LLM_TOOL_RESULT_EVENT, format_stream_data, stream_end_event, + stream_event, stream_start_event, streaming_query_endpoint_handler_base, ) @@ -56,6 +60,7 @@ cleanup_after_streaming, get_system_prompt, ) +from utils.query import create_violation_stream from utils.quota import consume_tokens, get_available_quotas from utils.suid import normalize_conversation_id, to_llama_stack_conversation_id from utils.mcp_headers import mcp_headers_dependency @@ -65,7 +70,7 @@ ) from utils.token_counter import TokenCounter from utils.transcripts import store_transcript -from utils.types import ToolCallSummary, TurnSummary +from utils.types import TurnSummary logger = logging.getLogger("app.endpoints.handlers") router = APIRouter(tags=["streaming_query_v1"]) @@ -130,12 +135,10 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat # Accumulators for Responses API text_parts: list[str] = [] - tool_item_registry: dict[str, dict[str, str]] = {} emitted_turn_complete = False # Use the conversation_id from context (either provided or newly created) conv_id = context.conversation_id - start_event_emitted = False # Track the latest response object from response.completed event latest_response_object: Optional[Any] = None @@ -146,121 +149,122 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat event_type = getattr(chunk, "type", None) logger.debug("Processing chunk %d, type: %s", chunk_id, event_type) - # Emit start event on first chunk (conversation_id is always set at this point) - if not start_event_emitted: - yield stream_start_event(conv_id) - start_event_emitted = True - - # Handle response.created event (just skip, no need to extract conversation_id) + # Emit start event when response is created if event_type == "response.created": - continue + yield stream_start_event(conv_id) # Text streaming if event_type == "response.output_text.delta": - delta = getattr(chunk, "delta", "") - if delta: - text_parts.append(delta) - yield format_stream_data( + delta_chunk = cast( + OpenAIResponseObjectStreamResponseOutputTextDelta, chunk + ) + if delta_chunk.delta: + text_parts.append(delta_chunk.delta) + yield stream_event( { - "event": "token", - "data": { - "id": chunk_id, - "token": delta, - }, - } + "id": chunk_id, + "token": delta_chunk.delta, + }, + LLM_TOKEN_EVENT, + media_type, ) chunk_id += 1 # Final text of the output (capture, but emit at response.completed) elif event_type == "response.output_text.done": - final_text = getattr(chunk, "text", "") - if final_text: - summary.llm_response = final_text + done_chunk = cast( + OpenAIResponseObjectStreamResponseOutputTextDone, chunk + ) + if done_chunk.text: + summary.llm_response = done_chunk.text - # Content part started - emit an empty token to kick off UI streaming if desired + # Content part started - emit an empty token to kick off UI streaming elif event_type == "response.content_part.added": - yield format_stream_data( + yield stream_event( { - "event": "token", - "data": { - "id": chunk_id, - "token": "", - }, - } + "id": chunk_id, + "token": "", + }, + LLM_TOKEN_EVENT, + media_type, ) chunk_id += 1 - # Track tool call items as they are added so we can build a summary later - elif event_type == "response.output_item.added": - item = getattr(chunk, "item", None) - item_type = getattr(item, "type", None) - if item and item_type == "function_call": - item_id = getattr(item, "id", "") - name = getattr(item, "name", "function_call") - call_id = getattr(item, "call_id", item_id) - if item_id: - tool_item_registry[item_id] = { - "name": name, - "call_id": call_id, - } - - # Stream tool call arguments as tool_call events - elif event_type == "response.function_call_arguments.delta": - delta = getattr(chunk, "delta", "") - yield format_stream_data( - { - "event": "tool_call", - "data": { - "id": chunk_id, - "role": "tool_execution", - "token": delta, - }, - } + # Process tool calls and results are emitted together when output items are done + # TODO(asimurka): support emitting tool calls and results separately when ready + elif event_type == "response.output_item.done": + done_chunk = cast( + OpenAIResponseObjectStreamResponseOutputItemDone, chunk ) - chunk_id += 1 - - # Finalize tool call arguments and append to summary - elif event_type in ( - "response.function_call_arguments.done", - "response.mcp_call.arguments.done", - ): - item_id = getattr(chunk, "item_id", "") - arguments = getattr(chunk, "arguments", "") - meta = tool_item_registry.get(item_id, {}) - summary.tool_calls.append( - ToolCallSummary( - id=meta.get("call_id", item_id or "unknown"), - name=meta.get("name", "tool_call"), - args=( - arguments if isinstance(arguments, dict) else {} - ), # Handle non-dict arguments - type="tool_call", + if done_chunk.item.type == "message": + continue + tool_call, tool_result = _build_tool_call_summary(done_chunk.item) + if tool_call: + summary.tool_calls.append(tool_call) + yield stream_event( + tool_call.model_dump(), + LLM_TOOL_CALL_EVENT, + media_type, + ) + if tool_result: + summary.tool_results.append(tool_result) + yield stream_event( + tool_result.model_dump(), + LLM_TOOL_RESULT_EVENT, + media_type, ) - ) # Completed response - capture final text and response object elif event_type == "response.completed": # Capture the response object for token usage extraction - latest_response_object = getattr(chunk, "response", None) + completed_chunk = cast( + OpenAIResponseObjectStreamResponseCompleted, chunk + ) + latest_response_object = completed_chunk.response if not emitted_turn_complete: final_message = summary.llm_response or "".join(text_parts) if not final_message: final_message = "No response from the model" summary.llm_response = final_message - yield format_stream_data( + yield stream_event( { - "event": "turn_complete", - "data": { - "id": chunk_id, - "token": final_message, - }, - } + "id": chunk_id, + "token": final_message, + }, + "turn_complete", + media_type, ) chunk_id += 1 emitted_turn_complete = True - # Ignore other event types for now; could add heartbeats if desired + # Incomplete response - emit error because LLS does not + # support incomplete responses "incomplete_detail" attribute yet + elif event_type == "response.incomplete": + error_response = InternalServerErrorResponse.query_failed( + "An unexpected error occurred while processing the request." + ) + logger.error("Error while obtaining answer for user question") + yield format_stream_data( + {"event": "error", "data": {**error_response.detail.model_dump()}} + ) + return + + # Failed response - emit error with custom cause from error message + elif event_type == "response.failed": + failed_chunk = cast(OpenAIResponseObjectStreamResponseFailed, chunk) + latest_response_object = failed_chunk.response + error_message = ( + failed_chunk.response.error.message + if failed_chunk.response.error + else "An unexpected error occurred while processing the request." + ) + error_response = InternalServerErrorResponse.query_failed(error_message) + logger.error("Error while obtaining answer for user question") + yield format_stream_data( + {"event": "error", "data": {**error_response.detail.model_dump()}} + ) + return logger.debug( "Streaming complete - Tool calls: %d, Response chars: %d", @@ -467,54 +471,3 @@ async def retrieve_response( # pylint: disable=too-many-locals response_stream = cast(AsyncIterator[OpenAIResponseObjectStream], response) return response_stream, normalize_conversation_id(conversation_id) - - -async def create_violation_stream( - message: str, - shield_model: Optional[str] = None, -) -> AsyncIterator[OpenAIResponseObjectStream]: - """Generate a minimal streaming response for cases where input is blocked by a shield. - - This yields only the essential streaming events to indicate that the input was rejected. - Dummy item identifiers are used solely for protocol compliance and are not used later. - """ - response_id = "resp_shield_violation" - - # Content part added (triggers empty initial token) - yield OpenAIResponseObjectStreamResponseContentPartAdded( - content_index=0, - response_id=response_id, - item_id="msg_shield_violation_1", - output_index=0, - part=OpenAIResponseContentPartOutputText(text=""), - sequence_number=0, - ) - - # Text delta - yield OpenAIResponseObjectStreamResponseOutputTextDelta( - content_index=1, - delta=message, - item_id="msg_shield_violation_2", - output_index=1, - sequence_number=1, - ) - - # Completed response - yield OpenAIResponseObjectStreamResponseCompleted( - response=OpenAIResponseObject( - id=response_id, - created_at=0, # not used - model=shield_model or "shield", - output=[ - OpenAIResponseMessage( - id="msg_shield_violation_3", - content=[ - OpenAIResponseOutputMessageContentOutputText(text=message) - ], - role="assistant", - status="completed", - ) - ], - status="completed", - ) - ) diff --git a/src/models/responses.py b/src/models/responses.py index bf4da2698..08f1fb504 100644 --- a/src/models/responses.py +++ b/src/models/responses.py @@ -1885,21 +1885,20 @@ def feedback_path_invalid(cls, path: str) -> "InternalServerErrorResponse": ) @classmethod - def query_failed(cls, backend_url: str) -> "InternalServerErrorResponse": + def query_failed(cls, cause: str) -> "InternalServerErrorResponse": """ - Create an InternalServerErrorResponse representing a failed query to an external backend. + Create an InternalServerErrorResponse representing a failed query. Parameters: - backend_url (str): The backend URL included in the error cause message. + cause (str): The error cause message. Returns: InternalServerErrorResponse: An error response with response "Error - while processing query" and cause "Failed to call backend: - {backend_url}". + while processing query" and the provided cause. """ return cls( response="Error while processing query", - cause=f"Failed to call backend: {backend_url}", + cause=cause, ) @classmethod diff --git a/src/utils/query.py b/src/utils/query.py new file mode 100644 index 000000000..c1650c27d --- /dev/null +++ b/src/utils/query.py @@ -0,0 +1,122 @@ +"""Utility functions for working with queries.""" + +import json +from typing import Any, AsyncIterator, Optional + + +from llama_stack.apis.agents.openai_responses import ( + OpenAIResponseContentPartOutputText, + OpenAIResponseObject, + OpenAIResponseObjectStream, + OpenAIResponseObjectStreamResponseCreated, + OpenAIResponseObjectStreamResponseContentPartAdded, + OpenAIResponseObjectStreamResponseOutputTextDelta, + OpenAIResponseObjectStreamResponseOutputTextDone, + OpenAIResponseMessage, + OpenAIResponseOutputMessageContentOutputText, + OpenAIResponseObjectStreamResponseCompleted, +) + + +def parse_arguments_string(arguments_str: str) -> dict[str, Any]: + """ + Try to parse an arguments string into a dictionary. + + Attempts multiple parsing strategies: + 1. Try parsing the string as-is as JSON (if it's already valid JSON) + 2. Try wrapping the string in {} if it doesn't start with { + 3. Return {"args": arguments_str} if all attempts fail + + Args: + arguments_str: The arguments string to parse + + Returns: + Parsed dictionary if successful, otherwise {"args": arguments_str} + """ + # Try parsing as-is first (most common case) + try: + parsed = json.loads(arguments_str) + if isinstance(parsed, dict): + return parsed + except (json.JSONDecodeError, ValueError): + pass + + # Try wrapping in {} if string doesn't start with { + # This handles cases where the string is just the content without braces + stripped = arguments_str.strip() + if stripped and not stripped.startswith("{"): + try: + wrapped = "{" + stripped + "}" + parsed = json.loads(wrapped) + if isinstance(parsed, dict): + return parsed + except (json.JSONDecodeError, ValueError): + pass + + # Fallback: return wrapped in arguments key + return {"args": arguments_str} + + +async def create_violation_stream( + message: str, + shield_model: Optional[str] = None, +) -> AsyncIterator[OpenAIResponseObjectStream]: + """Generate a minimal streaming response for cases where input is blocked by a shield. + + This yields only the essential streaming events to indicate that the input was rejected. + Dummy item identifiers are used solely for protocol compliance and are not used later. + """ + response_id = "resp_shield_violation" + + # Create the response object with empty output at the beginning + response_obj = OpenAIResponseObject( + id=response_id, + created_at=0, # not used + model=shield_model or "shield", + output=[], + status="in_progress", + ) + yield OpenAIResponseObjectStreamResponseCreated(response=response_obj) + + # Triggers empty initial token + yield OpenAIResponseObjectStreamResponseContentPartAdded( + content_index=0, + response_id=response_id, + item_id="msg_shield_violation_1", + output_index=0, + part=OpenAIResponseContentPartOutputText(text=""), + sequence_number=0, + ) + + # Text delta + yield OpenAIResponseObjectStreamResponseOutputTextDelta( + content_index=1, + delta=message, + item_id="msg_shield_violation_2", + output_index=1, + sequence_number=1, + ) + + # Output text done + yield OpenAIResponseObjectStreamResponseOutputTextDone( + content_index=2, + text=message, + item_id="msg_shield_violation_3", + output_index=2, + sequence_number=2, + ) + + # Fill the output when message is completed + response_obj.output = [ + OpenAIResponseMessage( + id="msg_shield_violation_4", + content=[OpenAIResponseOutputMessageContentOutputText(text=message)], + role="assistant", + status="completed", + ) + ] + # Update status to completed + response_obj.status = "completed" + + # Completed response triggers turn complete event + yield OpenAIResponseObjectStreamResponseCompleted(response=response_obj) diff --git a/src/utils/shields.py b/src/utils/shields.py index ca99671c7..5fa14d33c 100644 --- a/src/utils/shields.py +++ b/src/utils/shields.py @@ -113,11 +113,11 @@ async def run_shield_moderation( shield_model=shield.provider_resource_id, ) - # Known Llama Stack bug: BadRequestError is raised when violation is present + # Known Llama Stack bug: error is raised when violation is present # in the shield LLM response but has wrong format that cannot be parsed. - except BadRequestError: + except (BadRequestError, ValueError): logger.warning( - "Shield '%s' returned BadRequestError, treating as blocked", + "Shield '%s' violation detected, treating as blocked", shield.identifier, ) metrics.llm_calls_validation_errors_total.inc() diff --git a/src/utils/types.py b/src/utils/types.py index 06f8d9e61..37cc8f89c 100644 --- a/src/utils/types.py +++ b/src/utils/types.py @@ -128,7 +128,7 @@ class ToolResultSummary(BaseModel): status: str = Field( ..., description="Status of the tool execution (e.g., 'success')" ) - content: Any = Field(..., description="Content/result returned from the tool") + content: str = Field(..., description="Content/result returned from the tool") type: str = Field("tool_result", description="Type indicator for tool result") round: int = Field(..., description="Round number or step of tool execution") @@ -193,9 +193,9 @@ def append_tool_calls_from_llama(self, tec: ToolExecutionStep) -> None: ToolResultSummary( id=call_id, status="success" if resp else "failure", - content=response_content, + content=response_content or "", type="tool_result", - round=1, # clarify meaning of this attribute + round=1, ) ) # Extract RAG chunks from knowledge_search tool responses diff --git a/tests/integration/endpoints/test_query_v2_integration.py b/tests/integration/endpoints/test_query_v2_integration.py index ec0397127..ffa90b5b9 100644 --- a/tests/integration/endpoints/test_query_v2_integration.py +++ b/tests/integration/endpoints/test_query_v2_integration.py @@ -367,6 +367,18 @@ async def test_query_v2_endpoint_with_tool_calls( "doc_url": "https://example.com/ansible-docs.txt", "link": "https://example.com/ansible-docs.txt", } + mock_result.model_dump = mocker.Mock( + return_value={ + "file_id": "doc-1", + "filename": "ansible-docs.txt", + "score": 0.95, + "text": "Ansible is an open-source automation tool...", + "attributes": { + "doc_url": "https://example.com/ansible-docs.txt", + "link": "https://example.com/ansible-docs.txt", + }, + } + ) mock_tool_output.results = [mock_result] mock_message_output = mocker.MagicMock() @@ -422,9 +434,13 @@ async def test_query_v2_endpoint_with_mcp_list_tools( mock_tool1 = mocker.MagicMock() mock_tool1.name = "list_pods" + mock_tool1.description = "List Kubernetes pods" + mock_tool1.input_schema = {"type": "object", "properties": {}} mock_tool2 = mocker.MagicMock() mock_tool2.name = "get_deployment" + mock_tool2.description = "Get Kubernetes deployment" + mock_tool2.input_schema = {"type": "object", "properties": {}} mock_mcp_list = mocker.MagicMock() mock_mcp_list.type = "mcp_list_tools" @@ -494,8 +510,9 @@ async def test_query_v2_endpoint_with_multiple_tool_types( mock_function = mocker.MagicMock() mock_function.type = "function_call" mock_function.id = "func-2" + mock_function.call_id = "func-2" mock_function.name = "calculate" - mock_function.arguments = {"operation": "sum"} + mock_function.arguments = '{"operation": "sum"}' mock_function.status = "completed" mock_message = mocker.MagicMock() diff --git a/tests/unit/app/endpoints/test_query_v2.py b/tests/unit/app/endpoints/test_query_v2.py index fd4ece751..47d925bb5 100644 --- a/tests/unit/app/endpoints/test_query_v2.py +++ b/tests/unit/app/endpoints/test_query_v2.py @@ -341,8 +341,9 @@ async def test_retrieve_response_parses_output_and_tool_calls( tool_call_item = mocker.Mock() tool_call_item.type = "function_call" tool_call_item.id = "tc-1" + tool_call_item.call_id = "tc-1" tool_call_item.name = "do_something" - tool_call_item.arguments = {"x": 1} + tool_call_item.arguments = '{"x": 1}' tool_call_item.status = None # Explicitly set to avoid Mock auto-creation response_obj = mocker.Mock() @@ -898,6 +899,7 @@ def _create_file_search_output(mocker: MockerFixture) -> Any: # 2. Output item with file search tool call results output_item = mocker.Mock() output_item.type = "file_search_call" + output_item.id = "file-search-1" output_item.queries = ( [] ) # Ensure queries is a list to avoid iteration error in tool summary @@ -909,6 +911,15 @@ def _create_file_search_output(mocker: MockerFixture) -> Any: result_1.text = "Sample text from file2.pdf" result_1.score = 0.95 result_1.file_id = "file-123" + result_1.model_dump = mocker.Mock( + return_value={ + "filename": "file2.pdf", + "attributes": {"url": "http://example.com/doc2"}, + "text": "Sample text from file2.pdf", + "score": 0.95, + "file_id": "file-123", + } + ) result_2 = mocker.Mock() result_2.filename = "file3.docx" @@ -916,6 +927,15 @@ def _create_file_search_output(mocker: MockerFixture) -> Any: result_2.text = "Sample text from file3.docx" result_2.score = 0.85 result_2.file_id = "file-456" + result_2.model_dump = mocker.Mock( + return_value={ + "filename": "file3.docx", + "attributes": {}, + "text": "Sample text from file3.docx", + "score": 0.85, + "file_id": "file-456", + } + ) output_item.results = [result_1, result_2] return output_item diff --git a/tests/unit/app/endpoints/test_streaming_query_v2.py b/tests/unit/app/endpoints/test_streaming_query_v2.py index 20a379b72..29947c434 100644 --- a/tests/unit/app/endpoints/test_streaming_query_v2.py +++ b/tests/unit/app/endpoints/test_streaming_query_v2.py @@ -140,8 +140,8 @@ async def test_streaming_query_endpoint_handler_v2_success_yields_events( lambda conv_id: f"START:{conv_id}\n", ) mocker.patch( - "app.endpoints.streaming_query_v2.format_stream_data", - lambda obj: f"EV:{obj['event']}:{obj['data'].get('token','')}\n", + "app.endpoints.streaming_query_v2.stream_event", + lambda data, event_type, media_type: f"EV:{event_type}:{data.get('token','')}\n", ) mocker.patch( "app.endpoints.streaming_query_v2.stream_end_event", @@ -164,9 +164,7 @@ async def fake_stream() -> AsyncIterator[Mock]: - a "response.created" event with a conversation id, - content and text delta events ("response.content_part.added", "response.output_text.delta"), - - function call events ("response.output_item.added", - "response.function_call_arguments.delta", - "response.function_call_arguments.done"), + - function call events ("response.output_item.done" with completed tool call), - a final "response.output_text.done" event and a "response.completed" event. Returns: @@ -182,13 +180,8 @@ async def fake_stream() -> AsyncIterator[Mock]: yield Mock(type="response.output_text.delta", delta="world") item_mock = Mock(type="function_call", id="item1", call_id="call1") item_mock.name = "search" # 'name' is a special Mock param, set explicitly - yield Mock(type="response.output_item.added", item=item_mock) - yield Mock(type="response.function_call_arguments.delta", delta='{"q":"x"}') - yield Mock( - type="response.function_call_arguments.done", - item_id="item1", - arguments='{"q":"x"}', - ) + item_mock.arguments = '{"q":"x"}' + yield Mock(type="response.output_item.done", item=item_mock) yield Mock(type="response.output_text.done", text="Hello world") # Include a response object with output attribute for shield violation detection mock_response = Mock(output=[]) @@ -392,8 +385,8 @@ async def test_streaming_response_blocked_by_shield_moderation( lambda conv_id: f"START:{conv_id}\n", ) mocker.patch( - "app.endpoints.streaming_query_v2.format_stream_data", - lambda obj: f"EV:{obj['event']}:{obj['data'].get('token','')}\n", + "app.endpoints.streaming_query_v2.stream_event", + lambda data, event_type, media_type: f"EV:{event_type}:{data.get('token','')}\n", ) mocker.patch( "app.endpoints.streaming_query_v2.stream_end_event", @@ -484,8 +477,8 @@ async def test_streaming_response_no_shield_violation( lambda conv_id: f"START:{conv_id}\n", ) mocker.patch( - "app.endpoints.streaming_query_v2.format_stream_data", - lambda obj: f"EV:{obj['event']}:{obj['data'].get('token','')}\n", + "app.endpoints.streaming_query_v2.stream_event", + lambda data, event_type, media_type: f"EV:{event_type}:{data.get('token','')}\n", ) mocker.patch( "app.endpoints.streaming_query_v2.stream_end_event", diff --git a/tests/unit/models/responses/test_error_responses.py b/tests/unit/models/responses/test_error_responses.py index 2f61680d6..aaf5047ab 100644 --- a/tests/unit/models/responses/test_error_responses.py +++ b/tests/unit/models/responses/test_error_responses.py @@ -543,14 +543,13 @@ def test_factory_feedback_path_invalid(self) -> None: def test_factory_query_failed(self) -> None: """Test InternalServerErrorResponse.query_failed() factory method.""" - response = InternalServerErrorResponse.query_failed("https://api.example.com") + custom_cause = "Failed to call backend: https://api.example.com" + response = InternalServerErrorResponse.query_failed(custom_cause) assert isinstance(response, AbstractErrorResponse) assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert isinstance(response.detail, DetailModel) assert response.detail.response == "Error while processing query" - assert ( - response.detail.cause == "Failed to call backend: https://api.example.com" - ) + assert response.detail.cause == custom_cause def test_factory_cache_unavailable(self) -> None: """Test InternalServerErrorResponse.cache_unavailable() factory method.""" diff --git a/tests/unit/models/responses/test_query_response.py b/tests/unit/models/responses/test_query_response.py index 050f91ef8..a6c846f06 100644 --- a/tests/unit/models/responses/test_query_response.py +++ b/tests/unit/models/responses/test_query_response.py @@ -38,7 +38,7 @@ def test_complete_query_response_with_all_fields(self) -> None: ToolResultSummary( id="call-1", status="success", - content={"chunks_found": 5}, + content='{"chunks_found": 5}', type="tool_result", round=1, ) @@ -73,7 +73,7 @@ def test_complete_query_response_with_all_fields(self) -> None: assert qr.tool_results is not None assert len(qr.tool_results) == 1 assert qr.tool_results[0].status == "success" - assert qr.tool_results[0].content == {"chunks_found": 5} + assert qr.tool_results[0].content == '{"chunks_found": 5}' assert qr.tool_results[0].type == "tool_result" assert qr.tool_results[0].round == 1 assert len(qr.referenced_documents) == 1 diff --git a/tests/unit/models/responses/test_successful_responses.py b/tests/unit/models/responses/test_successful_responses.py index 2e7056245..880b8c494 100644 --- a/tests/unit/models/responses/test_successful_responses.py +++ b/tests/unit/models/responses/test_successful_responses.py @@ -287,7 +287,7 @@ def test_constructor_full(self) -> None: ToolResultSummary( id="call-1", status="success", - content={"chunks_found": 5}, + content='{"chunks_found": 5}', type="tool_result", round=1, )