king-translate / main.py
gleisonnanet's picture
Revert "load app"
ebe5014
raw
history blame
4.21 kB
import io
import re
import time
import os
from typing import List, Literal
from fastapi import FastAPI
from pydantic import BaseModel
from enum import Enum
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import uvicorn
app = FastAPI(docs_url="/")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class TranslationRequest(BaseModel):
user_input: str
source_lang: str
target_lang: str
def load_model(pretrained_model: str = "facebook/m2m100_1.2B", cache_dir: str = "models/"):
model_dir = os.path.join(os.getcwd(), cache_dir)
tokenizer = M2M100Tokenizer.from_pretrained(pretrained_model, cache_dir=model_dir)
model = M2M100ForConditionalGeneration.from_pretrained(pretrained_model, cache_dir=model_dir).to(device)
model.eval()
return tokenizer, model
@app.post("/translate")
async def translate(request: TranslationRequest):
"""
language support
Afrikaans (af), Amharic (am), Arabic (ar), Asturian (ast), Azerbaijani (az), Bashkir (ba), Belarusian (be), Bulgarian (bg), Bengali (bn), Breton (br), Bosnian (bs), Catalan; Valencian (ca), Cebuano (ceb), Czech (cs), Welsh (cy), Danish (da), German (de), Greeek (el), English (en), Spanish (es), Estonian (et), Persian (fa), Fulah (ff), Finnish (fi), French (fr), Western Frisian (fy), Irish (ga), Gaelic; Scottish Gaelic (gd), Galician (gl), Gujarati (gu), Hausa (ha), Hebrew (he), Hindi (hi), Croatian (hr), Haitian; Haitian Creole (ht), Hungarian (hu), Armenian (hy), Indonesian (id), Igbo (ig), Iloko (ilo), Icelandic (is), Italian (it), Japanese (ja), Javanese (jv), Georgian (ka), Kazakh (kk), Central Khmer (km), Kannada (kn), Korean (ko), Luxembourgish; Letzeburgesch (lb), Ganda (lg), Lingala (ln), Lao (lo), Lithuanian (lt), Latvian (lv), Malagasy (mg), Macedonian (mk), Malayalam (ml), Mongolian (mn), Marathi (mr), Malay (ms), Burmese (my), Nepali (ne), Dutch; Flemish (nl), Norwegian (no), Northern Sotho (ns), Occitan (post 1500) (oc), Oriya (or), Panjabi; Punjabi (pa), Polish (pl), Pushto; Pashto (ps), Portuguese (pt), Romanian; Moldavian; Moldovan (ro), Russian (ru), Sindhi (sd), Sinhala; Sinhalese (si), Slovak (sk), Slovenian (sl), Somali (so), Albanian (sq), Serbian (sr), Swati (ss), Sundanese (su), Swedish (sv), Swahili (sw), Tamil (ta), Thai (th), Tagalog (tl), Tswana (tn), Turkish (tr), Ukrainian (uk), Urdu (ur), Uzbek (uz), Vietnamese (vi), Wolof (wo), Xhosa (xh), Yiddish (yi), Yoruba (yo), Chinese (zh), Zulu (zu)
"""
try:
tokenizer, model = load_model()
except Exception as E:
return{"error": str(E)}
src_lang = request.source_lang
trg_lang = request.target_lang
tokenizer.src_lang = src_lang
with torch.no_grad():
encoded_input = tokenizer(request.user_input, return_tensors="pt").to(device)
generated_tokens = model.generate(
**encoded_input, forced_bos_token_id=tokenizer.get_lang_id(trg_lang)
)
translated_text = tokenizer.batch_decode(
generated_tokens, skip_special_tokens=True
)[0]
try:
response = {"translation": translated_text}
except Exception as E:
return {"error": str(E)}
return response
# chat
WHITESPACE_HANDLER = lambda k: re.sub('\s+', ' ', re.sub('\n+', ' ', k.strip()))
chat_model_name = "csebuetnlp/mT5_multilingual_XLSum"
tokenizer = AutoTokenizer.from_pretrained(chat_model_name)
modelchat = AutoModelForSeq2SeqLM.from_pretrained(chat_model_name)
@app.get("/chat")
async def read_root(text: str):
input_ids = tokenizer(
[WHITESPACE_HANDLER(text)],
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=512
)["input_ids"]
# max_length=84,
output_ids = modelchat.generate(
input_ids=input_ids,
max_length=500,
no_repeat_ngram_size=2,
num_beams=4
)[0]
summary = tokenizer.decode(
output_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=True
)
return {"summary": summary}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)