ViT5BaseNode / app.py
VietCat's picture
switch to fastapi
c5a0bf8
raw
history blame
1.17 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
app = FastAPI()
# Load model
model_name = "VietAI/vit5-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
# Input format
class TextInput(BaseModel):
text: str
@app.get("/")
def read_root():
return {"message": "ViT5 summarization API is running!"}
@app.post("/summarize")
def summarize(input: TextInput):
try:
input_text = f"summarize: {input.text}"
inputs = tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True)
summary_ids = model.generate(
inputs,
max_length=128,
min_length=20,
num_beams=4,
no_repeat_ngram_size=3,
repetition_penalty=2.5,
length_penalty=1.0,
early_stopping=True
)
output = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return {"summary": output}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))