import tensorflow as tf from typing import List, Dict from transformers import TFAutoModelForSeq2SeqLM, AutoTokenizer import logging from dataclasses import dataclass logger = logging.getLogger(__name__) @dataclass class ChatConfig: max_sequence_length: int = 512 default_top_k: int = 10 chunk_size: int = 512 chunk_overlap: int = 256 min_confidence_score: float = 0.7 class DeviceAwareModel: """ Mixin: Handle device placement and mixed precision training. """ def setup_device(self, device: str = None): if device is None: device = 'GPU' if tf.config.list_physical_devices('GPU') else 'CPU' self.device = device.upper() self.strategy = None # NOTE: Needs more testing. Training issues may have been from other bugs I found since this was tested. # Reminder: Test model saving/loading alongside mixed precision settings if self.device == 'GPU': # # Enable mixed precision for better performance # policy = tf.keras.mixed_precision.Policy('mixed_float16') # tf.keras.mixed_precision.set_global_policy(policy) # Setup multi-GPU if available gpus = tf.config.list_physical_devices('GPU') if len(gpus) > 1: self.strategy = tf.distribute.MirroredStrategy() return self.device def run_on_device(self, func): """Decorator to ensure ops run on the correct device.""" def wrapper(*args, **kwargs): with tf.device(f'/{self.device}:0'): return func(*args, **kwargs) return wrapper class Summarizer(DeviceAwareModel): """ T5-based summarizer with chunking and device management. Chunking and progressive summarization for long conversations. """ def __init__( self, tokenizer: AutoTokenizer, model_name="t5-small", max_summary_length=128, device=None, max_summary_rounds=2 ): self.tokenizer = tokenizer self.setup_device(device) # Strategy scope if using distribution if self.strategy: with self.strategy.scope(): self._setup_model(model_name) else: self._setup_model(model_name) self.max_summary_length = max_summary_length self.max_summary_rounds = max_summary_rounds def _setup_model(self, model_name): self.model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name) # Optimize for inference self.model.generate = tf.function( self.model.generate, input_signature=[ { 'input_ids': tf.TensorSpec(shape=[None, None], dtype=tf.int32), 'attention_mask': tf.TensorSpec(shape=[None, None], dtype=tf.int32) } ] ) @tf.function def _generate_summary(self, inputs): return self.model.generate( inputs, max_length=self.max_summary_length, num_beams=4, length_penalty=2.0, early_stopping=True, no_repeat_ngram_size=3 ) def chunk_text(self, text: str, chunk_size: int = 512, overlap: int = 256) -> List[str]: """Split text into overlapping chunks for context preservation.""" tokens = self.tokenizer.encode(text) chunks = [] for i in range(0, len(tokens), chunk_size - overlap): chunk = tokens[i:i + chunk_size] chunks.append(self.tokenizer.decode(chunk, skip_special_tokens=True)) return chunks def summarize_text( self, text: str, progressive: bool = True, round_idx: int = 0 ) -> str: """ Progressive summarization and limited number of resummarization rounds. """ @self.run_on_device def _summarize_chunk(chunk: str) -> str: input_text = "summarize: " + chunk inputs = self.tokenizer( input_text, return_tensors="tf", padding=True, truncation=True ) summary_ids = self._generate_summary(inputs) return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True) # Do a single pass at resummarizing if max_summary rounds is hit if round_idx >= self.max_summary_rounds: return _summarize_chunk(text) # Chunk and summarize if len(text.split()) > 512 and progressive: chunks = self.chunk_text(text) chunk_summaries = [_summarize_chunk(chunk) for chunk in chunks] # Combine chunk-level summaries combined_summary = " ".join(chunk_summaries) if len(combined_summary.split()) > 512: return self.summarize_text( combined_summary, progressive=True, round_idx=round_idx + 1 ) return combined_summary else: # Summarize once and return return _summarize_chunk(text)