spacesedan commited on
Commit
ed4c020
·
1 Parent(s): a93595a
Files changed (1) hide show
  1. app.py +6 -9
app.py CHANGED
@@ -3,7 +3,6 @@ from pydantic import BaseModel
3
  from transformers import pipeline, AutoTokenizer
4
  from typing import List
5
  import logging
6
- import torch
7
 
8
  app = FastAPI()
9
 
@@ -56,16 +55,14 @@ async def summarize_batch(request: BatchSummarizationRequest):
56
  all_chunks.extend(chunks)
57
  chunk_map.extend([item.content_id] * len(chunks))
58
 
59
- # Enforce token limit using tensor shape
60
  safe_chunks = []
61
  for chunk in all_chunks:
62
- inputs = tokenizer(chunk, return_tensors="pt", truncation=False)
63
- token_length = inputs["input_ids"].shape[1]
64
- if token_length > MAX_MODEL_TOKENS:
65
- logger.warning(f"[TRUNCATING] Chunk token length {token_length} > {MAX_MODEL_TOKENS}, truncating.")
66
- inputs = tokenizer(chunk, return_tensors="pt", truncation=True, max_length=MAX_MODEL_TOKENS)
67
- chunk = tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
68
- safe_chunks.append(chunk)
69
 
70
  summaries = summarizer(
71
  safe_chunks,
 
3
  from transformers import pipeline, AutoTokenizer
4
  from typing import List
5
  import logging
 
6
 
7
  app = FastAPI()
8
 
 
55
  all_chunks.extend(chunks)
56
  chunk_map.extend([item.content_id] * len(chunks))
57
 
58
+ # Hard-truncate chunks during encoding and decode safely
59
  safe_chunks = []
60
  for chunk in all_chunks:
61
+ encoded = tokenizer.encode(chunk, truncation=True, max_length=MAX_MODEL_TOKENS)
62
+ if len(encoded) >= MAX_MODEL_TOKENS:
63
+ logger.warning(f"[TRUNCATING] Chunk encoded to {len(encoded)} tokens, trimming to {MAX_MODEL_TOKENS}.")
64
+ decoded = tokenizer.decode(encoded, skip_special_tokens=True)
65
+ safe_chunks.append(decoded)
 
 
66
 
67
  summaries = summarizer(
68
  safe_chunks,