from fastapi import FastAPI from pydantic import BaseModel from transformers import pipeline, AutoTokenizer import nltk import os import uvicorn nltk.download('punkt', quiet=True) app = FastAPI() HF_AUTH_TOKEN = os.getenv("HF_TOKEN") MODEL_NAME = "VincentMuriuki/legal-summarizer" summarizer = pipeline("summarization", model=MODEL_NAME, token=HF_AUTH_TOKEN) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_AUTH_TOKEN) class SummarizeInput(BaseModel): text: str class ChunkInput(BaseModel): text: str max_tokens: int = 512 # Default chunk size # Summarize endpoint @app.post("/summarize") def summarize_text(data: SummarizeInput): summary = summarizer(data.text, max_length=150, min_length=30, do_sample=False) return {"summary": summary[0]["summary_text"]} @app.post("/chunk") def chunk_text(data: ChunkInput): sentences = nltk.sent_tokenize(data.text) chunks = [] current_chunk = "" current_token_count = 0 for sentence in sentences: token_count = len(tokenizer.tokenize(sentence)) if current_token_count + token_count > data.max_tokens: if current_chunk: chunks.append(current_chunk.strip()) current_chunk = sentence current_token_count = token_count else: current_chunk = f"{current_chunk} {sentence}".strip() current_token_count += token_count if current_chunk: chunks.append(current_chunk.strip()) return {"chunks": chunks} if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)