from fastapi import FastAPI, Request from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import torch app = FastAPI() # Load model và tokenizer model_name = "VietAI/vit5-base-vietnews-summarization" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSeq2SeqLM.from_pretrained(model_name) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # Định nghĩa schema đầu vào class SummaryRequest(BaseModel): text: str @app.get("/") def read_root(): return {"message": "VietAI viT5 summarization API is running."} @app.post("/summarize") def summarize(request: SummaryRequest): text = request.text.strip() if not text: return {"summary": ""} prefix = "vietnews: " + text + " " encoding = tokenizer(prefix, return_tensors="pt", truncation=True, max_length=512) input_ids = encoding["input_ids"].to(device) attention_mask = encoding["attention_mask"].to(device) outputs = model.generate( input_ids=input_ids, attention_mask=attention_mask, max_length=128, # Tóm tắt ngắn gọn do_sample=False, # Không sampling num_beams=1 # Greedy decoding (nhanh nhất) ) summary = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True) return {"summary": summary}