sagar008's picture
Update app.py
49a53d3 verified
raw
history blame
1.59 kB
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import pipeline, AutoTokenizer
import nltk
import os
import uvicorn
nltk.download('punkt', quiet=True)
app = FastAPI()
HF_AUTH_TOKEN = os.getenv("HF_TOKEN")
MODEL_NAME = "VincentMuriuki/legal-summarizer"
summarizer = pipeline("summarization", model=MODEL_NAME, token=HF_AUTH_TOKEN)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_AUTH_TOKEN)
class SummarizeInput(BaseModel):
text: str
class ChunkInput(BaseModel):
text: str
max_tokens: int = 512 # Default chunk size
# Summarize endpoint
@app.post("/summarize")
def summarize_text(data: SummarizeInput):
summary = summarizer(data.text, max_length=150, min_length=30, do_sample=False)
return {"summary": summary[0]["summary_text"]}
@app.post("/chunk")
def chunk_text(data: ChunkInput):
sentences = nltk.sent_tokenize(data.text)
chunks = []
current_chunk = ""
current_token_count = 0
for sentence in sentences:
token_count = len(tokenizer.tokenize(sentence))
if current_token_count + token_count > data.max_tokens:
if current_chunk:
chunks.append(current_chunk.strip())
current_chunk = sentence
current_token_count = token_count
else:
current_chunk = f"{current_chunk} {sentence}".strip()
current_token_count += token_count
if current_chunk:
chunks.append(current_chunk.strip())
return {"chunks": chunks}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)