-
Notifications
You must be signed in to change notification settings - Fork 124
/
server.py
141 lines (120 loc) · 4.5 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import argparse
import os
import time
import pytoml
import requests
from aiohttp import web
from loguru import logger
from termcolor import colored
from .service import ErrorCode, SerialPipeline, ParallelPipeline, start_llm_server
from .primitive import Query
import asyncio
from fastapi import FastAPI, APIRouter
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import uvicorn
import json
from typing import List
assistant = None
app = FastAPI(docs_url='/')
class Talk(BaseModel):
text: str
image: str = ''
def format_refs(refs: List[str]):
refs_filter = list(set(refs))
if len(refs) < 1:
return ''
text = '**References:**\r\n'
for file_or_url in refs_filter:
text += '* {}\r\n'.format(file_or_url)
text += '\r\n'
return text
@app.post("/huixiangdou_inference")
async def huixiangdou_inference(talk: Talk):
global assistant
query = Query(talk.text, talk.image)
pipeline = {'step': []}
debug = dict()
if type(assistant) is SerialPipeline:
for sess in assistant.generate(query=query):
status = {
"state":str(sess.code),
"response": sess.response,
"refs": sess.references
}
pipeline['step'].append(status)
pipeline['debug'] = sess.debug
return pipeline
else:
sentence = ''
async for sess in assistant.generate(query=query, enable_web_search=False):
if sentence == '' and len(sess.references) > 0:
sentence = format_refs(sess.references)
if len(sess.delta) > 0:
sentence += sess.delta
return sentence
@app.post("/huixiangdou_stream")
async def huixiangdou_stream(talk: Talk):
global assistant
query = Query(talk.text, talk.image)
pipeline = {'step': []}
debug = dict()
def event_stream():
for sess in assistant.generate(query=query):
status = {
"state":str(sess.code),
"response": sess.response,
"refs": sess.references
}
pipeline['step'].append(status)
pipeline['debug'] = sess.debug
yield json.dumps(pipeline)
async def event_stream_async():
sentence = ''
async for sess in assistant.generate(query=query, enable_web_search=False):
if sentence == '' and len(sess.references) > 0:
sentence = format_refs(sess.references)
if len(sess.delta) > 0:
sentence += sess.delta
yield sentence
if type(assistant) is SerialPipeline:
return StreamingResponse(event_stream(), media_type="text/event-stream")
else:
return StreamingResponse(event_stream_async(), media_type="text/event-stream")
def parse_args():
"""Parse args."""
parser = argparse.ArgumentParser(description='SerialPipeline.')
parser.add_argument('--work_dir',
type=str,
default='workdir',
help='Working directory.')
parser.add_argument(
'--config_path',
default='config.ini',
type=str,
help='Configuration path. Default value is config.ini')
parser.add_argument('--pipeline', type=str, choices=['chat_with_repo', 'chat_in_group'], default='chat_with_repo',
help='Select pipeline type for difference scenario, default value is `chat_with_repo`')
parser.add_argument('--standalone',
action='store_true',
default=True,
help='Auto deploy required Hybrid LLM Service.')
parser.add_argument('--no-standalone',
action='store_false',
dest='standalone', # 指定与上面参数相同的目标
help='Do not auto deploy required Hybrid LLM Service.')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
# start service
if args.standalone is True:
# hybrid llm serve
start_llm_server(config_path=args.config_path)
# setup chat service
if 'chat_with_repo' in args.pipeline:
assistant = ParallelPipeline(work_dir=args.work_dir, config_path=args.config_path)
elif 'chat_in_group' in args.pipeline:
assistant = SerialPipeline(work_dir=args.work_dir, config_path=args.config_path)
uvicorn.run(app, host='0.0.0.0', port=23333, log_level='info')