Skip to content

Commit

Permalink
Merge pull request #26 from cjopengler/dbg_gbi_refactor_client
Browse files Browse the repository at this point in the history
升级 gbi 使用 Component 的 http_client
  • Loading branch information
seiriosPlus authored Dec 27, 2023
2 parents 1a8910d + 98ce96a commit cca1f30
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 36 deletions.
23 changes: 11 additions & 12 deletions appbuilder/core/components/gbi/nl2sql/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@

r"""GBI nl2sql component.
"""
import uuid
import json
from typing import Dict, List, Optional
from pydantic import BaseModel, Field, ValidationError

Expand Down Expand Up @@ -90,11 +88,12 @@ def run(self,
1. query: 用户问题
2. session: gbi session 的历史 列表, 参考 SessionRecord
3. column_constraint: 列选约束 参考 ColumnItem 具体定义
timeout: 超时时间
retry: 重试次数
Returns:
NL2SqlResult 的 message
"""


try:
inputs = self.meta(**message.content)
except ValidationError as e:
Expand Down Expand Up @@ -134,11 +133,11 @@ def _run_nl2sql(self, query: str, session: List[SessionRecord], table_schemas: L
"""

headers = self.auth_header()
headers = self.http_client.auth_header()
headers["Content-Type"] = "application/json"

if retry != self.retry.total:
self.retry.total = retry
if retry != self.http_client.retry.total:
self.http_client.retry.total = retry

payload = {"query": query,
"table_schemas": table_schemas,
Expand All @@ -148,13 +147,13 @@ def _run_nl2sql(self, query: str, session: List[SessionRecord], table_schemas: L
"knowledge": knowledge,
"prompt_template": prompt_template}

server_url = self.service_url(prefix="", sub_path=self.server_sub_path)
response = self.s.post(url=server_url, headers=headers,
json=payload, timeout=timeout)
super().check_response_header(response)
server_url = self.http_client.service_url(prefix="", sub_path=self.server_sub_path)
response = self.http_client.session.post(url=server_url, headers=headers,
json=payload, timeout=timeout)
self.http_client.check_response_header(response)
data = response.json()
super().check_response_json(data)
self.http_client.check_response_json(data)

request_id = self.response_request_id(response)
request_id = self.http_client.response_request_id(response)
response.request_id = request_id
return response
29 changes: 13 additions & 16 deletions appbuilder/core/components/gbi/select_table/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@

r"""GBI nl2sql component.
"""
import uuid
import json
from typing import Dict, List, Optional
from pydantic import BaseModel, Field, ValidationError

Expand Down Expand Up @@ -67,8 +65,6 @@ def __init__(self, model_name: str, table_descriptions: Dict[str, str],
问题:{query}
回答:
```
secret_key:
gateway:
"""
super().__init__(meta=SelectTableArgs)
if model_name not in SUPPORTED_MODEL_NAME:
Expand All @@ -79,12 +75,14 @@ def __init__(self, model_name: str, table_descriptions: Dict[str, str],
self.prompt_template = prompt_template

def run(self,
message: Message, timeout: int = 60,retry: int = 0) -> Message[List[str]]:
message: Message, timeout: int = 60, retry: int = 0) -> Message[List[str]]:
"""
Args:
message: message.content 字典包含 key:
1. query - 用户的问题输入
2. session - 对话历史, 可选
timeout: 超时时间
retry: 重试次数
Returns: 识别的表名的列表 ["table_name"]
"""
Expand Down Expand Up @@ -122,26 +120,25 @@ def _run_select_table(self, query: str, session: List[SessionRecord],
obj:`ShortSpeechRecognitionResponse`: 接口返回的输出消息。
"""

headers = self.auth_header()
headers = self.http_client.auth_header()
headers["Content_Type"] = "application/json"

if retry != self.retry.total:
self.retry.total = retry
if retry != self.http_client.retry.total:
self.http_client.retry.total = retry

payload = {"query": query,
"table_descriptions": table_descriptions,
"session": [session_record.to_json() for session_record in session],
"session": [session_record.dict() for session_record in session],
"model_name": model_name,
"prompt_template": prompt_template}

server_url = self.service_url(sub_path=self.server_sub_path)
response = self.s.post(url=server_url, headers=headers,
json=payload, timeout=timeout)
super().check_response_header(response)
server_url = self.http_client.service_url(sub_path=self.server_sub_path)
response = self.http_client.session.post(url=server_url, headers=headers,
json=payload, timeout=timeout)
self.http_client.check_response_header(response)
data = response.json()
super().check_response_json(data)
self.http_client.check_response_json(data)

request_id = self.response_request_id(response)
request_id = self.http_client.response_request_id(response)
response.request_id = request_id
return response

8 changes: 4 additions & 4 deletions appbuilder/tests/test_gbi_nl2sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,10 @@ def test_run_with_session(self):
session = list()
session_record = SessionRecord(query="列出商品类别是水果的的利润率",
answer=NL2SqlResult(
llm_result="根据问题分析得到 sql 如下: \n "
"```sql\nSELECT * FROM `超市营收明细` "
"WHERE `商品类别` = '水果'\n```",
sql="SELECT * FROM `超市营收明细` WHERE `商品类别` = '水果'"))
llm_result="根据问题分析得到 sql 如下: \n "
"```sql\nSELECT * FROM `超市营收明细` "
"WHERE `商品类别` = '水果'\n```",
sql="SELECT * FROM `超市营收明细` WHERE `商品类别` = '水果'"))
session.append(session_record)

query = "列出所有的商品类别"
Expand Down
26 changes: 23 additions & 3 deletions appbuilder/tests/test_gbi_select_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import appbuilder
from appbuilder.core.message import Message
from appbuilder.core.components.gbi.basic import SessionRecord

from appbuilder.core.components.gbi.basic import NL2SqlResult

SUPER_MARKET_SCHEMA = """
```
Expand Down Expand Up @@ -68,6 +68,7 @@
回答:
"""


class TestGBISelectTable(unittest.TestCase):

def setUp(self):
Expand All @@ -79,7 +80,7 @@ def setUp(self):
self.select_table_node = \
appbuilder.SelectTable(model_name=model_name,
table_descriptions={"supper_market_info": "超市营收明细表,包含超市各种信息等",
"product_sales_info": "产品销售表"})
"product_sales_info": "产品销售表"})

def test_run_with_default_param(self):
"""测试 run 方法使用有效参数"""
Expand All @@ -95,7 +96,6 @@ def test_run_with_prompt_template(self):
"""测试 run 方法中 prompt template 模版"""
query = "列出超市中的所有数据"
msg = Message({"query": query})
result_message = self.select_table_node(message=msg)
self.select_table_node.prompt_template = PROMPT_TEMPLATE
result_message = self.select_table_node(msg)

Expand All @@ -104,6 +104,26 @@ def test_run_with_prompt_template(self):
self.assertEqual(result_message.content[0], "supper_market_info")
self.select_table_node.prompt_template = ""

def test_run_with_session(self):
"""测试 run 方法中 prompt template 模版"""

session = list()
session_record = SessionRecord(query="列出商品类别是水果的的利润率",
answer=NL2SqlResult(
llm_result="根据问题分析得到 sql 如下: \n "
"```sql\nSELECT * FROM `超市营收明细` "
"WHERE `商品类别` = '水果'\n```",
sql="SELECT * FROM `超市营收明细` WHERE `商品类别` = '水果'"))
session.append(session_record)

query = "列出超市中的所有数据"
msg = Message({"query": query, "session": session})
result_message = self.select_table_node(msg)

self.assertIsNotNone(result_message)
self.assertEqual(len(result_message.content), 1)
self.assertEqual(result_message.content[0], "supper_market_info")


if __name__ == '__main__':
unittest.main()
2 changes: 1 addition & 1 deletion cookbooks/gbi.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"# GBI\n",
"\n",
"## 目标\n",
"通过 GBI sdk 接口完成选表和问表的能力。\n",
"通过 GBI sdk 接口完成选表和问表的能力。 \n",
"\n",
"## 准备工作\n",
"### 平台注册\n",
Expand Down

0 comments on commit cca1f30

Please sign in to comment.