Spaces:
Running
Running
import io | |
import time | |
import os | |
from typing import List, Literal | |
from fastapi import FastAPI | |
from pydantic import BaseModel | |
from enum import Enum | |
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration | |
import torch | |
app = FastAPI(docs_url="/") | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
class TranslationRequest(BaseModel): | |
user_input: str | |
source_lang: str | |
target_lang: str | |
def load_model(pretrained_model: str = "facebook/m2m100_1.2B", cache_dir: str = "models/"): | |
model_dir = os.path.join(os.getcwd(), cache_dir) | |
tokenizer = M2M100Tokenizer.from_pretrained(pretrained_model, cache_dir=model_dir) | |
model = M2M100ForConditionalGeneration.from_pretrained(pretrained_model, cache_dir=model_dir).to(device) | |
model.eval() | |
return tokenizer, model | |
async def translate(request: TranslationRequest): | |
""" | |
language support | |
Afrikaans (af), Amharic (am), Arabic (ar), Asturian (ast), Azerbaijani (az), Bashkir (ba), Belarusian (be), Bulgarian (bg), Bengali (bn), Breton (br), Bosnian (bs), Catalan; Valencian (ca), Cebuano (ceb), Czech (cs), Welsh (cy), Danish (da), German (de), Greeek (el), English (en), Spanish (es), Estonian (et), Persian (fa), Fulah (ff), Finnish (fi), French (fr), Western Frisian (fy), Irish (ga), Gaelic; Scottish Gaelic (gd), Galician (gl), Gujarati (gu), Hausa (ha), Hebrew (he), Hindi (hi), Croatian (hr), Haitian; Haitian Creole (ht), Hungarian (hu), Armenian (hy), Indonesian (id), Igbo (ig), Iloko (ilo), Icelandic (is), Italian (it), Japanese (ja), Javanese (jv), Georgian (ka), Kazakh (kk), Central Khmer (km), Kannada (kn), Korean (ko), Luxembourgish; Letzeburgesch (lb), Ganda (lg), Lingala (ln), Lao (lo), Lithuanian (lt), Latvian (lv), Malagasy (mg), Macedonian (mk), Malayalam (ml), Mongolian (mn), Marathi (mr), Malay (ms), Burmese (my), Nepali (ne), Dutch; Flemish (nl), Norwegian (no), Northern Sotho (ns), Occitan (post 1500) (oc), Oriya (or), Panjabi; Punjabi (pa), Polish (pl), Pushto; Pashto (ps), Portuguese (pt), Romanian; Moldavian; Moldovan (ro), Russian (ru), Sindhi (sd), Sinhala; Sinhalese (si), Slovak (sk), Slovenian (sl), Somali (so), Albanian (sq), Serbian (sr), Swati (ss), Sundanese (su), Swedish (sv), Swahili (sw), Tamil (ta), Thai (th), Tagalog (tl), Tswana (tn), Turkish (tr), Ukrainian (uk), Urdu (ur), Uzbek (uz), Vietnamese (vi), Wolof (wo), Xhosa (xh), Yiddish (yi), Yoruba (yo), Chinese (zh), Zulu (zu) | |
""" | |
tokenizer, model = load_model() | |
src_lang = request.source_lang | |
trg_lang = request.target_lang | |
tokenizer.src_lang = src_lang | |
with torch.no_grad(): | |
encoded_input = tokenizer(request.user_input, return_tensors="pt").to(device) | |
generated_tokens = model.generate( | |
**encoded_input, forced_bos_token_id=tokenizer.get_lang_id(trg_lang) | |
) | |
translated_text = tokenizer.batch_decode( | |
generated_tokens, skip_special_tokens=True | |
)[0] | |
try: | |
response = {"translation": translated_text} | |
except Exception as E: | |
return {"error": str(E)} | |
return response | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |