File size: 1,199 Bytes
3d99217
 
 
49a53d3
bf808a4
cec4ed4
01728a1
a166f8e
 
dc09052
 
 
 
 
49a53d3
cec4ed4
01728a1
a166f8e
cec4ed4
bf808a4
a166f8e
 
cec4ed4
3d99217
 
 
 
 
a166f8e
3d99217
 
 
 
 
 
 
 
a166f8e
3d99217
a166f8e
9da1804
49a53d3
a166f8e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import pipeline, AutoTokenizer
import nltk
import os
import uvicorn

from chunker import chunk_by_token_limit  

NLTK_DATA_DIR = "/app/nltk_data"
os.makedirs(NLTK_DATA_DIR, exist_ok=True)
nltk.data.path.append(NLTK_DATA_DIR)

nltk.download("punkt", download_dir=NLTK_DATA_DIR, quiet=True)

app = FastAPI()

HF_AUTH_TOKEN = os.getenv("HF_TOKEN")

MODEL_NAME = "VincentMuriuki/legal-summarizer"
summarizer = pipeline("summarization", model=MODEL_NAME, use_auth_token=HF_AUTH_TOKEN)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=HF_AUTH_TOKEN)

class SummarizeInput(BaseModel):
    text: str

class ChunkInput(BaseModel):
    text: str
    max_tokens: int = 1024  

@app.post("/summarize")
def summarize_text(data: SummarizeInput):
    summary = summarizer(data.text, max_length=150, min_length=30, do_sample=False)
    return {"summary": summary[0]["summary_text"]}

@app.post("/chunk")
def chunk_text(data: ChunkInput):
    chunks = chunk_by_token_limit(data.text, data.max_tokens, tokenizer)
    return {"chunks": chunks}

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)