diff --git a/packages/gooddata-sdk/src/gooddata_sdk/__init__.py b/packages/gooddata-sdk/src/gooddata_sdk/__init__.py index fe6e2c5af..b7aad8a4c 100644 --- a/packages/gooddata-sdk/src/gooddata_sdk/__init__.py +++ b/packages/gooddata-sdk/src/gooddata_sdk/__init__.py @@ -276,6 +276,7 @@ PopDatesetMetric, SimpleMetric, ) +from gooddata_api_client.model.allowed_relationship_type import AllowedRelationshipType from gooddata_sdk.compute.service import ComputeService from gooddata_sdk.sdk import GoodDataSdk from gooddata_sdk.table import ExecutionTable, TableService diff --git a/packages/gooddata-sdk/src/gooddata_sdk/compute/service.py b/packages/gooddata-sdk/src/gooddata_sdk/compute/service.py index cd9a5d522..b85c7115c 100644 --- a/packages/gooddata-sdk/src/gooddata_sdk/compute/service.py +++ b/packages/gooddata-sdk/src/gooddata_sdk/compute/service.py @@ -8,6 +8,7 @@ from gooddata_api_client import ApiException from gooddata_api_client.model.afm_cancel_tokens import AfmCancelTokens +from gooddata_api_client.model.allowed_relationship_type import AllowedRelationshipType from gooddata_api_client.model.chat_history_request import ChatHistoryRequest from gooddata_api_client.model.chat_history_result import ChatHistoryResult from gooddata_api_client.model.chat_request import ChatRequest @@ -135,17 +136,27 @@ def build_exec_def_from_chat_result( is_cancellable=is_cancellable, ) - def ai_chat(self, workspace_id: str, question: str) -> ChatResult: + def ai_chat( + self, + workspace_id: str, + question: str, + allowed_relationship_types: Optional[list[AllowedRelationshipType]] = None, + ) -> ChatResult: """ Chat with AI in GoodData workspace. Args: workspace_id (str): workspace identifier question (str): question for the AI + allowed_relationship_types (Optional[list[AllowedRelationshipType]]): list of allowed relationship types + to filter search results. If provided, only relationships of the specified types will be considered. Returns: ChatResult: Chat response """ - chat_request = ChatRequest(question=question) + chat_request_params: dict[str, Any] = {"question": question} + if allowed_relationship_types is not None: + chat_request_params["allowed_relationship_types"] = allowed_relationship_types + chat_request = ChatRequest(**chat_request_params) response = self._actions_api.ai_chat(workspace_id, chat_request, _check_return_type=False) return response @@ -160,17 +171,27 @@ def _parse_sse_events(self, raw: str) -> Iterator[Any]: except json.JSONDecodeError: continue - def ai_chat_stream(self, workspace_id: str, question: str) -> Iterator[Any]: + def ai_chat_stream( + self, + workspace_id: str, + question: str, + allowed_relationship_types: Optional[list[AllowedRelationshipType]] = None, + ) -> Iterator[Any]: """ Chat Stream with AI in GoodData workspace. Args: workspace_id (str): workspace identifier question (str): question for the AI + allowed_relationship_types (Optional[list[AllowedRelationshipType]]): list of allowed relationship types + to filter search results. If provided, only relationships of the specified types will be considered. Returns: Iterator[Any]: Yields parsed JSON objects from each SSE event's data field """ - chat_request = ChatRequest(question=question) + chat_request_params: dict[str, Any] = {"question": question} + if allowed_relationship_types is not None: + chat_request_params["allowed_relationship_types"] = allowed_relationship_types + chat_request = ChatRequest(**chat_request_params) response = self._actions_api.ai_chat_stream( workspace_id, chat_request, _check_return_type=False, _preload_content=False ) @@ -280,6 +301,7 @@ def search_ai( object_types: Optional[list[str]] = None, relevant_score_threshold: Optional[float] = None, title_to_descriptor_ratio: Optional[float] = None, + allowed_relationship_types: Optional[list[AllowedRelationshipType]] = None, ) -> SearchResult: """ Search for metadata objects using similarity search. @@ -293,6 +315,8 @@ def search_ai( "label", "date", "dataset", "visualization" and "dashboard". Defaults to None. relevant_score_threshold (Optional[float]): minimum relevance score threshold for results. Defaults to None. title_to_descriptor_ratio (Optional[float]): ratio of title score to descriptor score. Defaults to None. + allowed_relationship_types (Optional[list[AllowedRelationshipType]]): list of allowed relationship types + to filter search results. If provided, only relationships of the specified types will be considered. Returns: SearchResult: Search results @@ -311,6 +335,8 @@ def search_ai( search_params["relevant_score_threshold"] = relevant_score_threshold if title_to_descriptor_ratio is not None: search_params["title_to_descriptor_ratio"] = title_to_descriptor_ratio + if allowed_relationship_types is not None: + search_params["allowed_relationship_types"] = allowed_relationship_types search_request = SearchRequest(question=question, **search_params) response = self._actions_api.ai_search(workspace_id, search_request, _check_return_type=False) return response diff --git a/packages/gooddata-sdk/tests/compute/test_compute_service.py b/packages/gooddata-sdk/tests/compute/test_compute_service.py index f91dfa29a..a6c9b498b 100644 --- a/packages/gooddata-sdk/tests/compute/test_compute_service.py +++ b/packages/gooddata-sdk/tests/compute/test_compute_service.py @@ -2,7 +2,7 @@ from pathlib import Path import pytest -from gooddata_sdk import CatalogWorkspace +from gooddata_sdk import AllowedRelationshipType, CatalogWorkspace from gooddata_sdk.sdk import GoodDataSdk from tests_support.vcrpy_utils import get_vcr @@ -219,6 +219,85 @@ def test_ai_chat_stream(test_config): sdk.compute.reset_ai_chat_history(test_workspace_id) +@gd_vcr.use_cassette(str(_fixtures_dir / "ai_chat_with_allowed_relationship_types.yaml")) +def test_ai_chat_with_allowed_relationship_types(test_config): + """Test AI chat with allowed_relationship_types parameter.""" + sdk = GoodDataSdk.create(host_=test_config["host"], token_=test_config["token"]) + path = _current_dir / "load" / "ai" + test_workspace_id = test_config["workspace_test"] + + allowed_types = [ + AllowedRelationshipType(source_type="dashboard", target_type="visualization"), + AllowedRelationshipType(source_type="visualization", target_type="metric"), + ] + + try: + _setup_test_workspace(sdk, test_workspace_id, path) + response = sdk.compute.ai_chat( + test_workspace_id, + "Create a visualization for total revenue", + allowed_relationship_types=allowed_types, + ) + assert hasattr(response, "routing") + assert hasattr(response, "created_visualizations") + assert hasattr(response, "chat_history_interaction_id") + assert response.chat_history_interaction_id is not None + finally: + sdk.catalog_workspace.delete_workspace(test_workspace_id) + sdk.compute.reset_ai_chat_history(test_workspace_id) + + +@gd_vcr.use_cassette(str(_fixtures_dir / "ai_chat_stream_with_allowed_relationship_types.yaml")) +def test_ai_chat_stream_with_allowed_relationship_types(test_config): + """Test AI chat stream with allowed_relationship_types parameter.""" + sdk = GoodDataSdk.create(host_=test_config["host"], token_=test_config["token"]) + path = _current_dir / "load" / "ai" + test_workspace_id = test_config["workspace_test"] + + allowed_types = [ + AllowedRelationshipType(source_type="dashboard", target_type="visualization"), + ] + + try: + _setup_test_workspace(sdk, test_workspace_id, path) + buffer = {} + for chunk in sdk.compute.ai_chat_stream( + test_workspace_id, + "What is the total revenue for the year 2024?", + allowed_relationship_types=allowed_types, + ): + buffer = {**buffer, **chunk} + assert buffer is not None + finally: + sdk.catalog_workspace.delete_workspace(test_workspace_id) + sdk.compute.reset_ai_chat_history(test_workspace_id) + + +@gd_vcr.use_cassette(str(_fixtures_dir / "ai_search_with_allowed_relationship_types.yaml")) +def test_search_ai_with_allowed_relationship_types(test_config): + """Test AI search with allowed_relationship_types parameter.""" + sdk = GoodDataSdk.create(host_=test_config["host"], token_=test_config["token"]) + path = _current_dir / "load" / "ai" + test_workspace_id = test_config["workspace_test"] + + allowed_types = [ + AllowedRelationshipType(source_type="dashboard", target_type="visualization"), + AllowedRelationshipType(source_type="visualization", target_type="metric", allow_orphans=False), + ] + + try: + _setup_test_workspace(sdk, test_workspace_id, path) + result = sdk.compute.search_ai( + test_workspace_id, + "What is the total revenue?", + allowed_relationship_types=allowed_types, + ) + assert result is not None + assert hasattr(result, "results") + finally: + sdk.catalog_workspace.delete_workspace(test_workspace_id) + + @gd_vcr.use_cassette(str(_fixtures_dir / "build_exec_def_from_chat_result.yaml")) def test_build_exec_def_from_chat_result(test_config): """Test build execution definition from chat result."""