Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import torch | |
app = FastAPI() | |
# Load model | |
model_name = "VietAI/vit5-base" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
# Input format | |
class TextInput(BaseModel): | |
text: str | |
def read_root(): | |
return {"message": "ViT5 summarization API is running!"} | |
def summarize(input: TextInput): | |
try: | |
input_text = f"summarize: {input.text}" | |
inputs = tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True) | |
summary_ids = model.generate( | |
inputs, | |
max_length=128, | |
min_length=20, | |
num_beams=4, | |
no_repeat_ngram_size=3, | |
repetition_penalty=2.5, | |
length_penalty=1.0, | |
early_stopping=True | |
) | |
output = tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
return {"summary": output} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |