sagar008's picture
Update app.py
a166f8e verified
raw
history blame
1.06 kB
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import pipeline, AutoTokenizer
import nltk
import os
import uvicorn
from chunker import chunk_by_token_limit
nltk.download('punkt', quiet=True)
app = FastAPI()
HF_AUTH_TOKEN = os.getenv("HF_TOKEN")
MODEL_NAME = "VincentMuriuki/legal-summarizer"
summarizer = pipeline("summarization", model=MODEL_NAME, use_auth_token=HF_AUTH_TOKEN)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=HF_AUTH_TOKEN)
class SummarizeInput(BaseModel):
text: str
class ChunkInput(BaseModel):
text: str
max_tokens: int = 1024
@app.post("/summarize")
def summarize_text(data: SummarizeInput):
summary = summarizer(data.text, max_length=150, min_length=30, do_sample=False)
return {"summary": summary[0]["summary_text"]}
@app.post("/chunk")
def chunk_text(data: ChunkInput):
chunks = chunk_by_token_limit(data.text, data.max_tokens, tokenizer)
return {"chunks": chunks}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)