File size: 1,230 Bytes
3d99217
 
 
4828140
01728a1
3d99217
4828140
01728a1
 
 
 
3d99217
 
01728a1
2609cab
 
3d99217
 
 
 
 
01728a1
3d99217
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import pipeline, AutoTokenizer
import os

app = FastAPI()

os.environ["TRANSFORMERS_CACHE"] = "/home/user/.cache/huggingface"
os.environ["HF_HOME"] = "/home/user/.cache/huggingface"

HF_AUTH_TOKEN = os.environ.get("HF_TOKEN")

MODEL_NAME = "VincentMuriuki/legal-summarizer"

summarizer = pipeline("summarization", model=MODEL_NAME, use_auth_token=HF_AUTH_TOKEN)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=HF_AUTH_TOKEN)
class SummarizeInput(BaseModel):
    text: str

class ChunkInput(BaseModel):
    text: str
    max_tokens: int = 512

@app.post("/summarize")
def summarize_text(data: SummarizeInput):
    summary = summarizer(data.text, max_length=150, min_length=30, do_sample=False)
    return {"summary": summary[0]["summary_text"]}

@app.post("/chunk")
def chunk_text(data: ChunkInput):
    tokens = tokenizer.encode(data.text, truncation=False)
    chunks = []

    for i in range(0, len(tokens), data.max_tokens):
        chunk_tokens = tokens[i:i + data.max_tokens]
        chunk_text = tokenizer.decode(chunk_tokens, skip_special_tokens=True)
        chunks.append(chunk_text.strip())

    return {"chunks": chunks}