File size: 3,202 Bytes
fb2eca4
 
 
 
 
 
 
 
 
03b7efe
 
 
fb2eca4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2005977
 
 
 
fb2eca4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a04ade
03b7efe
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import io
import time
from typing import List, Literal
from fastapi import FastAPI
from pydantic import BaseModel
from enum import Enum
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration
import torch

app = FastAPI(docs_url="/")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class TranslationRequest(BaseModel):
    user_input: str
    source_lang: str
    target_lang: str


def load_model(pretrained_model: str = "facebook/m2m100_1.2B", cache_dir: str = "models/"):
    tokenizer = M2M100Tokenizer.from_pretrained(pretrained_model, cache_dir=cache_dir)
    model = M2M100ForConditionalGeneration.from_pretrained(pretrained_model, cache_dir=cache_dir).to(device)
    model.eval()
    return tokenizer, model


@app.post("/translate")
async def translate(request: TranslationRequest):
    """
    language support
    Afrikaans (af), Amharic (am), Arabic (ar), Asturian (ast), Azerbaijani (az), Bashkir (ba), Belarusian (be), Bulgarian (bg), Bengali (bn), Breton (br), Bosnian (bs), Catalan; Valencian (ca), Cebuano (ceb), Czech (cs), Welsh (cy), Danish (da), German (de), Greeek (el), English (en), Spanish (es), Estonian (et), Persian (fa), Fulah (ff), Finnish (fi), French (fr), Western Frisian (fy), Irish (ga), Gaelic; Scottish Gaelic (gd), Galician (gl), Gujarati (gu), Hausa (ha), Hebrew (he), Hindi (hi), Croatian (hr), Haitian; Haitian Creole (ht), Hungarian (hu), Armenian (hy), Indonesian (id), Igbo (ig), Iloko (ilo), Icelandic (is), Italian (it), Japanese (ja), Javanese (jv), Georgian (ka), Kazakh (kk), Central Khmer (km), Kannada (kn), Korean (ko), Luxembourgish; Letzeburgesch (lb), Ganda (lg), Lingala (ln), Lao (lo), Lithuanian (lt), Latvian (lv), Malagasy (mg), Macedonian (mk), Malayalam (ml), Mongolian (mn), Marathi (mr), Malay (ms), Burmese (my), Nepali (ne), Dutch; Flemish (nl), Norwegian (no), Northern Sotho (ns), Occitan (post 1500) (oc), Oriya (or), Panjabi; Punjabi (pa), Polish (pl), Pushto; Pashto (ps), Portuguese (pt), Romanian; Moldavian; Moldovan (ro), Russian (ru), Sindhi (sd), Sinhala; Sinhalese (si), Slovak (sk), Slovenian (sl), Somali (so), Albanian (sq), Serbian (sr), Swati (ss), Sundanese (su), Swedish (sv), Swahili (sw), Tamil (ta), Thai (th), Tagalog (tl), Tswana (tn), Turkish (tr), Ukrainian (uk), Urdu (ur), Uzbek (uz), Vietnamese (vi), Wolof (wo), Xhosa (xh), Yiddish (yi), Yoruba (yo), Chinese (zh), Zulu (zu)
    """
    time_start = time.time()
    tokenizer, model = load_model()
    src_lang = request.source_lang
    trg_lang = request.target_lang
    tokenizer.src_lang = src_lang
    with torch.no_grad():
        encoded_input = tokenizer(request.user_input, return_tensors="pt").to(device)
        generated_tokens = model.generate(
            **encoded_input, forced_bos_token_id=tokenizer.get_lang_id(trg_lang)
        )
        translated_text = tokenizer.batch_decode(
            generated_tokens, skip_special_tokens=True
        )[0]
    time_end = time.time()
    response = {"translation": translated_text, "computation_time": round((time_end - time_start), 3)}
    return response


if __name__ == "__main__":
    import uvicorn

    uvicorn.run(app, host="0.0.0.0", port=80)