|
import tensorflow as tf |
|
from typing import List, Dict |
|
from transformers import TFAutoModelForSeq2SeqLM, AutoTokenizer |
|
import logging |
|
from dataclasses import dataclass |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
@dataclass |
|
class ChatConfig: |
|
max_sequence_length: int = 512 |
|
default_top_k: int = 10 |
|
chunk_size: int = 512 |
|
chunk_overlap: int = 256 |
|
min_confidence_score: float = 0.7 |
|
|
|
class DeviceAwareModel: |
|
""" |
|
Mixin: Handle device placement and mixed precision training. |
|
""" |
|
|
|
def setup_device(self, device: str = None): |
|
if device is None: |
|
device = 'GPU' if tf.config.list_physical_devices('GPU') else 'CPU' |
|
|
|
self.device = device.upper() |
|
self.strategy = None |
|
|
|
|
|
|
|
if self.device == 'GPU': |
|
|
|
|
|
|
|
|
|
|
|
gpus = tf.config.list_physical_devices('GPU') |
|
if len(gpus) > 1: |
|
self.strategy = tf.distribute.MirroredStrategy() |
|
|
|
return self.device |
|
|
|
def run_on_device(self, func): |
|
"""Decorator to ensure ops run on the correct device.""" |
|
def wrapper(*args, **kwargs): |
|
with tf.device(f'/{self.device}:0'): |
|
return func(*args, **kwargs) |
|
return wrapper |
|
|
|
class Summarizer(DeviceAwareModel): |
|
""" |
|
T5-based summarizer with chunking and device management. |
|
Chunking and progressive summarization for long conversations. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
tokenizer: AutoTokenizer, |
|
model_name="t5-small", |
|
max_summary_length=128, |
|
device=None, |
|
max_summary_rounds=2 |
|
): |
|
self.tokenizer = tokenizer |
|
self.setup_device(device) |
|
|
|
|
|
if self.strategy: |
|
with self.strategy.scope(): |
|
self._setup_model(model_name) |
|
else: |
|
self._setup_model(model_name) |
|
|
|
self.max_summary_length = max_summary_length |
|
self.max_summary_rounds = max_summary_rounds |
|
|
|
def _setup_model(self, model_name): |
|
self.model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
|
|
|
|
self.model.generate = tf.function( |
|
self.model.generate, |
|
input_signature=[ |
|
{ |
|
'input_ids': tf.TensorSpec(shape=[None, None], dtype=tf.int32), |
|
'attention_mask': tf.TensorSpec(shape=[None, None], dtype=tf.int32) |
|
} |
|
] |
|
) |
|
|
|
@tf.function |
|
def _generate_summary(self, inputs): |
|
return self.model.generate( |
|
inputs, |
|
max_length=self.max_summary_length, |
|
num_beams=4, |
|
length_penalty=2.0, |
|
early_stopping=True, |
|
no_repeat_ngram_size=3 |
|
) |
|
|
|
def chunk_text(self, text: str, chunk_size: int = 512, overlap: int = 256) -> List[str]: |
|
"""Split text into overlapping chunks for context preservation.""" |
|
tokens = self.tokenizer.encode(text) |
|
chunks = [] |
|
|
|
for i in range(0, len(tokens), chunk_size - overlap): |
|
chunk = tokens[i:i + chunk_size] |
|
chunks.append(self.tokenizer.decode(chunk, skip_special_tokens=True)) |
|
|
|
return chunks |
|
|
|
def summarize_text( |
|
self, |
|
text: str, |
|
progressive: bool = True, |
|
round_idx: int = 0 |
|
) -> str: |
|
""" |
|
Progressive summarization and limited number of resummarization rounds. |
|
""" |
|
@self.run_on_device |
|
def _summarize_chunk(chunk: str) -> str: |
|
input_text = "summarize: " + chunk |
|
inputs = self.tokenizer( |
|
input_text, |
|
return_tensors="tf", |
|
padding=True, |
|
truncation=True |
|
) |
|
summary_ids = self._generate_summary(inputs) |
|
return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True) |
|
|
|
|
|
if round_idx >= self.max_summary_rounds: |
|
return _summarize_chunk(text) |
|
|
|
|
|
if len(text.split()) > 512 and progressive: |
|
chunks = self.chunk_text(text) |
|
chunk_summaries = [_summarize_chunk(chunk) for chunk in chunks] |
|
|
|
|
|
combined_summary = " ".join(chunk_summaries) |
|
|
|
if len(combined_summary.split()) > 512: |
|
return self.summarize_text( |
|
combined_summary, |
|
progressive=True, |
|
round_idx=round_idx + 1 |
|
) |
|
|
|
return combined_summary |
|
else: |
|
|
|
return _summarize_chunk(text) |
|
|