-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
83 lines (76 loc) · 2.75 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from fastapi import FastAPI
from pydantic import BaseModel
from typing import Any, Dict, Union
import base64
from urllib.parse import urljoin
import requests
app = FastAPI()
hp = ("text-generation", "zero-shot-classification", "object-detection", "token-class")
mdu = ("https://text-generation-intern-mohsina.demo1.truefoundry.com/", "https://zero-shot-classification-intern-mohsina.demo1.truefoundry.com/", "https://object-detection-intern-mohsina.demo1.truefoundry.com/", "https://token-class-intern-mohsina.demo1.truefoundry.com/")
hf_pipeline = "token-class"
model_deployed_url = "https://token-class-intern-mohsina.demo1.truefoundry.com/"
class Input(BaseModel):
inputs: Any
def convert_to_v2_format(inputs: Any, pipeline: str) -> Dict:
if pipeline == "zero-shot-classification":
v2_format_data = {
"inputs": [
{
"name": "array_inputs",
"shape": [1, 1],
"datatype": "BYTES",
"data": [inputs['sequence']]
},
{
"name": "candidate_labels",
"shape": [1, len(inputs['candidate_labels'])],
"datatype": "BYTES",
"data": inputs['candidate_labels']
}
]
}
elif pipeline == "object-detection":
with open(inputs, "rb") as image_file:
encoded_image = base64.b64encode(image_file.read()).decode('utf-8')
v2_format_data = {
"inputs": [
{
"name": "inputs",
"shape": [1, 1],
"datatype": "BYTES",
"data": [encoded_image],
"parameters": {"content_type": "base64"},
}
]
}
elif pipeline == "text-generation":
v2_format_data = {
"inputs": [
{
"name": "array_inputs",
"shape": [1, 1],
"datatype": "BYTES",
"data": [inputs]
}
]
}
elif pipeline == "token-class":
v2_format_data = {
"inputs": [
{
"name": "args",
"shape": [1, 1],
"datatype": "BYTES",
"data": [inputs]
}
]
}
else:
raise ValueError(f"Unsupported pipeline: {pipeline}")
return v2_format_data
@app.post("/predict")
def predict(input: Input):
print(hf_pipeline)
v2_format_data = convert_to_v2_format(input.inputs, hf_pipeline)
response = requests.post(urljoin(model_deployed_url, f'v2/models/{hf_pipeline}/infer'), json=v2_format_data)
return response.json()