king-translate / main.py
gleisonnanet's picture
start teste
fb2eca4
raw
history blame
1.56 kB
import io
import time
from typing import List, Literal
from fastapi import FastAPI
from pydantic import BaseModel
from enum import Enum
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration
import torch
app = FastAPI()
device = torch.device("cpu")
class TranslationRequest(BaseModel):
user_input: str
source_lang: str
target_lang: str
def load_model(pretrained_model: str = "facebook/m2m100_1.2B", cache_dir: str = "models/"):
tokenizer = M2M100Tokenizer.from_pretrained(pretrained_model, cache_dir=cache_dir)
model = M2M100ForConditionalGeneration.from_pretrained(pretrained_model, cache_dir=cache_dir).to(device)
model.eval()
return tokenizer, model
@app.post("/translate")
async def translate(request: TranslationRequest):
time_start = time.time()
tokenizer, model = load_model()
src_lang = request.source_lang
trg_lang = request.target_lang
tokenizer.src_lang = src_lang
with torch.no_grad():
encoded_input = tokenizer(request.user_input, return_tensors="pt").to(device)
generated_tokens = model.generate(
**encoded_input, forced_bos_token_id=tokenizer.get_lang_id(trg_lang)
)
translated_text = tokenizer.batch_decode(
generated_tokens, skip_special_tokens=True
)[0]
time_end = time.time()
response = {"translation": translated_text, "computation_time": round((time_end - time_start), 3)}
return response
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)