diff --git a/.gitignore b/.gitignore index 548dbbfb..af8593a7 100644 --- a/.gitignore +++ b/.gitignore @@ -51,6 +51,7 @@ coverage.xml .pytest_cache/ # Unit test / production asr_test.pcm +components_error_info.txt # Translations *.mo diff --git a/appbuilder/__init__.py b/appbuilder/__init__.py index e5e4e998..7162b192 100644 --- a/appbuilder/__init__.py +++ b/appbuilder/__init__.py @@ -114,7 +114,56 @@ def get_default_header(): from .core.components.image_understand.component import ImageUnderstand from .core.components.mix_card_ocr.component import MixCardOCR -from .tests.component_test import AppbuilderTestToolEval, AutomaticTestToolEval +__COMPONENTS__ = [ + "RagWithBaiduSearchPro", + "RAGWithBaiduSearch", + "Excel2Figure", + "MRC", + "OralQueryGeneration", + "QAPairMining", + "SimilarQuestion", + "StyleWriting", + "StyleRewrite", + "TagExtraction", + "Nl2pandasComponent", + "QueryRewrite", + "DialogSummary", + "HallucinationDetection", + "Playground", + "ASR", + "GeneralOCR", + "ObjectRecognition", + "Text2Image", + "LandmarkRecognition", + "TTS", + "ExtractTableFromDoc", + "DocParser", + "DocSplitter", + "BESRetriever", + "BESVectorStoreIndex", + "BaiduVDBVectorStoreIndex", + "BaiduVDBRetriever", + "TableParams", + "Reranker", + "PPTGenerationFromInstruction", + "PPTGenerationFromPaper", + "PPTGenerationFromFile", + "DishRecognition", + "Translation", + "AnimalRecognition", + "DocCropEnhance", + "QRcodeOCR", + "TableOCR", + "DocFormatConverter", + "Embedding", + "Matching", + "NL2Sql", + "SelectTable", + "PlantRecognition", + "HandwriteOCR", + "ImageUnderstand", + "MixCardOCR", +] # NOQA from appbuilder.core.message import Message from appbuilder.core.agent import AgentRuntime @@ -149,7 +198,6 @@ def get_default_header(): __all__ = [ 'logger', - 'BadRequestException', 'ForbiddenException', 'NotFoundException', @@ -157,68 +205,13 @@ def get_default_header(): 'InternalServerErrorException', 'HTTPConnectionException', 'AppBuilderServerException', - - 'StyleWriting', - 'MRC', - 'Playground', - 'OralQueryGeneration', - 'QAPairMining', - 'SimilarQuestion', - 'IsComplexQuery', - 'QueryDecomposition', - 'TagExtraction', - 'StyleRewrite', - 'QueryRewrite', - 'DialogSummary', - 'ASR', - 'GeneralOCR', - 'ObjectRecognition', - 'Text2Image', - 'LandmarkRecognition', - 'TTS', - "ExtractTableFromDoc", - "DocParser", - "ParserConfig", - "DocSplitter", - "BESRetriever", - "BESVectorStoreIndex", - "BaiduVDBVectorStoreIndex", - "BaiduVDBRetriever", - "TableParams", - "Reranker", - "HallucinationDetection", - - 'DishRecognition', - 'Translation', - 'Message', - 'AnimalRecognition', - 'DocCropEnhance', - 'QRcodeOCR', - 'TableOCR', - - 'Embedding', - - 'Matching', - - "PlantRecognition", - "HandwriteOCR", - "ImageUnderstand", - "MixCardOCR", - - 'PPTGenerationFromInstruction', - 'PPTGenerationFromPaper', - 'PPTGenerationFromFile', - 'AppbuilderTestToolEval', 'AutomaticTestToolEval', - "get_model_list", - "AppBuilderClient", "AgentBuilder", "get_app_list", "get_all_apps", - "KnowledgeBase", "CustomProcessRule", "DocumentSource", @@ -227,12 +220,10 @@ def get_default_header(): "DocumentSeparator", "DocumentPattern", "DocumentProcessOption", - "assistant", "StreamRunContext", "AssistantEventHandler", "AssistantStreamManager", - "AppBuilderTracer", "AppbuilderInstrumentor" -] +] + __COMPONENTS__ diff --git a/appbuilder/tests/component_check.py b/appbuilder/tests/component_check.py new file mode 100644 index 00000000..64ff4b06 --- /dev/null +++ b/appbuilder/tests/component_check.py @@ -0,0 +1,245 @@ +# Copyright (c) 2024 Baidu, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import json +import inspect +from pydantic import BaseModel +from appbuilder.utils.func_utils import Singleton +from appbuilder.utils.json_schema_to_model import json_schema_to_pydantic_model + + +class CheckInfo(BaseModel): + check_rule_name: str + check_result: bool + check_detail: str + + +class RuleBase(object): + def __init__(self): + self.invalid = False + + def check(self, component_cls) -> CheckInfo: + raise NotImplementedError + + def reset_state(self): + self.invalid = False + + +class ComponentCheckBase(metaclass=Singleton): + def __init__(self): + self.rules = {} + + def register_rule(self, rule_name: str, rule_obj: RuleBase): + if not isinstance(rule_obj, RuleBase): + raise TypeError("rule_obj must be a subclass of RuleBase") + if rule_name in self.rules: + raise ValueError(f"Rule {rule_name} already exists.") + self.rules[rule_name] = rule_obj + + def remove_rule(self, rule_name: str): + del self.rules[rule_name] + + def notify(self, component_cls) -> tuple[bool, list]: + check_pass = True + check_details = {} + reasons = [] + for rule_name, rule_obj in self.rules.items(): + res = rule_obj.check(component_cls) + check_details[rule_name] = res + if res.check_result == False: + check_pass = False + reasons.append(res.check_detail) + + if check_pass: + return True, reasons + else: + return False, reasons + +def register_component_check_rule(rule_name: str, rule_cls: RuleBase): + component_checker = ComponentCheckBase() + component_checker.register_rule(rule_name, rule_cls()) + + + +class ManifestValidRule(RuleBase): + """ + 通过尝试将component的manifest转换为pydantic模型来检查manifest是否符合规范 + """ + def __init__(self): + super().__init__() + self.rule_name = "ManifestValidRule" + + def check(self, component_cls) -> CheckInfo: + check_pass_flag = True + invalid_details = [] + + + try: + if not hasattr(component_cls, "manifests"): + raise ValueError("No manifests found") + manifests = component_cls.manifests + # NOTE(暂时检查manifest中的第一个mainfest) + if not manifests or len(manifests) == 0: + raise ValueError("No manifests found") + manifest = manifests[0] + tool_name = manifest['name'] + tool_desc = manifest['description'] + schema = manifest["parameters"] + schema["title"] = tool_name + # 第一步,将json schema转换为pydantic模型 + pydantic_model = json_schema_to_pydantic_model(schema, tool_name) + check_to_json = pydantic_model.schema_json() + json_to_dict = json.loads(check_to_json) + except Exception as e: + print(e) + check_pass_flag = False + invalid_details.append(str(e)) + + if len(invalid_details) > 0: + invalid_details = ",".join(invalid_details) + else: + invalid_details = "" + return CheckInfo( + check_rule_name=self.rule_name, + check_result=check_pass_flag, + check_detail=invalid_details) + +Data_Type = { + 'string': str, + 'integer': int, + 'object': int, + 'array': list, + 'boolean': bool, + 'null': None, +} + +class MainfestMatchToolEvalRule(RuleBase): + def __init__(self): + super().__init__() + self.rule_name = "MainfestMatchToolEvalRule" + + + def check(self, component_cls) -> CheckInfo: + check_pass_flag = True + invalid_details = [] + + try: + if not hasattr(component_cls, "manifests"): + raise ValueError("No manifests found") + manifests = component_cls.manifests + # NOTE(暂时检查manifest中的第一个mainfest) + if not manifests or len(manifests) == 0: + raise ValueError("No manifests found") + manifest = manifests[0] + properties = manifest['parameters']['properties'] + required_params = [] + anyOf = manifest['parameters'].get('anyOf', None) + if anyOf: + for anyOf_dict in anyOf: + required_params += anyOf_dict['required'] + if not anyOf: + required_params += manifest['parameters']['required'] + + + # 交互检查 + tool_eval_input_params = [] + print("required_params: {}".format(required_params)) + signature = inspect.signature(component_cls.tool_eval) + ileagal_params = [] + for param_name, param in signature.parameters.items(): + if param_name == 'kwargs' or param_name == 'args' or param_name == 'self': + continue + if param_name not in required_params: + check_pass_flag = False + ileagal_params.append(param_name) + + if len(ileagal_params) > 0: + invalid_details.append("tool_eval 参数 {} 不在 mainfest 参数列表中".format(",".join(ileagal_params))) + + ileagal_params =[] + for required_param in required_params: + if required_param not in tool_eval_input_params: + check_pass_flag = False + ileagal_params.append(required_param) + if len(ileagal_params) > 0: + invalid_details.append("mainfest 参数 {} 不在 tool_eval 参数列表中".format(",".join(ileagal_params))) + + return CheckInfo( + check_rule_name=self.rule_name, + check_result=check_pass_flag, + check_detail=",".join(invalid_details)) + + except Exception as e: + check_pass_flag = False + invalid_details.append(str(e)) + return CheckInfo( + check_rule_name=self.rule_name, + check_result=check_pass_flag, + check_detail=",".join(invalid_details)) + + + + + + +class ToolEvalInputNameRule(RuleBase): + """ + 检查tool_eval的输入参数中,是否包含系统保留的输入名称 + """ + def __init__(self): + super().__init__() + self.rule_name = 'ToolEvalInputNameRule' + self.system_input_name = [ + "_sys_name", + "_sys_origin_query", + "_sys_user_instruction", + "_sys_file_names", + "_sys_file_urls", + "_sys_current_time", + "_sys_chat_history", + "_sys_used_tool", + "_sys_uid", + "_sys_traceid", + "_sys_conversation_id", + "_sys_gateway_endpoint", + "_sys_appbuiler_token", + "_sys_debug", + "_sys_custom_variables", + "_sys_thought_model_config", + "_sys_rag_model_config", + ] + + def check(self, component_cls) -> CheckInfo: + tool_eval_signature = inspect.signature(component_cls.__init__) + params = tool_eval_signature.parameters + invalid_details = [] + check_pass_flag = True + for param_name in params: + if param_name == 'self': + continue + if param_name in self.system_input_name: + invalid_details.append(param_name) + check_pass_flag = False + + + return CheckInfo( + check_rule_name=self.rule_name, + check_result=check_pass_flag, + check_detail="以下ToolEval方法参数名称是系统保留字段,请更换:{}".format(",".join(invalid_details)) if len(invalid_details) > 0 else "") + + + +register_component_check_rule("ManifestValidRule", ManifestValidRule) +register_component_check_rule("MainfestMatchToolEvalRule", MainfestMatchToolEvalRule) +register_component_check_rule("ToolEvalInputNameRule", ToolEvalInputNameRule) \ No newline at end of file diff --git a/appbuilder/tests/component_collector.py b/appbuilder/tests/component_collector.py new file mode 100644 index 00000000..00c015db --- /dev/null +++ b/appbuilder/tests/component_collector.py @@ -0,0 +1,102 @@ +# Copyright (c) 2024 Baidu, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import appbuilder + +# SKIP名单中的组件无需检查,直接跳过 +SKIP_COMPONENTS = [ +] + +# 白名单中的组件因历史原因,检查失败,但可以正常使用,因此加入白名单 +COMPONENT_WHITE_LIST = [ + "RagWithBaiduSearchPro", + "RAGWithBaiduSearch", + "Excel2Figure", + "MRC", + "OralQueryGeneration", + "QAPairMining", + "SimilarQuestion", + "StyleWriting", + "StyleRewrite", + "TagExtraction", + "Nl2pandasComponent", + "QueryRewrite", + "DialogSummary", + "HallucinationDetection", + "Playground", + "ASR", + "GeneralOCR", + "ObjectRecognition", + "Text2Image", + "LandmarkRecognition", + "TTS", + "ExtractTableFromDoc", + "DocParser", + "DocSplitter", + "BESRetriever", + "BESVectorStoreIndex", + "BaiduVDBVectorStoreIndex", + "BaiduVDBRetriever", + "TableParams", + "Reranker", + "PPTGenerationFromInstruction", + "PPTGenerationFromPaper", + "PPTGenerationFromFile", + "DishRecognition", + "Translation", + "AnimalRecognition", + "DocCropEnhance", + "QRcodeOCR", + "TableOCR", + "DocFormatConverter", + "Embedding", + "Matching", + "NL2Sql", + "SelectTable", + "PlantRecognition", + "HandwriteOCR", + "ImageUnderstand", + "MixCardOCR", +] + + +def get_component_white_list(): + return COMPONENT_WHITE_LIST + +def get_all_components(): + from appbuilder import __COMPONENTS__ + + components = {} + for component in __COMPONENTS__: + if component in SKIP_COMPONENTS: + continue + + try: + component_obj = eval("appbuilder."+component) + components[component]= { + "obj": component_obj, + "import_error": "" + } + except Exception as e: + print("Component: {} import with error: {}".format(component, str(e))) + components[component]= { + "obj": None, + "import_error": str(e) + } + + return components + +if __name__ == '__main__': + all_components = get_all_components() + print(all_components) \ No newline at end of file diff --git a/appbuilder/tests/component_test.py b/appbuilder/tests/component_test.py deleted file mode 100644 index 7f4a141e..00000000 --- a/appbuilder/tests/component_test.py +++ /dev/null @@ -1,296 +0,0 @@ -import requests -import types -import re -import inspect - -from typing import TypeVar, Generic, Union, Type -from appbuilder.core._exception import * -from unittest.mock import Mock -from appbuilder.core import components -from appbuilder.core._session import InnerSession - - -Data_Type = { - 'string': str, - 'integer': int, - 'object': int, - 'array': list, - 'boolean': bool, - 'null': None, -} - -class AppbuilderTestToolEval: - """ - 功能:Components组件模拟post本地运行。 - - 使用方法: - - ```python - # 实例化一个 - image_understand = appbuilder.ImageUnderstand() - - # 设计一个符合规范的tool_eval input(dict数据类型) - tool_eval_input = { - 'streaming': True, - 'traceid': 'traceid', - 'name':"image_understand", - 'img_url':'img_url_str', - 'origin_query':"" - } - - # 设计一个组件API接口预期的response - mock_response_data = { - 'result': {'task_id': '1821485837570181996'}, - 'log_id': 1821485837570181996, - } - mock_response = Mock() - mock_response.status_code = 200 - mock_response.headers = {'Content-Type': 'application/json'} - def mock_json(): - return mock_response_data - mock_response.json = mock_json - - # 实例化一个AppbuilderTestToolEval对象,实现组件本地的自动化测试 - appbuilder.AppbuilderTestToolEval(appbuilder_components=image_understand, - tool_eval_input=tool_eval_input, - response=mock_response) - ``` - """ - def __init__(self, appbuilder_components:components, tool_eval_input:dict, response:requests.Response): - """ - 初始化函数。 - - Args: - appbuilder_components (components): 应用构建器组件对象。 - tool_eval_input (dict): tool_eval的传入参数。 - response (dict): api预期的response返回值。 - - Returns: - None - """ - self.component = appbuilder_components - self.tool_eval_input = tool_eval_input - self.response = response - self.test_manifests() - self.test_tool_eval_input() - self.test_tool_eval_generator() - if hasattr(self.component, '__module__'): - module_name = self.component.__module__ - if re.match(r'appbuilder\.', module_name): - self.test_tool_eval_reponse_raise() - self.test_tool_eval_text_str() - - def test_manifests(self): - """ - 校验组件成员变量manifests是否符合规范。 - - Args: - 无参数。 - - Returns: - 无返回值。 - Raises: - AppbuilderBuildexException: 校验不通过时抛出异常。 - """ - manifests = self.component.manifests - try: - assert isinstance(manifests, list) - assert len(manifests) > 0 - assert isinstance(manifests[0],dict) - assert isinstance(manifests[0]['name'], str) - assert isinstance(manifests[0]['description'], str) - assert isinstance(manifests[0]['parameters'], dict) - except Exception as e: - raise AppbuilderBuildexException(f'请检查{self.component}组件是否存在成员变量manifests或manifests成员变量定义规范, 错误信息:{e}') - - def test_tool_eval_input(self): - """ - 校验tool_eval的传入参数是否合法。 - - Args: - 无参数。 - - Returns: - 无返回值。 - - Raises: - AppbuilderBuildexException: 校验不通过时抛出异常。 - - """ - if not self.tool_eval_input.get('streaming',None): - raise AppbuilderBuildexException(f'请检查{self.component}组件tool_eval的传入参数是否定义streaming') - if hasattr(self.component, '__module__'): - module_name = self.component.__module__ - if re.match(r'appbuilder\.', module_name): - if not self.tool_eval_input.get('traceid',None): - raise AppbuilderBuildexException(f'请检查{self.component}组件tool_eval的传入参数是否有traceid') - try: - manifests = self.component.manifests[0] - parameters = manifests['parameters'] - properties = parameters['properties'] - except: - raise AppbuilderBuildexException(f'请检查{self.component}组件是否存在成员变量manifests或manifests成员变量定义规范') - anyOf = parameters.get('anyOf',None) - if anyOf: - anyOf_test = False - for anyOf_requried_dict in anyOf: - anyOf_requried = anyOf_requried_dict.get('required',None) - if anyOf_requried: - success_number = 0 - for anyOf_requried_data in anyOf_requried: - try: - input_data = self.tool_eval_input[anyOf_requried_data] - input_data_type = Data_Type[properties[anyOf_requried_data]['type']] - if anyOf_requried_data in self.tool_eval_input and isinstance(input_data, input_data_type): - success_number += 1 - except: - pass - if success_number == len(anyOf_requried): - anyOf_test = True - if not anyOf_test: - raise AppbuilderBuildexException(f'请检查{self.component}组件tool_eval的传入参数是否正确或manifests的参数定义是否正确') - - if not anyOf: - un_anyOf_test = False - requried = parameters.get('required',None) - if requried: - success_number = 0 - for requried_data in requried: - try: - input_data = self.tool_eval_input[requried_data] - input_data_type = Data_Type[properties[requried_data]['type']] - if requried_data in self.tool_eval_input and isinstance(input_data, input_data_type): - success_number += 1 - except: - pass - if success_number == len(requried): - un_anyOf_test = True - if not un_anyOf_test: - raise AppbuilderBuildexException(f'请检查{self.component}组件tool_eval的传入参数是否正确或manifests的参数定义是否正确') - - def test_tool_eval_reponse_raise(self): - """ - Args: - 无参数 - - Returns: - 无返回值 - - Raises: - AppbuilderBuildexException: 如果响应头状态码对应的异常类型与捕获到的异常类型不一致,则抛出此异常。 - - 功能:测试tool_eval方法在不同响应头状态码下的异常抛出情况。 - - 首先,设置响应头状态码为bad_request,并模拟InnerSession.post方法的返回值。 - 然后,定义一个状态码与异常类型的映射字典test_status_code_dict,用于测试不同状态码下抛出的异常类型是否正确。 - 接着,遍历test_status_code_dict字典,将状态码和异常类型分别赋值给self.response.status_code和error变量,并重新模拟InnerSession.post方法的返回值。 - 在每次循环中,调用self.component.tool_eval方法,并捕获可能抛出的异常。 - 如果捕获到的异常类型与test_status_code_dict字典中对应状态码的异常类型一致,则继续下一次循环; - 否则,抛出AppbuilderBuildexException异常,提示用户检查self.component组件tool_eval方法的response返回值是否添加了check_response_header检测。 - """ - # test_response_head_status - self.response.status_code = requests.codes.bad_request - InnerSession.post = Mock(return_value=self.response) - test_status_code_dict = { - requests.codes.bad_request: BadRequestException, - requests.codes.forbidden: ForbiddenException, - requests.codes.not_found: ForbiddenException, - requests.codes.precondition_required: PreconditionFailedException, - requests.codes.internal_server_error: InternalServerErrorException - } - for status_code,error in test_status_code_dict.items(): - self.response.status_code = status_code - InnerSession.post = Mock(return_value=self.response) - try: - self.component.tool_eval(**self.tool_eval_input) - except Exception as e: - if isinstance(e,error): - raise AppbuilderBuildexException(f'请检查{self.component}组件tool_eval的response返回值是否添加check_response_header检测') - - def test_tool_eval_generator(self): - """ - 测试组件tool_eval方法返回是否为生成器 - - Args: - 无 - - Returns: - 无 - - Raises: - AppbuilderBuildexException: 如果组件tool_eval的返回值不为生成器时抛出异常 - """ - self.response.status_code = requests.codes.ok - InnerSession.post = Mock(return_value=self.response) - result_generator = self.component.tool_eval(**self.tool_eval_input) - if not result_generator: - raise AppbuilderBuildexException(f'请检查{self.component}组件tool_eval的返回值是否为生成器') - if not isinstance(result_generator, types.GeneratorType): - raise AppbuilderBuildexException(f'请检查{self.component}组件tool_eval的返回值是否为生成器') - - def test_tool_eval_text_str(self): - """ - 测试tool_eval方法返回值的文本是否为字符串类型 - - Args: - 无 - - Returns: - 无返回值,该函数主要进行断言测试 - - Raises: - AppbuilderBuildexException: 当tool_eval方法返回的文本不是字符串类型时抛出异常 - """ - self.response.status_code = requests.codes.ok - InnerSession.post = Mock(return_value=self.response) - result_generator = self.component.tool_eval(**self.tool_eval_input) - for res in result_generator: - if not isinstance(res.get("text",""),str): - raise AppbuilderBuildexException(f'请检查{self.component}组件tool_eval的返回值是否为字符串') - -class AutomaticTestToolEval: - def __init__(self, appbuilder_components:components): - self.components = appbuilder_components - self.test_input() - - def test_input(self): - manifest = self.components.manifests[0] - properties = manifest['parameters']['properties'] - required_params = [] - anyOf = manifest['parameters'].get('anyOf', None) - if anyOf: - for anyOf_dict in anyOf: - required_params += anyOf_dict['required'] - if not anyOf: - required_params += manifest['parameters']['required'] - required_param_dict = { - 'name':str, - 'streaming':bool - } - - for param in required_params: - required_param_dict[param] = Data_Type[properties[param]['type']] - required_params = [] - for param in required_param_dict.keys(): - required_params.append(param) - - # 交互检查 - tool_eval_input_params = [] - signature = inspect.signature(self.components.tool_eval) - for param_name, param in signature.parameters.items(): - if param_name == 'kwargs': - continue - if param_name in required_params: - if required_param_dict[param_name] == param.annotation: - tool_eval_input_params.append(param_name) - else: - raise AppbuilderBuildexException(f'请检查tool_eval的传入参数{param_name}是否符合成员变量manifest的参数类型要求') - else: - raise AppbuilderBuildexException(f'请检查tool_eval的传入参数{param_name}是否在成员变量manifest要求内') - - for required_param in required_params: - if required_param not in tool_eval_input_params: - raise AppbuilderBuildexException(f'请检查成员变量manifest要求的tool_eval的传入参数{required_param}是否在其中') - - - \ No newline at end of file diff --git a/appbuilder/tests/run_python_test.sh b/appbuilder/tests/run_python_test.sh index 2f0e6de4..00f082fc 100644 --- a/appbuilder/tests/run_python_test.sh +++ b/appbuilder/tests/run_python_test.sh @@ -97,13 +97,14 @@ echo "单测运行结果: $run_result" echo "单测覆盖率结果: $cover_result" echo "--------------------------" +echo "--------------------------" +echo "Components组件检查规范性检测结果: " +python3 print_components_error_info.py +echo "--------------------------" + # 若单测失败,则退出 if [ $run_result -ne 0 ]; then echo "单测运行失败,请检查错误日志,修复单测后重试" && exit 1; fi if [ $cover_result -ne 0 ]; then echo "增量代码的单元测试覆盖率低于90%,请完善单元测试后重试" && exit 1; fi -echo "--------------------------" -echo "Components组件检查规范性检测结果: " -python3 print_components_error_info.py -echo "--------------------------" \ No newline at end of file diff --git a/appbuilder/tests/test_all_components.py b/appbuilder/tests/test_all_components.py new file mode 100644 index 00000000..49b1b0af --- /dev/null +++ b/appbuilder/tests/test_all_components.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024 Baidu, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import unittest +import inspect +import appbuilder +import importlib.util +import numpy as np +import pandas as pd + +from appbuilder.core.component import Component +from appbuilder.core.components.llms.base import CompletionBaseComponent +from appbuilder.core._exception import AppbuilderBuildexException +from component_collector import get_all_components, get_component_white_list +from appbuilder.tests.component_check import ComponentCheckBase + + +def write_error_data(error_df,error_stats): + txt_file_path = 'components_error_info.txt' + with open(txt_file_path, 'w') as file: + file.write("Component Name\tError Message\n") + for _, row in error_df.iterrows(): + file.write(f"{row['Component Name']}\t{row['Error Message']}\n") + file.write("\n错误统计信息:\n") + for error, count in error_stats.items(): + file.write(f"错误信息: {error}, 出现次数: {count}\n") + print(f"\n错误信息已写入: {txt_file_path}") + +# @unittest.skipUnless(os.getenv("TEST_CASE", "UNKNOWN") == "CPU_PARALLEL", "") +class TestComponentManifestsAndToolEval(unittest.TestCase): + def setUp(self) -> None: + self.all_components = get_all_components() + self.whitelist_components = get_component_white_list() + self.component_check_base = ComponentCheckBase() + + def test_component(self): + error_data = [] + error_stats ={} + for name, import_res in self.all_components.items(): + + if import_res["import_error"] != "": + error_data.append({"Component Name": name, "Error Message": str(e)}) + print("组件名称:{} 错误信息:{}".format(name, import_res["import_error"])) + continue + + component_obj = import_res["obj"] + try: + pass_check, reasons = self.component_check_base.notify(component_obj) + reasons = list(set(reasons)) + if not pass_check: + error_data.append({"Component Name": name, "Error Message": ", ".join(reasons)}) + print("组件名称:{} 错误信息:{}".format(name, ", ".join(reasons))) + except Exception as e: + error_data.append({"Component Name": name, "Error Message": str(e)}) + print("组件名称:{} 错误信息:{}".format(name, str(e))) + + + error_df = pd.DataFrame(error_data) if len(error_data) > 0 else None + + if error_df is not None: + print("\n错误信息表格:") + print(error_df) + # 使用 NumPy 进行统计 + unique_errors, counts = np.unique(error_df["Error Message"], return_counts=True) + error_stats = dict(zip(unique_errors, counts)) + print("\n错误统计信息:") + for error, count in error_stats.items(): + print(f"错误信息: {error}, 出现次数: {count}") + # 将报错信息写入文件 + write_error_data(error_df, error_stats) + + # 判断报错组件是否位于白名单中 + component_names = error_df["Component Name"].tolist() + for component_name in component_names: + if component_name in self.whitelist_components: + print("{}zu白名单中,暂时忽略报错。".format(component_name)) + else: + raise AppbuilderBuildexException(f"组件 {component_name} 未在白名单中,请检查是否需要添加到白名单。") + + else: + print("\n所有组件测试通过,无错误信息。") + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/appbuilder/tests/test_core_commponents_tool_eval.py b/appbuilder/tests/test_core_commponents_tool_eval.py deleted file mode 100644 index 8605ec7e..00000000 --- a/appbuilder/tests/test_core_commponents_tool_eval.py +++ /dev/null @@ -1,184 +0,0 @@ -# Copyright (c) 2024 Baidu, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -import unittest -import inspect -import appbuilder -import importlib.util -import numpy as np -import pandas as pd - -from appbuilder.core.component import Component -from appbuilder.core.components.llms.base import CompletionBaseComponent -from appbuilder import AutomaticTestToolEval -from appbuilder.core._exception import AppbuilderBuildexException - -def check_ancestor(cls): - parent_cls = Component - excluded_classes = ('Component', 'MatchingBaseComponent', 'EmbeddingBaseComponent', 'CompletionBaseComponent') - if cls.__name__ in excluded_classes: - return False - if issubclass(cls, CompletionBaseComponent): - return False - if issubclass(cls, parent_cls): - if parent_cls in excluded_classes: - return False - return True - for base in cls.__bases__: - if check_ancestor(base): - return True - return False - - -def find_tool_eval_components(): - current_file_path = os.path.abspath(__file__) - print(current_file_path) - components = [] - added_components = set() - base_path = current_file_path.split('/') - base_path = base_path[:-2]+['core', 'components'] - base_path = '/'.join(base_path) - print(base_path) - - for root, _, files in os.walk(base_path): - for file in files: - if file.endswith(".py"): - module_path = os.path.join(root, file) - module_name = module_path.replace(base_path, '').replace('/', '.').replace('\\', '.').strip('.') - - # 动态加载模块 - spec = importlib.util.spec_from_file_location(module_name, module_path) - if spec is None: - continue - module = importlib.util.module_from_spec(spec) - try: - spec.loader.exec_module(module) - except Exception as e: - continue - - # 查找继承自 Component 的类 - for name, obj in inspect.getmembers(module, inspect.isclass): - has_tool_eval = 'tool_eval' in obj.__dict__ and callable(getattr(obj, 'tool_eval', None)) - if has_tool_eval and obj.__name__ not in added_components and check_ancestor(obj): - added_components.add(obj.__name__) - components.append((name, obj)) - - return components - - -def read_whitelist_components(): - with open('whitelist_components.txt', 'r') as f: - lines = [line.strip() for line in f] - return lines - - -def write_error_data(error_df,error_stats): - txt_file_path = 'components_error_info.txt' - with open(txt_file_path, 'w') as file: - file.write("Component Name\tError Message\n") - for _, row in error_df.iterrows(): - file.write(f"{row['Component Name']}\t{row['Error Message']}\n") - file.write("\n错误统计信息:\n") - for error, count in error_stats.items(): - file.write(f"错误信息: {error}, 出现次数: {count}\n") - print(f"\n错误信息已写入: {txt_file_path}") - -@unittest.skipUnless(os.getenv("TEST_CASE", "UNKNOWN") == "CPU_PARALLEL", "") -class TestComponentManifestsAndToolEval(unittest.TestCase): - def setUp(self) -> None: - self.tool_eval_components = find_tool_eval_components() - self.whitelist_components = read_whitelist_components() - - def test_manifests(self): - """ - 要求必填,格式: list[dict],dict字段为 - * "name":str,要求不重复 - * "description":str,对于组件tool_eval函数功能的描述 - * "parameters":json_schema,对于tool_eval函数入参的描述,json_schema格式要求见https://json-schema.org/understanding-json-schema - """ - print("完成manifests测试的组件:") - for name, cls in self.tool_eval_components: - init_signature = inspect.signature(cls.__init__) - params = init_signature.parameters - mock_args = {} - for parameter_name, param in params.items(): - # 跳过 'self' 参数 - if parameter_name == 'self': - continue - if parameter_name == 'model' or name == 'model_name': - mock_args[parameter_name] = appbuilder.get_model_list()[0] - app = cls(**mock_args) - manifests = app.manifests - - assert isinstance(manifests, list) - assert len(manifests) > 0 - assert isinstance(manifests[0],dict) - assert isinstance(manifests[0]['name'], str) - assert isinstance(manifests[0]['description'], str) - assert isinstance(manifests[0]['parameters'], dict) - print("组件名称:{}".format(name)) - - def test_tool_eval(self): - """ - 测试tool_eval组件,收集报错信息,生成并存储报错信息表格,并进行统计和可视化。 - """ - print("完成tool_eval测试的组件:") - error_data = [] - - for name, cls in self.tool_eval_components: - init_signature = inspect.signature(cls.__init__) - params = init_signature.parameters - mock_args = {} - for parameter_name, param in params.items(): - # 跳过 'self' 参数 - if parameter_name == 'self': - continue - if parameter_name == 'model' or name == 'model_name': - mock_args[parameter_name] = appbuilder.get_model_list()[0] - app = cls(**mock_args) - try: - AutomaticTestToolEval(app) - print("组件名称:{} 通过测试".format(name)) - except Exception as e: - error_data.append({"Component Name": name, "Error Message": str(e)}) - print("组件名称:{} 错误信息:{}".format(name, e)) - - # 将错误信息表格存储在本地变量中 - error_df = pd.DataFrame(error_data) if error_data else None - - if error_df is not None: - print("\n错误信息表格:") - print(error_df) - # 使用 NumPy 进行统计 - unique_errors, counts = np.unique(error_df["Error Message"], return_counts=True) - error_stats = dict(zip(unique_errors, counts)) - print("\n错误统计信息:") - for error, count in error_stats.items(): - print(f"错误信息: {error}, 出现次数: {count}") - else: - print("\n所有组件测试通过,无错误信息。") - - # 将报错信息写入文件 - write_error_data(error_df, error_stats) - - # 判断报错组件是否位于白名单中 - component_names = error_df["Component Name"].tolist() - for component_name in component_names: - if component_name in self.whitelist_components: - print("{}zu白名单中,暂时忽略报错。".format(component_name)) - else: - raise AppbuilderBuildexException(f"组件 {component_name} 未在白名单中,请检查是否需要添加到白名单。") - -if __name__ == '__main__': - unittest.main() diff --git a/appbuilder/tests/test_core_components.py b/appbuilder/tests/test_core_components.py deleted file mode 100644 index 3a7b80f2..00000000 --- a/appbuilder/tests/test_core_components.py +++ /dev/null @@ -1,119 +0,0 @@ -# Copyright (c) 2024 Baidu, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -import unittest -import json -import copy -import appbuilder -import collections.abc - - -from appbuilder.core.components.asr.component import ASR -from appbuilder.core.components.dish_recognize.component import DishRecognition -from appbuilder.core.components.dish_recognize.model import DishRecognitionRequest -from appbuilder.core.message import Message -from appbuilder.core.components.llms.base import LLMMessage -from appbuilder.core._exception import AppBuilderServerException,InvalidRequestArgumentError - - -# 创建ShortSpeechRecognitionRequest对象 -class Request: - def __init__(self, format: str, rate: int, dev_pid: int, cuid: str, speech: bytes): - """ - 初始化函数,用于设置参数。 - - Args: - format (str): 音频格式,例如 "wav"。 - rate (int): 采样率,单位是 Hz。 - dev_pid (int): 设备 PID。 - cuid (str): 客户端 ID。 - speech (bytes): 语音字节流。 - - Returns: - None. 无返回值。 - """ - self.format = format - self.rate = rate - self.dev_pid = dev_pid - self.cuid = cuid - self.speech = speech - -# 创建一个response类,模拟requests.Response -class Response: - def __init__(self, status_code, headers, text): - self.status_code = status_code - self.headers = headers - self.text = text - - def json(self): - return json.loads(self.text) - -@unittest.skipUnless(os.getenv("TEST_CASE", "UNKNOWN") == "CPU_PARALLEL", "") -class TestCoreComponents(unittest.TestCase): - def test_asr_component_ASR(self): - # test_retry != self.http_client.retry.total - asr=ASR() - asr.http_client.retry.total=1 - request = Request( - format="wav", rate=16000, dev_pid=15372, cuid="test", speech=b"") - with self.assertRaises(AppBuilderServerException): - asr._recognize(request=request) - - def test_dish_recognize_component(self): - dr=DishRecognition() - request=DishRecognitionRequest(image=b"test") - dr.http_client.retry.total=1 - with self.assertRaises(AppBuilderServerException): - dr._recognize(request=request) - - - def test_llms_base(self): - # test LLMMessage deepcopy - lm=LLMMessage() - lm.__dict__={ - "content":collections.abc.Iterator, - 'test':'test' - } - new_lm=copy.deepcopy(lm) - - def test_components_raise(self): - # test ASR - asr=appbuilder.ASR() - tool=asr.tool_eval(name='test',streaming=False,file_urls={'test_1':'test'},file_name='test') - with self.assertRaises(InvalidRequestArgumentError): - next(tool) - - # test GeneralOCR - go=appbuilder.GeneralOCR() - tool=go.tool_eval(name='test',streaming=False,file_urls={'test_1':'test'},img_name='test') - with self.assertRaises(InvalidRequestArgumentError): - next(tool) - - # test HandwriteOCR - hwo=appbuilder.HandwriteOCR() - from appbuilder.core.components.handwrite_ocr.model import HandwriteOCRRequest - hwor=HandwriteOCRRequest() - with self.assertRaises(ValueError): - hwo._recognize(request=hwor) - - # test_llms_base_ResultProcessor - from appbuilder.core.components.llms.base import ResultProcessor,CompletionBaseComponent - with self.assertRaises(TypeError): - ResultProcessor.process(key='test',result_list=[]) - - - - -if __name__ == "__main__": - unittest.main() \ No newline at end of file diff --git a/appbuilder/tests/test_core_componentstest.py b/appbuilder/tests/test_core_componentstest.py deleted file mode 100644 index 1426d128..00000000 --- a/appbuilder/tests/test_core_componentstest.py +++ /dev/null @@ -1,104 +0,0 @@ -# Copyright (c) 2024 Baidu, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -import unittest -import appbuilder -import requests - -from unittest.mock import Mock -from appbuilder.core._exception import AppbuilderBuildexException - - -@unittest.skipUnless(os.getenv("TEST_CASE", "UNKNOWN") == "CPU_PARALLEL", "") -class TestComponents(unittest.TestCase): - def setUp(self): - self.image_understand = appbuilder.ImageUnderstand() - self.table_ocr = appbuilder.TableOCR() - self.img_url = 'img_url' - - - def test_components_raise(self): - tool_eval_input = {} - response = requests.Response() - # 检查tool_eval必须有streaming参数 - with self.assertRaises(AppbuilderBuildexException) as e: - appbuilder.AppbuilderTestToolEval(appbuilder_components=self.image_understand, - tool_eval_input=tool_eval_input, - response=response) - exception = e.exception - self.assertIn('是否定义streaming', str(exception)) - - # 检查组件tool_eval的传入参数是否有traceid - tool_eval_input = { - 'streaming': True, - } - with self.assertRaises(AppbuilderBuildexException) as e: - appbuilder.AppbuilderTestToolEval(appbuilder_components=self.image_understand, - tool_eval_input=tool_eval_input, - response=response) - exception = e.exception - self.assertIn('传入参数是否有traceid', str(exception)) - - # 检查组件tool_eval的传入参数是否正确或manifests的参数定义是否正确(required为anyOf) - tool_eval_input = { - 'streaming': True, - 'traceid': 'traceid', - } - with self.assertRaises(AppbuilderBuildexException) as e: - appbuilder.AppbuilderTestToolEval(appbuilder_components=self.image_understand, - tool_eval_input=tool_eval_input, - response=response) - exception = e.exception - self.assertIn('组件tool_eval的传入参数是否正确或manifests的参数定义是否正确', str(exception)) - - # 检查组件tool_eval的传入参数是否正确或manifests的参数定义是否正确(required不为anyOf) - tool_eval_input = { - 'streaming': True, - 'traceid': 'traceid', - } - with self.assertRaises(AppbuilderBuildexException) as e: - appbuilder.AppbuilderTestToolEval(appbuilder_components=self.table_ocr, - tool_eval_input=tool_eval_input, - response=response) - exception = e.exception - self.assertIn('组件tool_eval的传入参数是否正确或manifests的参数定义是否正确', str(exception)) - - def test_components_tool_eval_image_understand(self): - mock_response_data = { - 'result': {'task_id': '1821485837570181996'}, - 'log_id': 1821485837570181996, - } - mock_response = Mock() - mock_response.status_code = 200 - mock_response.headers = {'Content-Type': 'application/json'} - def mock_json(): - return mock_response_data - mock_response.json = mock_json - - tool_eval_input = { - 'streaming': True, - 'traceid': 'traceid', - 'name':"image_understand", - 'img_url':self.img_url, - 'origin_query':"" - } - - appbuilder.AppbuilderTestToolEval(appbuilder_components=self.image_understand, - tool_eval_input=tool_eval_input, - response=mock_response) - - -if __name__ == "__main__": - unittest.main() - \ No newline at end of file diff --git a/appbuilder/utils/func_utils.py b/appbuilder/utils/func_utils.py index 52d1425b..91abd8fa 100644 --- a/appbuilder/utils/func_utils.py +++ b/appbuilder/utils/func_utils.py @@ -39,4 +39,13 @@ def new_func(*args, **kwargs): warnings.simplefilter('default', DeprecationWarning) # reset filter return func(*args, **kwargs) return new_func - return decorator \ No newline at end of file + return decorator + +class Singleton(type): + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super( + Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] \ No newline at end of file