Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, Request | |
from pydantic import BaseModel | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import torch | |
app = FastAPI() | |
# Load model và tokenizer | |
model_name = "VietAI/vit5-base" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
# Định nghĩa schema đầu vào | |
class SummaryRequest(BaseModel): | |
text: str | |
def read_root(): | |
return {"message": "VietAI viT5 summarization API is running."} | |
def summarize(request: SummaryRequest): | |
text = request.text.strip() | |
if not text: | |
return {"summary": ""} | |
prefix = "vietnews: " + text + " </s>" | |
encoding = tokenizer(prefix, return_tensors="pt", truncation=True, max_length=512) | |
input_ids = encoding["input_ids"].to(device) | |
attention_mask = encoding["attention_mask"].to(device) | |
outputs = model.generate( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
max_length=128, # Tóm tắt ngắn gọn | |
do_sample=False, # Không sampling | |
num_beams=1 # Greedy decoding (nhanh nhất) | |
) | |
summary = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
return {"summary": summary} | |