import io import re import time import os from typing import List, Literal from fastapi import FastAPI from pydantic import BaseModel from enum import Enum from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration import torch import uvicorn from fastapi.middleware.cors import CORSMiddleware app = FastAPI(docs_url="/") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") class TranslationRequest(BaseModel): user_input: str source_lang: str target_lang: str model:str = 'facebook/m2m100_418M' # facebook/m2m100_418M # facebook/m2m100_1.2B def load_model(model: str = 'facebook/m2m100_418M' , cache_dir: str = "models/"): model_dir = os.path.join(os.getcwd(), cache_dir) tokenizer = M2M100Tokenizer.from_pretrained(model, cache_dir=model_dir) model = M2M100ForConditionalGeneration.from_pretrained(model, cache_dir=model_dir).to(device) model.eval() return tokenizer, model # aparentemente temos um problema ao carregar o modelo então vou tentar carregar no start da aplicação para não dar time-out na request load_model() @app.post("/translate") async def translate(request: TranslationRequest): """ models: facebook/m2m100_418M | facebook/m2m100_1.2B language support Afrikaans (af), Amharic (am), Arabic (ar), Asturian (ast), Azerbaijani (az), Bashkir (ba), Belarusian (be), Bulgarian (bg), Bengali (bn), Breton (br), Bosnian (bs), Catalan; Valencian (ca), Cebuano (ceb), Czech (cs), Welsh (cy), Danish (da), German (de), Greeek (el), English (en), Spanish (es), Estonian (et), Persian (fa), Fulah (ff), Finnish (fi), French (fr), Western Frisian (fy), Irish (ga), Gaelic; Scottish Gaelic (gd), Galician (gl), Gujarati (gu), Hausa (ha), Hebrew (he), Hindi (hi), Croatian (hr), Haitian; Haitian Creole (ht), Hungarian (hu), Armenian (hy), Indonesian (id), Igbo (ig), Iloko (ilo), Icelandic (is), Italian (it), Japanese (ja), Javanese (jv), Georgian (ka), Kazakh (kk), Central Khmer (km), Kannada (kn), Korean (ko), Luxembourgish; Letzeburgesch (lb), Ganda (lg), Lingala (ln), Lao (lo), Lithuanian (lt), Latvian (lv), Malagasy (mg), Macedonian (mk), Malayalam (ml), Mongolian (mn), Marathi (mr), Malay (ms), Burmese (my), Nepali (ne), Dutch; Flemish (nl), Norwegian (no), Northern Sotho (ns), Occitan (post 1500) (oc), Oriya (or), Panjabi; Punjabi (pa), Polish (pl), Pushto; Pashto (ps), Portuguese (pt), Romanian; Moldavian; Moldovan (ro), Russian (ru), Sindhi (sd), Sinhala; Sinhalese (si), Slovak (sk), Slovenian (sl), Somali (so), Albanian (sq), Serbian (sr), Swati (ss), Sundanese (su), Swedish (sv), Swahili (sw), Tamil (ta), Thai (th), Tagalog (tl), Tswana (tn), Turkish (tr), Ukrainian (uk), Urdu (ur), Uzbek (uz), Vietnamese (vi), Wolof (wo), Xhosa (xh), Yiddish (yi), Yoruba (yo), Chinese (zh), Zulu (zu) """ try: tokenizer, model = load_model(model=request.model) except Exception as E: return{"error": str(E)} src_lang = request.source_lang trg_lang = request.target_lang tokenizer.src_lang = src_lang with torch.no_grad(): encoded_input = tokenizer(request.user_input, return_tensors="pt").to(device) generated_tokens = model.generate( **encoded_input, forced_bos_token_id=tokenizer.get_lang_id(trg_lang) ) translated_text = tokenizer.batch_decode( generated_tokens, skip_special_tokens=True )[0] try: response = {"translation": translated_text} except Exception as E: return {"error": str(E)} return response if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)