-
Notifications
You must be signed in to change notification settings - Fork 1
/
opus_mt.py
executable file
·89 lines (72 loc) · 2.81 KB
/
opus_mt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
"""An Opus-MT microservice that translates text from English to German"""
from fastapi import FastAPI
from fastapi.openapi.docs import (
get_redoc_html,
get_swagger_ui_html,
get_swagger_ui_oauth2_redirect_html,
)
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import syntok.segmenter as segmenter
app = FastAPI(docs_url=None, redoc_url=None)
app.mount("/static", StaticFiles(directory="static"), name="static")
TOKENIZER = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de")
MODEL = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-de")
# Serve the Swagger API locally
@app.get("/docs", include_in_schema=False)
async def custom_swagger_ui_html():
return get_swagger_ui_html(
openapi_url=app.openapi_url,
title=app.title + " - Swagger UI",
oauth2_redirect_url=app.swagger_ui_oauth2_redirect_url,
swagger_js_url="/static/swagger-ui-bundle.js",
swagger_css_url="/static/swagger-ui.css",
)
@app.get(app.swagger_ui_oauth2_redirect_url, include_in_schema=False)
async def swagger_ui_redirect():
return get_swagger_ui_oauth2_redirect_html()
@app.get("/redoc", include_in_schema=False)
async def redoc_html():
return get_redoc_html(
openapi_url=app.openapi_url,
title=app.title + " - ReDoc",
redoc_js_url="/static/redoc.standalone.js",
)
class Response(BaseModel):
translation: str
def translate_text(text: str) -> str:
"""Translate the text from English to German.
Arguments:
text -- The text to be translated
Returns:
The translated text
"""
translation = []
src_text = []
for paragraph in segmenter.analyze(text):
for sentence in paragraph:
src_text.append("".join(token.spacing + token.value for token in sentence))
translated = MODEL.generate(
**TOKENIZER(src_text, return_tensors="pt", padding=True)
)
tgt_text = [TOKENIZER.decode(t, skip_special_tokens=True) for t in translated]
translation.extend(tgt_text)
translation.append("\n")
# Translating "\n" leads to the model hallucinating a sentence.
src_text.clear()
del translation[-1]
return " ".join(translation)
@app.get("/translate")
def translate(text: str, source: str="en", target: str="de") -> Response:
"""Translate a single document.
Arguments:
text -- A text to be translated
Returns:
A list of translated sentences.
"""
if not source.startswith("en"):
return Response(translation="This model translates from English to German")
if not target.startswith("de"):
return Response(translation="This model translates from English to German")
return Response(translation=translate_text(text))