csc525_retrieval_based_chatbot / conversation_summarizer.py
JoeArmani
updates - new iteration with type token
7a0020b
raw
history blame
5.38 kB
import tensorflow as tf
from typing import List, Dict
from transformers import TFAutoModelForSeq2SeqLM, AutoTokenizer
import logging
from dataclasses import dataclass
logger = logging.getLogger(__name__)
@dataclass
class ChatConfig:
max_sequence_length: int = 512
default_top_k: int = 10
chunk_size: int = 512
chunk_overlap: int = 256
min_confidence_score: float = 0.7
class DeviceAwareModel:
"""Mixin to handle device placement and mixed precision training."""
def setup_device(self, device: str = None):
if device is None:
device = 'GPU' if tf.config.list_physical_devices('GPU') else 'CPU'
self.device = device.upper()
self.strategy = None
if self.device == 'GPU':
# # Enable mixed precision for better performance
# policy = tf.keras.mixed_precision.Policy('mixed_float16')
# tf.keras.mixed_precision.set_global_policy(policy)
# Setup distribution strategy for multi-GPU if available
gpus = tf.config.list_physical_devices('GPU')
if len(gpus) > 1:
self.strategy = tf.distribute.MirroredStrategy()
return self.device
def run_on_device(self, func):
"""Decorator to ensure ops run on the correct device."""
def wrapper(*args, **kwargs):
with tf.device(f'/{self.device}:0'):
return func(*args, **kwargs)
return wrapper
class Summarizer(DeviceAwareModel):
"""
Enhanced T5-based summarizer with better chunking and device management.
Handles long conversations by intelligent chunking and progressive summarization.
"""
def __init__(
self,
tokenizer: AutoTokenizer,
model_name="t5-small",
max_summary_length=128,
device=None,
max_summary_rounds=2
):
self.tokenizer = tokenizer # Injected tokenizer
self.setup_device(device)
# Initialize model within strategy scope if using distribution
if self.strategy:
with self.strategy.scope():
self._setup_model(model_name)
else:
self._setup_model(model_name)
self.max_summary_length = max_summary_length
self.max_summary_rounds = max_summary_rounds
def _setup_model(self, model_name):
self.model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)
# Optimize model for inference
self.model.generate = tf.function(
self.model.generate,
input_signature=[
{
'input_ids': tf.TensorSpec(shape=[None, None], dtype=tf.int32),
'attention_mask': tf.TensorSpec(shape=[None, None], dtype=tf.int32)
}
]
)
@tf.function
def _generate_summary(self, inputs):
return self.model.generate(
inputs,
max_length=self.max_summary_length,
num_beams=4,
length_penalty=2.0,
early_stopping=True,
no_repeat_ngram_size=3
)
def chunk_text(self, text: str, chunk_size: int = 512, overlap: int = 256) -> List[str]:
"""Split text into overlapping chunks for better context preservation."""
tokens = self.tokenizer.encode(text)
chunks = []
for i in range(0, len(tokens), chunk_size - overlap):
chunk = tokens[i:i + chunk_size]
chunks.append(self.tokenizer.decode(chunk, skip_special_tokens=True))
return chunks
def summarize_text(
self,
text: str,
progressive: bool = True,
round_idx: int = 0
) -> str:
"""
Summarize text with optional progressive summarization
and limit the maximum number of re-summarization rounds.
"""
@self.run_on_device
def _summarize_chunk(chunk: str) -> str:
input_text = "summarize: " + chunk
inputs = self.tokenizer(
input_text,
return_tensors="tf",
padding=True,
truncation=True
)
summary_ids = self._generate_summary(inputs)
return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
# If we've hit our max allowed summarization rounds, just do a single pass
if round_idx >= self.max_summary_rounds:
return _summarize_chunk(text)
# If text is longer than threshold and progressive summarization is on
if len(text.split()) > 512 and progressive:
chunks = self.chunk_text(text)
chunk_summaries = [_summarize_chunk(chunk) for chunk in chunks]
# Combine chunk-level summaries
combined_summary = " ".join(chunk_summaries)
# If still too long, do another summarization pass but increment round_idx
if len(combined_summary.split()) > 512:
return self.summarize_text(
combined_summary,
progressive=True,
round_idx=round_idx + 1
)
return combined_summary
else:
# If text is not too long, just summarize once and return
return _summarize_chunk(text)