import io import time import os from typing import List, Literal from fastapi import FastAPI from pydantic import BaseModel from enum import Enum from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration import torch app = FastAPI(docs_url="/") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") class TranslationRequest(BaseModel): user_input: str source_lang: str target_lang: str def load_model(pretrained_model: str = "facebook/m2m100_1.2B", cache_dir: str = "models/"): model_dir = os.path.join(os.getcwd(), cache_dir) tokenizer = M2M100Tokenizer.from_pretrained(pretrained_model, cache_dir=model_dir) model = M2M100ForConditionalGeneration.from_pretrained(pretrained_model, cache_dir=model_dir).to(device) model.eval() return tokenizer, model @app.post("/translate") async def translate(request: TranslationRequest): """ language support Afrikaans (af), Amharic (am), Arabic (ar), Asturian (ast), Azerbaijani (az), Bashkir (ba), Belarusian (be), Bulgarian (bg), Bengali (bn), Breton (br), Bosnian (bs), Catalan; Valencian (ca), Cebuano (ceb), Czech (cs), Welsh (cy), Danish (da), German (de), Greeek (el), English (en), Spanish (es), Estonian (et), Persian (fa), Fulah (ff), Finnish (fi), French (fr), Western Frisian (fy), Irish (ga), Gaelic; Scottish Gaelic (gd), Galician (gl), Gujarati (gu), Hausa (ha), Hebrew (he), Hindi (hi), Croatian (hr), Haitian; Haitian Creole (ht), Hungarian (hu), Armenian (hy), Indonesian (id), Igbo (ig), Iloko (ilo), Icelandic (is), Italian (it), Japanese (ja), Javanese (jv), Georgian (ka), Kazakh (kk), Central Khmer (km), Kannada (kn), Korean (ko), Luxembourgish; Letzeburgesch (lb), Ganda (lg), Lingala (ln), Lao (lo), Lithuanian (lt), Latvian (lv), Malagasy (mg), Macedonian (mk), Malayalam (ml), Mongolian (mn), Marathi (mr), Malay (ms), Burmese (my), Nepali (ne), Dutch; Flemish (nl), Norwegian (no), Northern Sotho (ns), Occitan (post 1500) (oc), Oriya (or), Panjabi; Punjabi (pa), Polish (pl), Pushto; Pashto (ps), Portuguese (pt), Romanian; Moldavian; Moldovan (ro), Russian (ru), Sindhi (sd), Sinhala; Sinhalese (si), Slovak (sk), Slovenian (sl), Somali (so), Albanian (sq), Serbian (sr), Swati (ss), Sundanese (su), Swedish (sv), Swahili (sw), Tamil (ta), Thai (th), Tagalog (tl), Tswana (tn), Turkish (tr), Ukrainian (uk), Urdu (ur), Uzbek (uz), Vietnamese (vi), Wolof (wo), Xhosa (xh), Yiddish (yi), Yoruba (yo), Chinese (zh), Zulu (zu) """ tokenizer, model = load_model() src_lang = request.source_lang trg_lang = request.target_lang tokenizer.src_lang = src_lang with torch.no_grad(): encoded_input = tokenizer(request.user_input, return_tensors="pt").to(device) generated_tokens = model.generate( **encoded_input, forced_bos_token_id=tokenizer.get_lang_id(trg_lang) ) translated_text = tokenizer.batch_decode( generated_tokens, skip_special_tokens=True )[0] try: response = {"translation": translated_text} except Exception as E: return {"error": str(E)} return response if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)