spacesedan commited on
Commit
0bda3c0
·
1 Parent(s): cef4a12

making changes

Browse files
Files changed (3) hide show
  1. Dockerfile +0 -10
  2. app.py +25 -27
  3. requirements.txt +0 -1
Dockerfile CHANGED
@@ -1,24 +1,14 @@
1
  FROM python:3.9
2
 
3
- # Create non-root user
4
  RUN useradd -m -u 1000 user
5
  USER user
6
  ENV PATH="/home/user/.local/bin:$PATH"
7
 
8
  WORKDIR /app
9
 
10
- # Copy and install dependencies
11
  COPY --chown=user requirements.txt .
12
  RUN pip install --no-cache-dir --upgrade -r requirements.txt
13
 
14
- # Download NLTK 'punkt' to a known path
15
- RUN python -m nltk.downloader -d /home/user/nltk_data punkt
16
-
17
- # Set env so NLTK can find the punkt data
18
- ENV NLTK_DATA=/home/user/nltk_data
19
-
20
- # Copy app source
21
  COPY --chown=user . /app
22
 
23
- # Run the app
24
  CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
 
1
  FROM python:3.9
2
 
 
3
  RUN useradd -m -u 1000 user
4
  USER user
5
  ENV PATH="/home/user/.local/bin:$PATH"
6
 
7
  WORKDIR /app
8
 
 
9
  COPY --chown=user requirements.txt .
10
  RUN pip install --no-cache-dir --upgrade -r requirements.txt
11
 
 
 
 
 
 
 
 
12
  COPY --chown=user . /app
13
 
 
14
  CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py CHANGED
@@ -4,31 +4,26 @@ from transformers import pipeline, AutoTokenizer
4
  from typing import List
5
  import logging
6
  import torch
7
- import nltk
8
- from nltk.tokenize import sent_tokenize
9
 
10
- # FastAPI app init
11
  app = FastAPI()
12
 
13
  # Configure logging
14
  logging.basicConfig(level=logging.INFO)
15
  logger = logging.getLogger("summarizer")
16
 
17
- # NLTK setup
18
- nltk.download("punkt")
19
-
20
- # Model config
21
  model_name = "sshleifer/distilbart-cnn-12-6"
22
  device = 0 if torch.cuda.is_available() else -1
23
  logger.info(f"Running summarizer on {'GPU' if device == 0 else 'CPU'}")
24
  summarizer = pipeline("summarization", model=model_name, device=device)
25
  tokenizer = AutoTokenizer.from_pretrained(model_name)
26
 
27
- # Token limits
28
  MAX_MODEL_TOKENS = 1024
29
- SAFE_CHUNK_SIZE = 700 # Conservative chunk size to stay below 1024 after re-tokenization
30
 
31
- # Input/output schemas
32
  class SummarizationItem(BaseModel):
33
  content_id: str
34
  text: str
@@ -43,47 +38,50 @@ class SummarizationResponseItem(BaseModel):
43
  class BatchSummarizationResponse(BaseModel):
44
  summaries: List[SummarizationResponseItem]
45
 
46
- # New safe chunking logic using NLTK
 
 
 
47
  def chunk_text(text: str, max_tokens: int = SAFE_CHUNK_SIZE) -> List[str]:
48
- sentences = sent_tokenize(text)
49
  chunks = []
50
- current_chunk = ""
51
 
52
  for sentence in sentences:
53
- temp_chunk = f"{current_chunk} {sentence}".strip()
54
- token_count = len(tokenizer.encode(temp_chunk, truncation=False))
55
 
56
  if token_count <= max_tokens:
57
- current_chunk = temp_chunk
58
  else:
59
- if current_chunk:
60
- chunks.append(current_chunk)
61
- current_chunk = sentence
62
 
63
- if current_chunk:
64
- chunks.append(current_chunk)
65
 
 
66
  final_chunks = []
67
  for chunk in chunks:
68
  encoded = tokenizer(chunk, return_tensors="pt", truncation=False)
69
- actual_len = encoded["input_ids"].shape[1]
70
- if actual_len <= MAX_MODEL_TOKENS:
71
  final_chunks.append(chunk)
72
  else:
73
- logger.warning(f"[CHUNKING] Dropped chunk due to re-encoding overflow: {actual_len} tokens")
74
 
75
  return final_chunks
76
 
77
- # Main summarization endpoint
78
  @app.post("/summarize", response_model=BatchSummarizationResponse)
79
  async def summarize_batch(request: BatchSummarizationRequest):
80
  all_chunks = []
81
  chunk_map = []
82
 
83
  for item in request.inputs:
84
- token_count = len(tokenizer.encode(item.text, truncation=False))
85
  chunks = chunk_text(item.text)
86
- logger.info(f"[CHUNKING] content_id={item.content_id} token_len={token_count} num_chunks={len(chunks)}")
87
 
88
  for chunk in chunks:
89
  all_chunks.append(chunk)
 
4
  from typing import List
5
  import logging
6
  import torch
7
+ import re
 
8
 
 
9
  app = FastAPI()
10
 
11
  # Configure logging
12
  logging.basicConfig(level=logging.INFO)
13
  logger = logging.getLogger("summarizer")
14
 
15
+ # Load model and tokenizer
 
 
 
16
  model_name = "sshleifer/distilbart-cnn-12-6"
17
  device = 0 if torch.cuda.is_available() else -1
18
  logger.info(f"Running summarizer on {'GPU' if device == 0 else 'CPU'}")
19
  summarizer = pipeline("summarization", model=model_name, device=device)
20
  tokenizer = AutoTokenizer.from_pretrained(model_name)
21
 
22
+ # Token constraints
23
  MAX_MODEL_TOKENS = 1024
24
+ SAFE_CHUNK_SIZE = 700
25
 
26
+ # Pydantic schemas
27
  class SummarizationItem(BaseModel):
28
  content_id: str
29
  text: str
 
38
  class BatchSummarizationResponse(BaseModel):
39
  summaries: List[SummarizationResponseItem]
40
 
41
+ # Sentence-based chunking
42
+ def split_sentences(text: str) -> list[str]:
43
+ return re.split(r'(?<=[.!?])\s+', text.strip())
44
+
45
  def chunk_text(text: str, max_tokens: int = SAFE_CHUNK_SIZE) -> List[str]:
46
+ sentences = split_sentences(text)
47
  chunks = []
48
+ current_chunk_sentences = []
49
 
50
  for sentence in sentences:
51
+ tentative_chunk = " ".join(current_chunk_sentences + [sentence])
52
+ token_count = len(tokenizer.encode(tentative_chunk, truncation=False))
53
 
54
  if token_count <= max_tokens:
55
+ current_chunk_sentences.append(sentence)
56
  else:
57
+ if current_chunk_sentences:
58
+ chunks.append(" ".join(current_chunk_sentences))
59
+ current_chunk_sentences = [sentence]
60
 
61
+ if current_chunk_sentences:
62
+ chunks.append(" ".join(current_chunk_sentences))
63
 
64
+ # Final filter: ensure nothing slipped through
65
  final_chunks = []
66
  for chunk in chunks:
67
  encoded = tokenizer(chunk, return_tensors="pt", truncation=False)
68
+ token_len = encoded["input_ids"].shape[1]
69
+ if token_len <= MAX_MODEL_TOKENS:
70
  final_chunks.append(chunk)
71
  else:
72
+ logger.warning(f"[CHUNKING] Dropped oversized chunk: {token_len} tokens")
73
 
74
  return final_chunks
75
 
76
+ # Summarization endpoint
77
  @app.post("/summarize", response_model=BatchSummarizationResponse)
78
  async def summarize_batch(request: BatchSummarizationRequest):
79
  all_chunks = []
80
  chunk_map = []
81
 
82
  for item in request.inputs:
 
83
  chunks = chunk_text(item.text)
84
+ logger.info(f"[CHUNKING] content_id={item.content_id} num_chunks={len(chunks)}")
85
 
86
  for chunk in chunks:
87
  all_chunks.append(chunk)
requirements.txt CHANGED
@@ -2,5 +2,4 @@ fastapi
2
  uvicorn[standard]
3
  transformers
4
  torch
5
- nltk
6
  pydantic
 
2
  uvicorn[standard]
3
  transformers
4
  torch
 
5
  pydantic