spacesedan commited on
Commit
45e1223
·
1 Parent(s): 71a1190
Files changed (1) hide show
  1. app.py +15 -8
app.py CHANGED
@@ -3,6 +3,7 @@ from pydantic import BaseModel
3
  from transformers import pipeline, AutoTokenizer
4
  from typing import List
5
  import logging
 
6
 
7
  app = FastAPI()
8
 
@@ -12,7 +13,9 @@ logger = logging.getLogger("summarizer")
12
 
13
  # Faster and lighter summarization model
14
  model_name = "sshleifer/distilbart-cnn-12-6"
15
- summarizer = pipeline("summarization", model=model_name)
 
 
16
  tokenizer = AutoTokenizer.from_pretrained(model_name)
17
 
18
  class SummarizationItem(BaseModel):
@@ -39,8 +42,7 @@ def chunk_text(text: str, max_tokens: int = SAFE_CHUNK_SIZE) -> List[str]:
39
 
40
  for i in range(0, len(tokens), max_tokens):
41
  chunk_tokens = tokens[i:i + max_tokens]
42
- if len(chunk_tokens) > MAX_MODEL_TOKENS:
43
- chunk_tokens = chunk_tokens[:MAX_MODEL_TOKENS]
44
  chunk = tokenizer.decode(chunk_tokens, skip_special_tokens=True)
45
  chunks.append(chunk)
46
 
@@ -49,17 +51,23 @@ def chunk_text(text: str, max_tokens: int = SAFE_CHUNK_SIZE) -> List[str]:
49
  @app.post("/summarize", response_model=BatchSummarizationResponse)
50
  async def summarize_batch(request: BatchSummarizationRequest):
51
  all_chunks = []
52
- chunk_map = [] # maps index of chunk to content_id
53
 
54
  for item in request.inputs:
55
  token_count = len(tokenizer.encode(item.text, truncation=False))
56
  chunks = chunk_text(item.text)
57
  logger.info(f"[CHUNKING] content_id={item.content_id} token_len={token_count} num_chunks={len(chunks)}")
58
- all_chunks.extend(chunks)
59
- chunk_map.extend([item.content_id] * len(chunks))
 
 
 
 
 
 
60
 
61
  if not all_chunks:
62
- logger.error("No valid chunks after chunking. Returning empty response.")
63
  return {"summaries": []}
64
 
65
  summaries = summarizer(
@@ -71,7 +79,6 @@ async def summarize_batch(request: BatchSummarizationRequest):
71
  batch_size=4
72
  )
73
 
74
- # Aggregate summaries back per content_id
75
  summary_map = {}
76
  for content_id, result in zip(chunk_map, summaries):
77
  summary_map.setdefault(content_id, []).append(result["summary_text"])
 
3
  from transformers import pipeline, AutoTokenizer
4
  from typing import List
5
  import logging
6
+ import torch
7
 
8
  app = FastAPI()
9
 
 
13
 
14
  # Faster and lighter summarization model
15
  model_name = "sshleifer/distilbart-cnn-12-6"
16
+ device = 0 if torch.cuda.is_available() else -1
17
+ logger.info(f"Running summarizer on {'GPU' if device == 0 else 'CPU'}")
18
+ summarizer = pipeline("summarization", model=model_name, device=device)
19
  tokenizer = AutoTokenizer.from_pretrained(model_name)
20
 
21
  class SummarizationItem(BaseModel):
 
42
 
43
  for i in range(0, len(tokens), max_tokens):
44
  chunk_tokens = tokens[i:i + max_tokens]
45
+ chunk_tokens = chunk_tokens[:MAX_MODEL_TOKENS]
 
46
  chunk = tokenizer.decode(chunk_tokens, skip_special_tokens=True)
47
  chunks.append(chunk)
48
 
 
51
  @app.post("/summarize", response_model=BatchSummarizationResponse)
52
  async def summarize_batch(request: BatchSummarizationRequest):
53
  all_chunks = []
54
+ chunk_map = []
55
 
56
  for item in request.inputs:
57
  token_count = len(tokenizer.encode(item.text, truncation=False))
58
  chunks = chunk_text(item.text)
59
  logger.info(f"[CHUNKING] content_id={item.content_id} token_len={token_count} num_chunks={len(chunks)}")
60
+ for chunk in chunks:
61
+ encoded = tokenizer(chunk, return_tensors="pt", truncation=False)
62
+ final_len = encoded["input_ids"].shape[1]
63
+ if final_len > MAX_MODEL_TOKENS:
64
+ logger.warning(f"[SKIP] content_id={item.content_id} chunk still too long after decode: {final_len} tokens")
65
+ continue
66
+ all_chunks.append(chunk)
67
+ chunk_map.append(item.content_id)
68
 
69
  if not all_chunks:
70
+ logger.error("No valid chunks after filtering. Returning empty response.")
71
  return {"summaries": []}
72
 
73
  summaries = summarizer(
 
79
  batch_size=4
80
  )
81
 
 
82
  summary_map = {}
83
  for content_id, result in zip(chunk_map, summaries):
84
  summary_map.setdefault(content_id, []).append(result["summary_text"])