sagar008 commited on
Commit
49a53d3
·
verified ·
1 Parent(s): 96e2541

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -10
app.py CHANGED
@@ -1,12 +1,15 @@
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  from transformers import pipeline, AutoTokenizer
 
4
  import os
5
  import uvicorn
6
 
 
 
7
  app = FastAPI()
8
 
9
- HF_AUTH_TOKEN = os.getenv("HF_TOKEN")
10
 
11
  MODEL_NAME = "VincentMuriuki/legal-summarizer"
12
  summarizer = pipeline("summarization", model=MODEL_NAME, token=HF_AUTH_TOKEN)
@@ -17,8 +20,9 @@ class SummarizeInput(BaseModel):
17
 
18
  class ChunkInput(BaseModel):
19
  text: str
20
- max_tokens: int = 512
21
 
 
22
  @app.post("/summarize")
23
  def summarize_text(data: SummarizeInput):
24
  summary = summarizer(data.text, max_length=150, min_length=30, do_sample=False)
@@ -26,15 +30,25 @@ def summarize_text(data: SummarizeInput):
26
 
27
  @app.post("/chunk")
28
  def chunk_text(data: ChunkInput):
29
- tokens = tokenizer.encode(data.text, truncation=False)
30
  chunks = []
31
-
32
- for i in range(0, len(tokens), data.max_tokens):
33
- chunk_tokens = tokens[i:i + data.max_tokens]
34
- chunk_text = tokenizer.decode(chunk_tokens, skip_special_tokens=True)
35
- chunks.append(chunk_text.strip())
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  return {"chunks": chunks}
38
-
39
  if __name__ == "__main__":
40
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  from transformers import pipeline, AutoTokenizer
4
+ import nltk
5
  import os
6
  import uvicorn
7
 
8
+ nltk.download('punkt', quiet=True)
9
+
10
  app = FastAPI()
11
 
12
+ HF_AUTH_TOKEN = os.getenv("HF_TOKEN")
13
 
14
  MODEL_NAME = "VincentMuriuki/legal-summarizer"
15
  summarizer = pipeline("summarization", model=MODEL_NAME, token=HF_AUTH_TOKEN)
 
20
 
21
  class ChunkInput(BaseModel):
22
  text: str
23
+ max_tokens: int = 512 # Default chunk size
24
 
25
+ # Summarize endpoint
26
  @app.post("/summarize")
27
  def summarize_text(data: SummarizeInput):
28
  summary = summarizer(data.text, max_length=150, min_length=30, do_sample=False)
 
30
 
31
  @app.post("/chunk")
32
  def chunk_text(data: ChunkInput):
33
+ sentences = nltk.sent_tokenize(data.text)
34
  chunks = []
35
+ current_chunk = ""
36
+ current_token_count = 0
37
+
38
+ for sentence in sentences:
39
+ token_count = len(tokenizer.tokenize(sentence))
40
+ if current_token_count + token_count > data.max_tokens:
41
+ if current_chunk:
42
+ chunks.append(current_chunk.strip())
43
+ current_chunk = sentence
44
+ current_token_count = token_count
45
+ else:
46
+ current_chunk = f"{current_chunk} {sentence}".strip()
47
+ current_token_count += token_count
48
+
49
+ if current_chunk:
50
+ chunks.append(current_chunk.strip())
51
 
52
  return {"chunks": chunks}
 
53
  if __name__ == "__main__":
54
+ uvicorn.run(app, host="0.0.0.0", port=7860)