spacesedan commited on
Commit
a93595a
·
1 Parent(s): eb54abc

ai on ai crime

Browse files
Files changed (1) hide show
  1. app.py +9 -6
app.py CHANGED
@@ -3,6 +3,7 @@ from pydantic import BaseModel
3
  from transformers import pipeline, AutoTokenizer
4
  from typing import List
5
  import logging
 
6
 
7
  app = FastAPI()
8
 
@@ -55,14 +56,16 @@ async def summarize_batch(request: BatchSummarizationRequest):
55
  all_chunks.extend(chunks)
56
  chunk_map.extend([item.content_id] * len(chunks))
57
 
58
- # Final safety pass to enforce 1024 token limit after decoding
59
  safe_chunks = []
60
  for chunk in all_chunks:
61
- encoded = tokenizer.encode(chunk, truncation=False)
62
- if len(encoded) > MAX_MODEL_TOKENS:
63
- logger.warning(f"[TRUNCATING] Chunk exceeded max tokens ({len(encoded)}), trimming to {MAX_MODEL_TOKENS} tokens")
64
- encoded = encoded[:MAX_MODEL_TOKENS]
65
- safe_chunks.append(tokenizer.decode(encoded, skip_special_tokens=True))
 
 
66
 
67
  summaries = summarizer(
68
  safe_chunks,
 
3
  from transformers import pipeline, AutoTokenizer
4
  from typing import List
5
  import logging
6
+ import torch
7
 
8
  app = FastAPI()
9
 
 
56
  all_chunks.extend(chunks)
57
  chunk_map.extend([item.content_id] * len(chunks))
58
 
59
+ # Enforce token limit using tensor shape
60
  safe_chunks = []
61
  for chunk in all_chunks:
62
+ inputs = tokenizer(chunk, return_tensors="pt", truncation=False)
63
+ token_length = inputs["input_ids"].shape[1]
64
+ if token_length > MAX_MODEL_TOKENS:
65
+ logger.warning(f"[TRUNCATING] Chunk token length {token_length} > {MAX_MODEL_TOKENS}, truncating.")
66
+ inputs = tokenizer(chunk, return_tensors="pt", truncation=True, max_length=MAX_MODEL_TOKENS)
67
+ chunk = tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
68
+ safe_chunks.append(chunk)
69
 
70
  summaries = summarizer(
71
  safe_chunks,