-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
42 lines (33 loc) · 1.37 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import flask
from flask import Flask, request, render_template
import json
import main
app = Flask(__name__)
@app.route('/')
def index():
return render_template('index.html')
@app.route('/get_end_predictions', methods=['post'])
def get_prediction_eos():
try:
input_text = ' '.join(request.json['input_text'].split())
input_text += ' <mask>'
top_k = request.json['top_k']
res = main.get_all_predictions(input_text, top_clean=int(top_k))
return app.response_class(response=json.dumps(res), status=200, mimetype='application/json')
except Exception as error:
err = str(error)
print(err)
return app.response_class(response=json.dumps(err), status=500, mimetype='application/json')
@app.route('/get_mask_predictions', methods=['post'])
def get_prediction_mask():
try:
input_text = ' '.join(request.json['input_text'].split())
top_k = request.json['top_k']
res = main.get_all_predictions(input_text, top_clean=int(top_k))
return app.response_class(response=json.dumps(res), status=200, mimetype='application/json')
except Exception as error:
err = str(error)
print(err)
return app.response_class(response=json.dumps(err), status=500, mimetype='application/json')
if __name__ == '__main__':
app.run(host='0.0.0.0', debug=True, port=8000, use_reloader=False)