From 909886c2bc30d1363ecf3b4b7cc85893c87c77ac Mon Sep 17 00:00:00 2001 From: Chengmo Date: Tue, 11 Jun 2024 20:29:21 +0800 Subject: [PATCH] revert oral generation (#363) --- .../llms/oral_query_generation/README.md | 65 ++---- .../llms/oral_query_generation/component.py | 221 ++---------------- .../tests/test_oral_query_generation.py | 125 ---------- .../test_qa_llm_oral_query_generation.py | 2 +- 4 files changed, 43 insertions(+), 370 deletions(-) delete mode 100644 appbuilder/tests/test_oral_query_generation.py diff --git a/appbuilder/core/components/llms/oral_query_generation/README.md b/appbuilder/core/components/llms/oral_query_generation/README.md index c82b5038..714d0bbf 100644 --- a/appbuilder/core/components/llms/oral_query_generation/README.md +++ b/appbuilder/core/components/llms/oral_query_generation/README.md @@ -1,16 +1,16 @@ -# 口语化Query生成(Oral Query Generation) +# 口语化Query生成(OralQueryGeneration) ## 简介 -口语化Query生成组件(Oral Query Generation)可以基于输入文本生成与文档内容相关的Query。 +口语化Query生成组件(OralQueryGeneration)可以基于输入文本生成与文档内容相关的Query。 ### 功能介绍 基于输入文本生成与文档内容相关的Query。 ### 特色优势 -生成的query划分为问题和短语两种类型,可分别用于不同场景。 +生成高质量Query。 ### 应用场景 -可用于推荐问题生成、标签生成、文档索引增强等。 +可用于文档生成推荐问题、文档索引增强等。 ## 基本用法 ### 快速开始 @@ -21,17 +21,13 @@ import appbuilder # 请前往千帆AppBuilder官网创建密钥,流程详见:https://cloud.baidu.com/doc/AppBuilder/s/Olq6grrt6#1%E3%80%81%E5%88%9B%E5%BB%BA%E5%AF%86%E9%92%A5 os.environ["APPBUILDER_TOKEN"] = "..." - text = ('文档标题:在OPPO Reno5上使用视频超级防抖\n' '文档摘要:OPPO Reno5上的视频超级防抖,视频超级防抖3.0,多代视频防抖算法积累,这一代依旧超级防抖超级稳。 开启视频超级' '防抖 开启路径:打开「相机 > 视频 > 点击屏幕上方的“超级防抖”标识」 后置视频同时支持超级防抖和超级防抖Pro功能,开启超级' '防抖后手机屏幕将出现超级防抖Pro开关,点击即可开启或关闭。 除此之外,前置视频同样加持防抖算法,边走边拍也能稳定聚焦脸部' ',实时视频分享您的生活。') - -#! 该组件推荐使用ERNIE Speed-AppBuilder模型。 oral_query_generation = appbuilder.OralQueryGeneration(model='ERNIE Speed-AppBuilder') -result = oral_query_generation(appbuilder.Message(text), query_type='全部', output_format='str') - +result = oral_query_generation(appbuilder.Message(text)) print(result) ``` @@ -47,9 +43,9 @@ os.environ["APPBUILDER_TOKEN"] = "bce-YOURTOKEN" | 参数名称 | 参数类型 | 是否必须 | 描述 | 示例值 | | ------- | ------- | -------- | -------- | -------- | -| `model` | str | 是 | 模型名称,用于指定要使用的千帆模型。推荐使用ERNIE Speed-AppBuilder模型。 | ERNIE Speed-AppBuilder | -| `secret_key` | str | 否 | 用户鉴权token,默认从环境变量中获取: `os.getenv("APPBUILDER_TOKEN", "")` | bce-v3/XXX | -| `gateway` | str | 否 | 后端网关服务地址,默认从环境变量中获取: `os.getenv("GATEWAY_URL", "")` | https://appbuilder.baidu.com | +| `model` | str | 否 | 模型名称,用于指定要使用的千帆模型。 | ERNIE Speed-AppBuilder | +| `secret_key` | str | 否 | 用户鉴权token,默认从环境变量中获取: `os.getenv("APPBUILDER_TOKEN", "")` | | +| `gateway` | str | 否 | 后端网关服务地址,默认从环境变量中获取: `os.getenv("GATEWAY_URL", "")` | | | `lazy_certification` | bool | 否 | 延迟认证,为True时在第一次运行时认证。默认为False。 | False | ### 调用参数 @@ -57,11 +53,8 @@ os.environ["APPBUILDER_TOKEN"] = "bce-YOURTOKEN" | 参数名称 | 参数类型 | 是否必须 | 描述 | 示例值 | | ------- | ------- | -------- | -------- | -------- | | `message` | obj | 是 | 输入消息,用于模型的主要输入内容。 | Message(content='...') | -| `query_type` | str | 否 | 待生成的query类型,包括问题、短语和全部(问题+短语)。默认为全部。 | 全部 | -| `output_format` | str | 否 | 输出格式,包括json和str。默认为str。 | str | | `stream` | bool | 否 | 指定是否以流式形式返回响应。默认为 False。 | False | | `temperature` | float | 否 | 模型配置的温度参数,用于调整模型的生成概率。取值范围为 0.0 到 1.0,其中较低的值使生成更确定性,较高的值使生成更多样性。默认值为 1e-10。 | 0.1 | -| `top_p` | float | 否 | 影响输出文本的多样性,取值越大,生成文本的多样性越强。取值范围为 0.0 到 1.0,其中较低的值使生成更确定性,较高的值使生成更多样性。默认值为 0.0。 | 0.0 | ### 响应参数 | 参数名称 | 参数类型 | 描述 | 示例值 | @@ -70,38 +63,26 @@ os.environ["APPBUILDER_TOKEN"] = "bce-YOURTOKEN" ### 响应示例 ``` -Message(name=msg, content=1. OPPO Reno5上有什么特殊的功能? -2. 视频超级防抖是什么? -3. 视频超级防抖有什么作用? -4. 如何在OPPO Reno5上开启视频超级防抖? -5. 视频超级防抖Pro是什么? -6. 开启视频超级防抖后,屏幕上会出现什么? -7. 前置视频有防抖算法吗? -8. OPPO Reno5上的视频超级防抖 -9. 视频超级防抖3.0 -10. 多代视频防抖算法积累的作用 -11. 开启视频超级防抖的方法 -12. 视频超级防抖Pro的功能 -13. 开启视频超级防抖后,屏幕上会出现的东西 -14. 前置视频防抖算法的作用, mtype=dict, extra={}) +Message(name=msg, content=1、OPPO Reno5上的超级防抖 +2、怎么开启OPPO Reno5的超级防抖 +3、OPPO Reno5的超级防抖Pro +4、前置视频上的超级防抖 +5、跑步时拍照OPPO Reno5的超级防抖 +6、OPPO Reno5的多代视频防抖 +7、OPPO Reno5的超级防抖算法 +8、OPPO Reno5的超级防抖怎么用 +9、OPPO Reno5的超级防抖如何开启 +10、OPPO Reno5的超级防抖Pro如何使用, mtype=dict, extra={}) ``` ## 高级用法 ## 更新记录和贡献 -### 2024.5.22 -#### [Updated] -- 升级能力,主要升级内容如下: - - 生成的query要求能够使用输入文本进行回答。 - - 生成的query划分为问题和短语类型。 - - 生成的query数量不再限制为10个。 -- 在调用组件时,支持输出问题、短语或全部(问题 + 短语);支持输出格式为json或者str(兼容之前版本的输出格式)。 +### 2023.12.07 +#### [Added] +- 增加口语化Query生成能力。 +- 增加口语化Query生成单元测试。 ### 2024.1.24 #### [Updated] -- 更新README。 - -### 2023.12.07 -#### [Added] -- 增加口语化Query生成组件。 -- 增加口语化Query生成组件单元测试。 \ No newline at end of file +- 更新README。 \ No newline at end of file diff --git a/appbuilder/core/components/llms/oral_query_generation/component.py b/appbuilder/core/components/llms/oral_query_generation/component.py index 6dc94a1d..8e45252a 100644 --- a/appbuilder/core/components/llms/oral_query_generation/component.py +++ b/appbuilder/core/components/llms/oral_query_generation/component.py @@ -15,84 +15,30 @@ """ -import json -import re - from pydantic import BaseModel, Field from typing import Optional -from enum import Enum -from appbuilder.core.components.llms.base import CompletionBaseComponent, ModelArgsConfig +from appbuilder.core.components.llms.base import CompletionBaseComponent from appbuilder.core.message import Message from appbuilder.core.component import ComponentArguments -from appbuilder.core._exception import AppBuilderServerException -from appbuilder.utils.logger_util import logger - - -class QueryTypeChoices(Enum): - questions = '问题' - phrases = '短语' - questions_and_phrases = '全部' - - def to_chinese(self): - """ - 将QueryTypeChoices枚举类中的值转换为中文描述。 - - Args: - 无参数 - - Returns: - 返回一个字典,键是QueryTypeChoices枚举类的成员,值为对应的中文描述字符串。 - - """ - descriptions = { - QueryTypeChoices.questions: '问题', - QueryTypeChoices.phrases: '短语', - QueryTypeChoices.questions_and_phrases: '全部' - } - return descriptions[self] - - -class OutputFormatChoices(Enum): - json_format = 'json' - str_format = 'str' - - def to_chinese(self): - """ - 将OutputFormatChoices枚举类中的值转换为中文描述。 - - Args: - 无参数 - - Returns: - 返回一个字典,键是OutputFormatChoices枚举类的成员,值为对应的中文描述字符串。 - - """ - descriptions = { - OutputFormatChoices.json_format: 'json', - OutputFormatChoices.str_format: 'str' - } - return descriptions[self] class OralQueryGenerationArgs(ComponentArguments): """口语化Query生成配置 """ - text: str = Field(..., - valiable_name='text', - description='输入文本,用于生成Query') - query_type: QueryTypeChoices = Field(..., - variable_name='query_type', - description='待生成的query类型,可选值为问题、短语和全部(问题+短语)。') - output_format: QueryTypeChoices = Field(..., - variable_name='output_format', - description='输出格式,可选值为json、str。') + """ + message: Message = Field(..., + valiable_name='query', + description='输入文本,用于生成Query') + """ + query: str = Field(..., + valiable_name='query', + description='输入文本,用于生成Query') class OralQueryGeneration(CompletionBaseComponent): """ 口语化Query生成,可用于问答场景下对文档增强索引。 - *注:该组件推荐使用ERNIE Speed-AppBuilder模型。* Examples: @@ -109,40 +55,13 @@ class OralQueryGeneration(CompletionBaseComponent): '防抖后手机屏幕将出现超级防抖Pro开关,点击即可开启或关闭。 除此之外,前置视频同样加持防抖算法,边走边拍也能稳定聚焦脸部' ',实时视频分享您的生活。') oral_query_generation = appbuilder.OralQueryGeneration(model='ERNIE Speed-AppBuilder') - answer = oral_query_generation(appbuilder.Message(text), query_type='全部', output_format='str') + answer = oral_query_generation(appbuilder.Message(text)) print(answer.content) """ - name = 'query_generation' + name = 'oral_query_generation' version = 'v1' meta = OralQueryGenerationArgs - manifests = [ - { - "name": "query_generation", - "description": "输入文本、待生成的query类型和输出格式,生成query,并按照要求的格式进行输出。", - "parameters": { - "type": "object", - "properties": { - "text": { - "text": "string", - "description": "输入文本,组件会根据该输入文本生成query。" - }, - "query_type": { - "text": "string", - "description": "待生成的query类型,可选问题、短语以及全部(问题 + 短语)。" - }, - "output_format": { - "text": "string", - "description": "输出格式,可选json或str,str格式与老版本输出格式相同。" - } - }, - "required": [ - "text" - ] - } - } - ] - def __init__( self, model=None, @@ -153,7 +72,7 @@ def __init__( """初始化口语化Query生成模型。 Args: - model (str|None): 模型名称,用于指定要使用的千帆模型。推荐使用ERNIE Speed-AppBuilder模型。 + model (str|None): 模型名称,用于指定要使用的千帆模型。 secret_key (str, 可选): 用户鉴权token, 默认从环境变量中获取: os.getenv("APPBUILDER_TOKEN", ""). gateway (str, 可选): 后端网关服务地址,默认从环境变量中获取: os.getenv("GATEWAY_URL", "") lazy_certification (bool, 可选): 延迟认证,为True时在第一次运行时认证. Defaults to False. @@ -164,120 +83,18 @@ def __init__( """ super().__init__( OralQueryGenerationArgs, model=model, secret_key=secret_key, gateway=gateway, lazy_certification=lazy_certification) - - def regenerate_output(self, model_output, output_format): - """ - 兼容老版本的输出格式 - """ - if not isinstance(model_output, str): - return model_output - - match_obj = re.search(r'```json\n(.+)\n```', model_output, flags=re.DOTALL) - - regenerated_output = None - if match_obj: - regenerated_output = json.loads(match_obj.group(1)) - else: - dict_json_match_obj = re.search(r'\{(.|\n)+\}', model_output) - dict_json_text = dict_json_match_obj.group(0) if dict_json_match_obj else None - regenerated_output = json.loads(dict_json_text) if dict_json_text is not None else model_output - - if output_format == 'json' or not isinstance(regenerated_output, dict): - return json.dumps(regenerated_output, ensure_ascii=False, indent=4) - - queries = [] - for key in ['问题', '短语']: - queries += regenerated_output.pop(key, []) - for value in regenerated_output.values(): - queries += value - - regenerated_output = '\n'.join([f'{index}. {query}' for index, query in enumerate(queries, 1)]) - return regenerated_output - - def completion(self, version, base_url, request, timeout: float = None, - retry: int = 0): - r"""Send a byte array of an audio file to obtain the result of speech recognition.""" - - headers = self.http_client.auth_header() - headers["Content-Type"] = "application/json" - - stream = True if request.response_mode == "streaming" else False - - url = self.http_client.service_url("/app/query_generation", self.base_url) - logger.debug( - "request url: {}, method: {}, json: {}, headers: {}".format(url, - "POST", - request.params, - headers)) - response = self.http_client.session.post(url, json=request.params, headers=headers, timeout=timeout, - stream=stream) - - logger.debug( - "request url: {}, method: {}, json: {}, headers: {}, response: {}".format(url, "POST", - request.params, - headers, - response)) - return self.gene_response(response, stream) - - def run(self, message, query_type='全部', output_format='str', stream=False, temperature=1e-10, top_p=0.0): + def run(self, message, stream=False, temperature=1e-10, top_p=0.0): """ 使用给定的输入运行模型并返回结果。 - Args: - message (Message): 输入消息,用于传入query、context和answer。这是一个必需的参数。 - query_type (str, 可选): 待生成的query类型,包括问题、短语和全部(问题+短语)。默认为全部。 - output_format (str, 可选): 输出格式,包括json和str,stream为True时,只能以json形式输出。默认为str。 + 参数: + message (obj:`Message`): 输入消息,用于模型的主要输入内容。这是一个必需的参数。 stream (bool, 可选): 指定是否以流式形式返回响应。默认为 False。 temperature (float, 可选): 模型配置的温度参数,用于调整模型的生成概率。取值范围为 0.0 到 1.0,其中较低的值使生成更确定性,较高的值使生成更多样性。默认值为 1e-10。 - top_p (float, 可选): 影响输出文本的多样性,取值越大,生成文本的多样性越强。取值范围为 0.0 到 1.0,其中较低的值使生成更确定性,较高的值使生成更多样性。默认值为 0。 - - Returns: - result (Message): 模型运行后的输出消息。 - """ - text = message.content - assert text, 'Input text should be a valid string' - inputs = { - 'text': text, - 'query_type': query_type - } - response_mode = "streaming" if stream else "blocking" - user_id = message.id - model_config_inputs = ModelArgsConfig(**{"stream": stream, "temperature": temperature, "top_p": top_p}) - model_config = self.get_model_config(model_config_inputs) + top_p(float, optional): 影响输出文本的多样性,取值越大,生成文本的多样性越强。取值范围为 0.0 到 1.0,其中较低的值使生成更确定性,较高的值使生成更多样性。默认值为 0。 - request = self.gene_request('', inputs, response_mode, user_id, model_config) - response = self.completion(self.version, self.base_url, request) - - if response.error_no != 0: - raise AppBuilderServerException(service_err_code=response.error_no, service_err_message=response.error_msg) - - result = response.to_message() - result.content = self.regenerate_output(result.content, output_format) - - return result - - def tool_eval(self, name: str, stream: bool = False, **kwargs): - """ - tool_eval for function call + 返回: + obj:`Message`: 模型运行后的输出消息。 """ - text = kwargs.get('text', None) - query_type = kwargs.get('query_type', '全部') - output_format = kwargs.get('output_format', 'str') - if not text: - raise ValueError('param `text` is required') - msg = Message(text) - model_configs = kwargs.get('model_configs', {}) - temperature = model_configs.get('temperature', 1e-10) - top_p = model_configs.get('top_p', 0.0) - message = self.run(message=msg, - query_type=query_type, - output_format=output_format, - stream=stream, - temperature=temperature, - top_p=top_p) - if stream: - for data in message.content: - yield data - else: - return message.content + return super().run(query=message.content, stream=stream, temperature=temperature, top_p=top_p) diff --git a/appbuilder/tests/test_oral_query_generation.py b/appbuilder/tests/test_oral_query_generation.py deleted file mode 100644 index 0af7787a..00000000 --- a/appbuilder/tests/test_oral_query_generation.py +++ /dev/null @@ -1,125 +0,0 @@ -# Copyright (c) 2023 Baidu, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -import unittest -import appbuilder - - -TEST_TEXT = ('文档标题:在OPPO Reno5上使用视频超级防抖\n' - '文档摘要:OPPO Reno5上的视频超级防抖,视频超级防抖3.0,多代视频防抖算法积累,这一代依旧超级防抖超级稳。 开启视频超级' - '防抖 开启路径:打开「相机 > 视频 > 点击屏幕上方的“超级防抖”标识」 后置视频同时支持超级防抖和超级防抖Pro功能,开启超级' - '防抖后手机屏幕将出现超级防抖Pro开关,点击即可开启或关闭。 除此之外,前置视频同样加持防抖算法,边走边拍也能稳定聚焦脸部' - ',实时视频分享您的生活。') - - -@unittest.skipUnless(os.getenv("TEST_CASE", "UNKNOWN") == "CPU_SERIAL", "") -class TestOralQueryGenerationComponent(unittest.TestCase): - def setUp(self): - """ - 设置环境变量。 - - Args: - 无参数,默认值为空。 - - Returns: - 无返回值,方法中执行了环境变量的赋值操作。 - """ - - self.model_name = 'ERNIE Speed-AppBuilder' - secret_key = os.getenv('SECRET_KEY', None) - self.query_generation = appbuilder.OralQueryGeneration(model=self.model_name, secret_key=secret_key) - - def test_run_with_default_params(self): - """测试 run 方法使用默认参数 - """ - text = TEST_TEXT - msg = appbuilder.Message(text) - answer = self.query_generation(msg) - # print(answer) - self.assertIsNotNone(answer) - print(f'\n[result]\n{answer.content}\n') - - def test_run_with_question_output_and_json_output(self): - """测试 run 方法,输出query类型为问题,输出格式为json - """ - text = TEST_TEXT - msg = appbuilder.Message(text) - answer = self.query_generation(msg, query_type='问题', output_format='json') - # print(answer) - self.assertIsNotNone(answer) - print(f'\n[result]\n{answer.content}\n') - - def test_run_with_phrase_output_and_str_output(self): - """测试 run 方法,输出query类型为短语,输出格式为str - """ - text = TEST_TEXT - msg = appbuilder.Message(text) - answer = self.query_generation(msg, query_type='短语', output_format='str') - # print(answer) - self.assertIsNotNone(answer) - print(f'\n[result]\n{answer.content}\n') - - def test_run_with_stream_and_temperature(self): - """测试不同的 stream 和 temperature 参数值 - """ - text = TEST_TEXT - msg = appbuilder.Message(text) - answer = self.query_generation(msg, stream=False, temperature=0.5) - # print(answer) - self.assertIsNotNone(answer) - print(f'\n[result]\n{answer.content}\n') - - def test_tool_eval_with_default_params(self): - """测试 tool_eval 方法使用默认参数 - """ - text = TEST_TEXT - answer = self.query_generation.tool_eval(name='', stream=False, text=text) - # print(answer) - self.assertIsNotNone(answer) - print(f'\n[result]\n{answer}\n') - - def test_tool_eval_with_model_configs(self): - """测试 tool_eval 方法使用不同temperature和top_p参数值。 - """ - text = TEST_TEXT - model_configs = {'temperature': 0.5, 'top_p': 0.5} - answer = self.query_generation.tool_eval(name='', stream=True, text=text, model_configs=model_configs) - # print(answer) - print(f'\n[result]\n') - for ans in answer: - print(ans) - - def test_tool_eval_with_default_params(self): - """测试 tool_eval 方法使用默认参数 - """ - text = TEST_TEXT - answer = self.query_generation.tool_eval(name='', stream=False, text=text) - # print(answer) - self.assertIsNotNone(answer) - print(f'\n[result]\n{answer}\n') - - def test_tool_eval_with_model_configs(self): - """测试 tool_eval 方法使用不同temperature和top_p参数值。 - """ - text = TEST_TEXT - model_configs = {'temperature': 0.5, 'top_p': 0.5} - answer = self.query_generation.tool_eval(name='', stream=True, text=text, model_configs=model_configs) - # print(answer) - print(f'\n[result]\n') - for ans in answer: - print(ans) - - -if __name__ == '__main__': - unittest.main() \ No newline at end of file diff --git a/appbuilder/tests/test_qa_llm_oral_query_generation.py b/appbuilder/tests/test_qa_llm_oral_query_generation.py index 882ea4b5..6cb9e3a7 100644 --- a/appbuilder/tests/test_qa_llm_oral_query_generation.py +++ b/appbuilder/tests/test_qa_llm_oral_query_generation.py @@ -94,7 +94,7 @@ def test_normal_case(self, model_name, text, stream, temperature): 'through: appbuilder.get_model_list()' ), param( - "ERNIE-Bot 4.0", None, "ValueError", "text", "Input text should be a valid string" + "ERNIE-Bot 4.0", None, "ValueError", "query", "Input should be a valid string" ) ]) def test_abnormal_case(self, model_name, text, err_type, err_param, err_msg):