diff --git a/appbuilder/core/components/asr/component.py b/appbuilder/core/components/asr/component.py index 94ea8d70..e5ce8507 100644 --- a/appbuilder/core/components/asr/component.py +++ b/appbuilder/core/components/asr/component.py @@ -14,6 +14,7 @@ r"""ASR component. """ +import os import uuid import json @@ -58,7 +59,7 @@ class ASR(Component): "parameters": { "type": "object", "properties": { - "url": { + "file_url": { "type": "string", "description": "输入语音文件的url,根据url获取到语音文件" }, @@ -67,8 +68,17 @@ class ASR(Component): "description": "待识别语音文件名,用于生成获取语音的url" } }, - "required": [ - "url" + "anyOf": [ + { + "required": [ + "file_url" + ] + }, + { + "required": [ + "file_name" + ] + } ] } } @@ -155,20 +165,20 @@ def tool_eval(self, name: str, streaming: bool, **kwargs): """ asr for function call """ - url_key = kwargs.get("url", None) - file_urls = kwargs.get("file_urls", {}) - if not url_key: - url_key = kwargs.get("file_name", None) - if utils.is_url(url_key): - url = url_key - else: - url = file_urls.get(url_key, None) - if not url: - raise InvalidRequestArgumentError(f"file {url_key} url does not exist") + file_url = kwargs.get("file_url", None) + if not file_url: + file_urls = kwargs.get("file_urls", {}) + file_path = kwargs.get("file_name", None) + if not file_path: + raise InvalidRequestArgumentError("file name is not set") + file_name = os.path.basename(file_path) + file_url = file_urls.get(file_name, None) + if not file_url: + raise InvalidRequestArgumentError(f"file {file_url} url does not exist") req = ShortSpeechRecognitionRequest() req.cuid = str(uuid.uuid4()) req.dev_pid = "80001" - req.speech = requests.get(url).content + req.speech = requests.get(file_url).content req.format = "pcm" req.rate = 16000 result = proto.Message.to_dict(self._recognize(req)) diff --git a/appbuilder/core/components/general_ocr/component.py b/appbuilder/core/components/general_ocr/component.py index 606ce148..82a7be5d 100644 --- a/appbuilder/core/components/general_ocr/component.py +++ b/appbuilder/core/components/general_ocr/component.py @@ -13,6 +13,7 @@ r"""general ocr component.""" import base64 import json +import os.path from appbuilder.core import utils from appbuilder.core._client import HTTPClient @@ -55,17 +56,26 @@ class GeneralOCR(Component): "parameters": { "type": "object", "properties": { - "url": { + "img_url": { "type": "string", "description": "待识别图片的url,根据该url能够获取图片" }, - "file_name": { + "img_name": { "type": "string", "description": "待识别图片的文件名,用于生成图片url" - } + }, }, - "required": [ - "url" + "anyOf": [ + { + "required": [ + "img_url" + ] + }, + { + "required": [ + "img_name" + ] + } ] } } @@ -141,17 +151,17 @@ def tool_eval(self, name: str, streaming: bool, **kwargs): """ general_ocr for function call """ - url_key = kwargs.get("url", None) - file_urls = kwargs.get("file_urls", {}) - if not url_key: - url_key = kwargs.get("file_name", None) - if utils.is_url(url_key): - url = url_key - else: - url = file_urls.get(url_key, None) - if not url: - raise InvalidRequestArgumentError(f"file {url_key} url does not exist") - req = GeneralOCRRequest(url=url) + img_url = kwargs.get("img_url", None) + if not img_url: + file_urls = kwargs.get("file_urls", {}) + img_path = kwargs.get("img_name", None) + if not img_path: + raise InvalidRequestArgumentError("file name is not set") + img_name = os.path.basename(img_path) + img_url = file_urls.get(img_name, None) + if not img_url: + raise InvalidRequestArgumentError(f"file {img_name} url does not exist") + req = GeneralOCRRequest(url=img_url) result = proto.Message.to_dict(self._recognize(req)) results = { "识别结果": " \n".join(item["words"] for item in result["words_result"]) diff --git a/appbuilder/core/components/object_recognize/component.py b/appbuilder/core/components/object_recognize/component.py index 16c6ae5e..405fafb7 100644 --- a/appbuilder/core/components/object_recognize/component.py +++ b/appbuilder/core/components/object_recognize/component.py @@ -14,6 +14,7 @@ import base64 import json +import os from appbuilder.core import utils from appbuilder.core._client import HTTPClient @@ -52,17 +53,26 @@ class ObjectRecognition(Component): "parameters": { "type": "object", "properties": { - "url": { + "img_url": { "type": "string", "description": "待识别图片的url,根据该url能够获取图片" }, - "file_name": { + "img_name": { "type": "string", "description": "待识别图片的文件名,用于生成图片url" } }, - "required": [ - "url" + "anyOf": [ + { + "required": [ + "img_url" + ] + }, + { + "required": [ + "img_name" + ] + } ] } } @@ -139,18 +149,18 @@ def tool_eval(self, name: str, streaming: bool, **kwargs): """ object_recognize for function call """ - url_key = kwargs.get("url", None) - file_urls = kwargs.get("file_urls", {}) - if not url_key: - url_key = kwargs.get("file_name", None) - if utils.is_url(url_key): - url = url_key - else: - url = file_urls.get(url_key, None) - if not url: - raise InvalidRequestArgumentError(f"file {url_key} url does not exist") + img_url = kwargs.get("img_url", None) + if not img_url: + file_urls = kwargs.get("file_urls", {}) + img_path = kwargs.get("img_name", None) + if not img_path: + raise InvalidRequestArgumentError("file name is not set") + img_name = os.path.basename(img_path) + img_url = file_urls.get(img_name, None) + if not img_url: + raise InvalidRequestArgumentError(f"file {img_name} url does not exist") score_threshold = kwargs.get("score_threshold", 0.5) - req = ObjectRecognitionRequest(url=url) + req = ObjectRecognitionRequest(url=img_url) result = proto.Message.to_dict(self._recognize(req)) results = [] for item in result["result"]: diff --git a/appbuilder/tests/test_asr.py b/appbuilder/tests/test_asr.py index 744d6bf4..a6ab77eb 100644 --- a/appbuilder/tests/test_asr.py +++ b/appbuilder/tests/test_asr.py @@ -132,7 +132,7 @@ def test_check_service_error(self): def test_tool_eval_valid(self): """测试 tool 方法对有效请求的处理。""" - result = self.asr.tool_eval(name="asr", streaming=True, url=self.audio_file_url) + result = self.asr.tool_eval(name="asr", streaming=True, file_url=self.audio_file_url) res = [item for item in result] self.assertNotEqual(len(res), 0) diff --git a/appbuilder/tests/test_general_ocr.py b/appbuilder/tests/test_general_ocr.py index 147f2271..b0e822ec 100644 --- a/appbuilder/tests/test_general_ocr.py +++ b/appbuilder/tests/test_general_ocr.py @@ -138,7 +138,7 @@ def test_tool_eval_valid(self): "authorization=bce-auth-v1%2FALTAKGa8m4qCUasgoljdEDAzLm%2F2024-01-" \ "11T10%3A59%3A17Z%2F-1%2Fhost%2F081bf7bcccbda5207c82a4de074628b04ae" \ "857a27513734d765495f89ffa5f73" - result = self.general_ocr.tool_eval(name="general_ocr", streaming=True, url=image_url) + result = self.general_ocr.tool_eval(name="general_ocr", streaming=True, img_url=image_url) res = [item for item in result] self.assertNotEqual(len(res), 0) diff --git a/appbuilder/tests/test_object_recognize.py b/appbuilder/tests/test_object_recognize.py index 1962bb93..e022de08 100644 --- a/appbuilder/tests/test_object_recognize.py +++ b/appbuilder/tests/test_object_recognize.py @@ -138,7 +138,7 @@ def test_tool_eval_valid(self): "authorization=bce-auth-v1%2FALTAKGa8m4qCUasgoljdEDAzLm%2F2024-01-" \ "11T11%3A00%3A19Z%2F-1%2Fhost%2F2c31bf29205f61e58df661dc80af31a1dc" \ "1ba1de0a8f072bc5a87102bd32f9e3" - result = self.object_recognition.tool_eval(name="object_recognition", streaming=True, url=image_url) + result = self.object_recognition.tool_eval(name="object_recognition", streaming=True, img_url=image_url) res = [item for item in result] self.assertNotEqual(len(res), 0) diff --git a/appbuilder/tests/test_translate.py b/appbuilder/tests/test_translate.py index 461367c3..5e643b6c 100644 --- a/appbuilder/tests/test_translate.py +++ b/appbuilder/tests/test_translate.py @@ -22,7 +22,7 @@ def test_run_invalid_request(self): def test_tool_eval_valid(self): """测试 tool 方法对有效请求的处理。""" - result = self.translation.tool_eval(name="translation", streaming=True, q="你好", to_lang="en") + result = self.translation.tool_eval(name="translation", streaming=True, q="你好\n中国", to_lang="en") res = [item for item in result] self.assertNotEqual(len(res), 0)