diff --git a/agentrun/model/__model_proxy_async_template.py b/agentrun/model/__model_proxy_async_template.py index dfa15a7..847dc52 100644 --- a/agentrun/model/__model_proxy_async_template.py +++ b/agentrun/model/__model_proxy_async_template.py @@ -9,6 +9,7 @@ import pydash from agentrun.model.api.data import BaseInfo, ModelDataAPI +from agentrun.model.api.model_api import ModelAPI from agentrun.utils.config import Config from agentrun.utils.model import Status from agentrun.utils.resource import ResourceBase @@ -30,6 +31,7 @@ class ModelProxy( ModelProxyImmutableProps, ModelProxyMutableProps, ModelProxySystemProps, + ModelAPI, ResourceBase, ): """模型服务""" @@ -230,41 +232,3 @@ def model_info(self, config: Optional[Config] = None) -> BaseInfo: ) return self._data_client.model_info() - - def completions( - self, - messages: list, - model: Optional[str] = None, - stream: bool = False, - config: Optional[Config] = None, - **kwargs, - ): - self.model_info(config) - assert self._data_client - - return self._data_client.completions( - **kwargs, - messages=messages, - model=model, - stream=stream, - config=config, - ) - - def responses( - self, - messages: list, - model: Optional[str] = None, - stream: bool = False, - config: Optional[Config] = None, - **kwargs, - ): - self.model_info(config) - assert self._data_client - - return self._data_client.responses( - **kwargs, - messages=messages, - model=model, - stream=stream, - config=config, - ) diff --git a/agentrun/model/__model_service_async_template.py b/agentrun/model/__model_service_async_template.py index e94331d..a3cfdcb 100644 --- a/agentrun/model/__model_service_async_template.py +++ b/agentrun/model/__model_service_async_template.py @@ -6,7 +6,8 @@ from typing import List, Optional -from agentrun.model.api.data import BaseInfo, ModelCompletionAPI +from agentrun.model.api.data import BaseInfo +from agentrun.model.api.model_api import ModelAPI from agentrun.utils.config import Config from agentrun.utils.model import PageableInput from agentrun.utils.resource import ResourceBase @@ -27,6 +28,7 @@ class ModelService( ModelServiceImmutableProps, ModelServiceMutableProps, ModelServicesSystemProps, + ModelAPI, ResourceBase, ): """模型服务""" @@ -230,38 +232,3 @@ def model_info(self, config: Optional[Config] = None) -> BaseInfo: model=default_model, headers=cfg.get_headers(), ) - - def completions( - self, - messages: list, - model: Optional[str] = None, - stream: bool = False, - **kwargs, - ): - info = self.model_info(config=kwargs.get("config")) - - m = ModelCompletionAPI( - api_key=info.api_key or "", - base_url=info.base_url or "", - model=model or info.model or self.model_service_name or "", - ) - - return m.completions(**kwargs, messages=messages, stream=stream) - - def responses( - self, - messages: list, - model: Optional[str] = None, - stream: bool = False, - **kwargs, - ): - info = self.model_info(config=kwargs.get("config")) - - m = ModelCompletionAPI( - api_key=info.api_key or "", - base_url=info.base_url or "", - model=model or info.model or self.model_service_name or "", - provider=(self.provider or "openai").lower(), - ) - - return m.responses(**kwargs, messages=messages, stream=stream) diff --git a/agentrun/model/api/model_api.py b/agentrun/model/api/model_api.py new file mode 100644 index 0000000..001097b --- /dev/null +++ b/agentrun/model/api/model_api.py @@ -0,0 +1,156 @@ +from abc import ABC, abstractmethod +from typing import Optional, TYPE_CHECKING, Union + +from .data import BaseInfo + +if TYPE_CHECKING: + from litellm import ResponseInputParam + + +class ModelAPI(ABC): + + @abstractmethod + def model_info(self) -> BaseInfo: + ... + + def completions( + self, + **kwargs, + ): + """ + Deprecated. Use completion() instead. + """ + import warnings + + warnings.warn( + "completions() is deprecated, use completion() instead", + DeprecationWarning, + stacklevel=2, + ) + return self.completion(**kwargs) + + def completion( + self, + messages=[], + model: Optional[str] = None, + custom_llm_provider: Optional[str] = None, + **kwargs, + ): + from litellm import completion + + info = self.model_info() + return completion( + **kwargs, + api_key=info.api_key, + base_url=info.base_url, + model=model or info.model or "", + custom_llm_provider=custom_llm_provider + or info.provider + or "openai", + messages=messages, + ) + + async def acompletion( + self, + messages=[], + model: Optional[str] = None, + custom_llm_provider: Optional[str] = None, + **kwargs, + ): + from litellm import acompletion + + info = self.model_info() + return await acompletion( + **kwargs, + api_key=info.api_key, + base_url=info.base_url, + model=model or info.model or "", + custom_llm_provider=custom_llm_provider + or info.provider + or "openai", + messages=messages, + ) + + def responses( + self, + input: Union[str, "ResponseInputParam"], + model: Optional[str] = None, + custom_llm_provider: Optional[str] = None, + **kwargs, + ): + from litellm import responses + + info = self.model_info() + return responses( + **kwargs, + api_key=info.api_key, + base_url=info.base_url, + model=model or info.model or "", + custom_llm_provider=custom_llm_provider + or info.provider + or "openai", + input=input, + ) + + async def aresponses( + self, + input: Union[str, "ResponseInputParam"], + model: Optional[str] = None, + custom_llm_provider: Optional[str] = None, + **kwargs, + ): + from litellm import aresponses + + info = self.model_info() + return await aresponses( + **kwargs, + api_key=info.api_key, + base_url=info.base_url, + model=model or info.model or "", + custom_llm_provider=custom_llm_provider + or info.provider + or "openai", + input=input, + ) + + def embedding( + self, + input=[], + model: Optional[str] = None, + custom_llm_provider: Optional[str] = None, + **kwargs, + ): + from litellm import embedding + + info = self.model_info() + return embedding( + **kwargs, + api_key=info.api_key, + api_base=info.base_url, + model=model or info.model or "", + custom_llm_provider=custom_llm_provider + or info.provider + or "openai", + input=input, + ) + + def aembedding( + self, + input=[], + model: Optional[str] = None, + custom_llm_provider: Optional[str] = None, + **kwargs, + ): + from litellm import aembedding + + info = self.model_info() + return aembedding( + **kwargs, + api_key=info.api_key, + api_base=info.base_url, + model=model or info.model or "", + custom_llm_provider=custom_llm_provider + or info.provider + or "openai", + input=input, + ) diff --git a/agentrun/model/model_proxy.py b/agentrun/model/model_proxy.py index 248d210..889ee2f 100644 --- a/agentrun/model/model_proxy.py +++ b/agentrun/model/model_proxy.py @@ -19,6 +19,7 @@ import pydash from agentrun.model.api.data import BaseInfo, ModelDataAPI +from agentrun.model.api.model_api import ModelAPI from agentrun.utils.config import Config from agentrun.utils.model import Status from agentrun.utils.resource import ResourceBase @@ -40,6 +41,7 @@ class ModelProxy( ModelProxyImmutableProps, ModelProxyMutableProps, ModelProxySystemProps, + ModelAPI, ResourceBase, ): """模型服务""" @@ -399,41 +401,3 @@ def model_info(self, config: Optional[Config] = None) -> BaseInfo: ) return self._data_client.model_info() - - def completions( - self, - messages: list, - model: Optional[str] = None, - stream: bool = False, - config: Optional[Config] = None, - **kwargs, - ): - self.model_info(config) - assert self._data_client - - return self._data_client.completions( - **kwargs, - messages=messages, - model=model, - stream=stream, - config=config, - ) - - def responses( - self, - messages: list, - model: Optional[str] = None, - stream: bool = False, - config: Optional[Config] = None, - **kwargs, - ): - self.model_info(config) - assert self._data_client - - return self._data_client.responses( - **kwargs, - messages=messages, - model=model, - stream=stream, - config=config, - ) diff --git a/agentrun/model/model_service.py b/agentrun/model/model_service.py index 24f9cce..270b355 100644 --- a/agentrun/model/model_service.py +++ b/agentrun/model/model_service.py @@ -16,7 +16,8 @@ from typing import List, Optional -from agentrun.model.api.data import BaseInfo, ModelCompletionAPI +from agentrun.model.api.data import BaseInfo +from agentrun.model.api.model_api import ModelAPI from agentrun.utils.config import Config from agentrun.utils.model import PageableInput from agentrun.utils.resource import ResourceBase @@ -37,6 +38,7 @@ class ModelService( ModelServiceImmutableProps, ModelServiceMutableProps, ModelServicesSystemProps, + ModelAPI, ResourceBase, ): """模型服务""" @@ -401,38 +403,3 @@ def model_info(self, config: Optional[Config] = None) -> BaseInfo: model=default_model, headers=cfg.get_headers(), ) - - def completions( - self, - messages: list, - model: Optional[str] = None, - stream: bool = False, - **kwargs, - ): - info = self.model_info(config=kwargs.get("config")) - - m = ModelCompletionAPI( - api_key=info.api_key or "", - base_url=info.base_url or "", - model=model or info.model or self.model_service_name or "", - ) - - return m.completions(**kwargs, messages=messages, stream=stream) - - def responses( - self, - messages: list, - model: Optional[str] = None, - stream: bool = False, - **kwargs, - ): - info = self.model_info(config=kwargs.get("config")) - - m = ModelCompletionAPI( - api_key=info.api_key or "", - base_url=info.base_url or "", - model=model or info.model or self.model_service_name or "", - provider=(self.provider or "openai").lower(), - ) - - return m.responses(**kwargs, messages=messages, stream=stream) diff --git a/examples/embedding.py b/examples/embedding.py new file mode 100644 index 0000000..f2b5e5f --- /dev/null +++ b/examples/embedding.py @@ -0,0 +1,141 @@ +import os +import re +import time + +from agentrun import model +from agentrun.model import ( + BackendType, + ModelClient, + ModelService, + ModelServiceCreateInput, + ModelServiceListInput, + ModelServiceUpdateInput, +) +from agentrun.utils.exception import ( + ResourceAlreadyExistError, + ResourceNotExistError, +) +from agentrun.utils.log import logger +from agentrun.utils.model import Status + +base_url = os.getenv( + "BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1" +) +api_key = os.getenv("API_KEY", "sk-xxxxx") +model_names = re.split( + r"\s|,", os.getenv("MODEL_NAMES", "text-embedding-v1").strip() +) + + +client = ModelClient() +model_service_name = "sdk-test-embedding" + + +def create_or_get_model_service(): + """ + 为您演示如何进行创建 / 获取 + """ + logger.info("创建或获取已有的资源") + + try: + ms = client.create( + ModelServiceCreateInput( + model_service_name=model_service_name, + description="测试模型服务", + model_type=model.ModelType.EMBEDDING, + provider="openai", + provider_settings=model.ProviderSettings( + api_key=api_key, + base_url=base_url, + model_names=model_names, + ), + ) + ) + except ResourceAlreadyExistError: + logger.info("已存在,获取已有资源") + ms = client.get( + name=model_service_name, backend_type=BackendType.SERVICE + ) + + ms.wait_until_ready_or_failed() + if ms.status != Status.READY: + raise Exception(f"状态异常:{ms.status}") + + logger.info("已就绪状态,当前信息: %s", ms) + + return ms + + +def update_model_service(ms: ModelService): + """ + 为您演示如何进行更新 + """ + logger.info("更新描述为当前时间") + + # 也可以使用 client.update + ms.update( + ModelServiceUpdateInput(description=f"当前时间戳:{time.time()}"), + ) + ms.wait_until_ready_or_failed() + if ms.status != Status.READY: + raise Exception(f"状态异常:{ms.status}") + + logger.info("更新成功,当前信息: %s", ms) + + +def list_model_services(): + """ + 为您演示如何进行枚举 + """ + logger.info("枚举资源列表") + ms_arr = client.list(ModelServiceListInput(model_type=model.ModelType.LLM)) + logger.info( + "共有 %d 个资源,分别为 %s", + len(ms_arr), + [c.model_service_name for c in ms_arr], + ) + + +def delete_model_service(ms: ModelService): + """ + 为您演示如何进行删除 + """ + logger.info("开始清理资源") + # 也可以使用 client.delete / cred.delete + 轮询状态 + ms.delete_and_wait_until_finished() + + logger.info("再次尝试获取") + try: + ms.refresh() + except ResourceNotExistError as e: + logger.info("得到资源不存在报错,删除成功,%s", e) + + +def invoke_model_service(ms: ModelService): + logger.info("调用模型服务进行推理") + + result = ms.embedding(input=["你好", "今天是周几"]) + logger.info("Embedding result: %s", result) + + +def model_example(): + """ + 为您演示模型模块的基本功能 + """ + logger.info("==== 模型模块基本功能示例 ====") + logger.info(" base_url=%s", base_url) + logger.info(" api_key=%s", len(api_key) * "*") + logger.info(" model_names=%s", model_names) + + list_model_services() + ms = create_or_get_model_service() + update_model_service(ms) + + invoke_model_service(ms) + + delete_model_service(ms) + list_model_services() + + +if __name__ == "__main__": + model_example() diff --git a/tests/unittests/model/test_model_proxy.py b/tests/unittests/model/test_model_proxy.py index fc88b01..6f0890e 100644 --- a/tests/unittests/model/test_model_proxy.py +++ b/tests/unittests/model/test_model_proxy.py @@ -527,23 +527,25 @@ class TestModelProxyCompletions: "AGENTRUN_ACCOUNT_ID": "test-account", }, ) - def test_completions(self): + @patch("litellm.completion") + def test_completions(self, mock_completion): from agentrun.model.api.data import BaseInfo + mock_completion.return_value = {"choices": []} + proxy = ModelProxy(model_proxy_name="test-proxy") - # Create a mock _data_client directly + # Create a mock _data_client to provide model_info mock_data_client = MagicMock() mock_info = BaseInfo(model="gpt-4", base_url="https://api.example.com") mock_data_client.model_info.return_value = mock_info - mock_data_client.completions.return_value = {"choices": []} - # Bypass the model_info call by setting _data_client + # Set _data_client so model_info() returns our mock info proxy._data_client = mock_data_client proxy.completions(messages=[{"role": "user", "content": "Hello"}]) - mock_data_client.completions.assert_called_once() + mock_completion.assert_called_once() class TestModelProxyResponses: @@ -557,20 +559,24 @@ class TestModelProxyResponses: "AGENTRUN_ACCOUNT_ID": "test-account", }, ) - def test_responses(self): + @patch("litellm.responses") + def test_responses(self, mock_responses): from agentrun.model.api.data import BaseInfo + mock_responses.return_value = {} + proxy = ModelProxy(model_proxy_name="test-proxy") - # Create a mock _data_client directly + # Create a mock _data_client to provide model_info mock_data_client = MagicMock() mock_info = BaseInfo(model="gpt-4", base_url="https://api.example.com") mock_data_client.model_info.return_value = mock_info - mock_data_client.responses.return_value = {} - # Bypass the model_info call by setting _data_client + # Set _data_client so model_info() returns our mock info proxy._data_client = mock_data_client - proxy.responses(messages=[{"role": "user", "content": "Hello"}]) + # Note: The responses method expects 'input' parameter (not 'messages') + # based on the ModelAPI.responses signature + proxy.responses(input="Hello") - mock_data_client.responses.assert_called_once() + mock_responses.assert_called_once()