import io import re import time import os from typing import List, Literal from fastapi import FastAPI from pydantic import BaseModel from enum import Enum from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM import torch import uvicorn app = FastAPI(docs_url="/") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") class TranslationRequest(BaseModel): user_input: str source_lang: str target_lang: str def load_model(pretrained_model: str = "facebook/m2m100_1.2B", cache_dir: str = "models/"): model_dir = os.path.join(os.getcwd(), cache_dir) tokenizer = M2M100Tokenizer.from_pretrained(pretrained_model, cache_dir=model_dir) model = M2M100ForConditionalGeneration.from_pretrained(pretrained_model, cache_dir=model_dir).to(device) model.eval() return tokenizer, model @app.post("/translate") async def translate(request: TranslationRequest): """ language support Afrikaans (af), Amharic (am), Arabic (ar), Asturian (ast), Azerbaijani (az), Bashkir (ba), Belarusian (be), Bulgarian (bg), Bengali (bn), Breton (br), Bosnian (bs), Catalan; Valencian (ca), Cebuano (ceb), Czech (cs), Welsh (cy), Danish (da), German (de), Greeek (el), English (en), Spanish (es), Estonian (et), Persian (fa), Fulah (ff), Finnish (fi), French (fr), Western Frisian (fy), Irish (ga), Gaelic; Scottish Gaelic (gd), Galician (gl), Gujarati (gu), Hausa (ha), Hebrew (he), Hindi (hi), Croatian (hr), Haitian; Haitian Creole (ht), Hungarian (hu), Armenian (hy), Indonesian (id), Igbo (ig), Iloko (ilo), Icelandic (is), Italian (it), Japanese (ja), Javanese (jv), Georgian (ka), Kazakh (kk), Central Khmer (km), Kannada (kn), Korean (ko), Luxembourgish; Letzeburgesch (lb), Ganda (lg), Lingala (ln), Lao (lo), Lithuanian (lt), Latvian (lv), Malagasy (mg), Macedonian (mk), Malayalam (ml), Mongolian (mn), Marathi (mr), Malay (ms), Burmese (my), Nepali (ne), Dutch; Flemish (nl), Norwegian (no), Northern Sotho (ns), Occitan (post 1500) (oc), Oriya (or), Panjabi; Punjabi (pa), Polish (pl), Pushto; Pashto (ps), Portuguese (pt), Romanian; Moldavian; Moldovan (ro), Russian (ru), Sindhi (sd), Sinhala; Sinhalese (si), Slovak (sk), Slovenian (sl), Somali (so), Albanian (sq), Serbian (sr), Swati (ss), Sundanese (su), Swedish (sv), Swahili (sw), Tamil (ta), Thai (th), Tagalog (tl), Tswana (tn), Turkish (tr), Ukrainian (uk), Urdu (ur), Uzbek (uz), Vietnamese (vi), Wolof (wo), Xhosa (xh), Yiddish (yi), Yoruba (yo), Chinese (zh), Zulu (zu) """ try: tokenizer, model = load_model() except Exception as E: return{"error": str(E)} src_lang = request.source_lang trg_lang = request.target_lang tokenizer.src_lang = src_lang with torch.no_grad(): encoded_input = tokenizer(request.user_input, return_tensors="pt").to(device) generated_tokens = model.generate( **encoded_input, forced_bos_token_id=tokenizer.get_lang_id(trg_lang) ) translated_text = tokenizer.batch_decode( generated_tokens, skip_special_tokens=True )[0] try: response = {"translation": translated_text} except Exception as E: return {"error": str(E)} return response # chat WHITESPACE_HANDLER = lambda k: re.sub('\s+', ' ', re.sub('\n+', ' ', k.strip())) chat_model_name = "csebuetnlp/mT5_multilingual_XLSum" tokenizer = AutoTokenizer.from_pretrained(chat_model_name) modelchat = AutoModelForSeq2SeqLM.from_pretrained(chat_model_name) @app.get("/chat") async def read_root(text: str): input_ids = tokenizer( [WHITESPACE_HANDLER(text)], return_tensors="pt", padding="max_length", truncation=True, max_length=512 )["input_ids"] # max_length=84, output_ids = modelchat.generate( input_ids=input_ids, max_length=500, no_repeat_ngram_size=2, num_beams=4 )[0] summary = tokenizer.decode( output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True ) return {"summary": summary} if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)