Spaces:
Running
Running
import io | |
import time | |
from typing import List, Literal | |
from fastapi import FastAPI | |
from pydantic import BaseModel | |
from enum import Enum | |
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration | |
import torch | |
app = FastAPI() | |
device = torch.device("cpu") | |
class TranslationRequest(BaseModel): | |
user_input: str | |
source_lang: str | |
target_lang: str | |
def load_model(pretrained_model: str = "facebook/m2m100_1.2B", cache_dir: str = "models/"): | |
tokenizer = M2M100Tokenizer.from_pretrained(pretrained_model, cache_dir=cache_dir) | |
model = M2M100ForConditionalGeneration.from_pretrained(pretrained_model, cache_dir=cache_dir).to(device) | |
model.eval() | |
return tokenizer, model | |
async def translate(request: TranslationRequest): | |
time_start = time.time() | |
tokenizer, model = load_model() | |
src_lang = request.source_lang | |
trg_lang = request.target_lang | |
tokenizer.src_lang = src_lang | |
with torch.no_grad(): | |
encoded_input = tokenizer(request.user_input, return_tensors="pt").to(device) | |
generated_tokens = model.generate( | |
**encoded_input, forced_bos_token_id=tokenizer.get_lang_id(trg_lang) | |
) | |
translated_text = tokenizer.batch_decode( | |
generated_tokens, skip_special_tokens=True | |
)[0] | |
time_end = time.time() | |
response = {"translation": translated_text, "computation_time": round((time_end - time_start), 3)} | |
return response | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |