csc525_retrieval_based_chatbot / conversation_summarizer.py
JoeArmani
style updates
cc2577d
import tensorflow as tf
from typing import List, Dict
from transformers import TFAutoModelForSeq2SeqLM, AutoTokenizer
import logging
from dataclasses import dataclass
logger = logging.getLogger(__name__)
@dataclass
class ChatConfig:
max_sequence_length: int = 512
default_top_k: int = 10
chunk_size: int = 512
chunk_overlap: int = 256
min_confidence_score: float = 0.7
class DeviceAwareModel:
"""
Mixin: Handle device placement and mixed precision training.
"""
def setup_device(self, device: str = None):
if device is None:
device = 'GPU' if tf.config.list_physical_devices('GPU') else 'CPU'
self.device = device.upper()
self.strategy = None
# NOTE: Needs more testing. Training issues may have been from other bugs I found since this was tested.
# Reminder: Test model saving/loading alongside mixed precision settings
if self.device == 'GPU':
# # Enable mixed precision for better performance
# policy = tf.keras.mixed_precision.Policy('mixed_float16')
# tf.keras.mixed_precision.set_global_policy(policy)
# Setup multi-GPU if available
gpus = tf.config.list_physical_devices('GPU')
if len(gpus) > 1:
self.strategy = tf.distribute.MirroredStrategy()
return self.device
def run_on_device(self, func):
"""Decorator to ensure ops run on the correct device."""
def wrapper(*args, **kwargs):
with tf.device(f'/{self.device}:0'):
return func(*args, **kwargs)
return wrapper
class Summarizer(DeviceAwareModel):
"""
T5-based summarizer with chunking and device management.
Chunking and progressive summarization for long conversations.
"""
def __init__(
self,
tokenizer: AutoTokenizer,
model_name="t5-small",
max_summary_length=128,
device=None,
max_summary_rounds=2
):
self.tokenizer = tokenizer
self.setup_device(device)
# Strategy scope if using distribution
if self.strategy:
with self.strategy.scope():
self._setup_model(model_name)
else:
self._setup_model(model_name)
self.max_summary_length = max_summary_length
self.max_summary_rounds = max_summary_rounds
def _setup_model(self, model_name):
self.model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)
# Optimize for inference
self.model.generate = tf.function(
self.model.generate,
input_signature=[
{
'input_ids': tf.TensorSpec(shape=[None, None], dtype=tf.int32),
'attention_mask': tf.TensorSpec(shape=[None, None], dtype=tf.int32)
}
]
)
@tf.function
def _generate_summary(self, inputs):
return self.model.generate(
inputs,
max_length=self.max_summary_length,
num_beams=4,
length_penalty=2.0,
early_stopping=True,
no_repeat_ngram_size=3
)
def chunk_text(self, text: str, chunk_size: int = 512, overlap: int = 256) -> List[str]:
"""Split text into overlapping chunks for context preservation."""
tokens = self.tokenizer.encode(text)
chunks = []
for i in range(0, len(tokens), chunk_size - overlap):
chunk = tokens[i:i + chunk_size]
chunks.append(self.tokenizer.decode(chunk, skip_special_tokens=True))
return chunks
def summarize_text(
self,
text: str,
progressive: bool = True,
round_idx: int = 0
) -> str:
"""
Progressive summarization and limited number of resummarization rounds.
"""
@self.run_on_device
def _summarize_chunk(chunk: str) -> str:
input_text = "summarize: " + chunk
inputs = self.tokenizer(
input_text,
return_tensors="tf",
padding=True,
truncation=True
)
summary_ids = self._generate_summary(inputs)
return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
# Do a single pass at resummarizing if max_summary rounds is hit
if round_idx >= self.max_summary_rounds:
return _summarize_chunk(text)
# Chunk and summarize
if len(text.split()) > 512 and progressive:
chunks = self.chunk_text(text)
chunk_summaries = [_summarize_chunk(chunk) for chunk in chunks]
# Combine chunk-level summaries
combined_summary = " ".join(chunk_summaries)
if len(combined_summary.split()) > 512:
return self.summarize_text(
combined_summary,
progressive=True,
round_idx=round_idx + 1
)
return combined_summary
else:
# Summarize once and return
return _summarize_chunk(text)