File size: 5,324 Bytes
f7b283c 7a0020b f7b283c cc2577d f7b283c cc2577d f7b283c cc2577d f7b283c 9decf80 f7b283c cc2577d f7b283c cc2577d f7b283c cc2577d f7b283c cc2577d f7b283c cc2577d 5b413d1 cc2577d f7b283c cc2577d f7b283c cc2577d f7b283c cc2577d 5b413d1 f7b283c cc2577d f7b283c cc2577d f7b283c cc2577d f7b283c cc2577d f7b283c cc2577d f7b283c cc2577d f7b283c cc2577d f7b283c cc2577d f7b283c cc2577d f7b283c cc2577d f7b283c cc2577d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import tensorflow as tf
from typing import List, Dict
from transformers import TFAutoModelForSeq2SeqLM, AutoTokenizer
import logging
from dataclasses import dataclass
logger = logging.getLogger(__name__)
@dataclass
class ChatConfig:
max_sequence_length: int = 512
default_top_k: int = 10
chunk_size: int = 512
chunk_overlap: int = 256
min_confidence_score: float = 0.7
class DeviceAwareModel:
"""
Mixin: Handle device placement and mixed precision training.
"""
def setup_device(self, device: str = None):
if device is None:
device = 'GPU' if tf.config.list_physical_devices('GPU') else 'CPU'
self.device = device.upper()
self.strategy = None
# NOTE: Needs more testing. Training issues may have been from other bugs I found since this was tested.
# Reminder: Test model saving/loading alongside mixed precision settings
if self.device == 'GPU':
# # Enable mixed precision for better performance
# policy = tf.keras.mixed_precision.Policy('mixed_float16')
# tf.keras.mixed_precision.set_global_policy(policy)
# Setup multi-GPU if available
gpus = tf.config.list_physical_devices('GPU')
if len(gpus) > 1:
self.strategy = tf.distribute.MirroredStrategy()
return self.device
def run_on_device(self, func):
"""Decorator to ensure ops run on the correct device."""
def wrapper(*args, **kwargs):
with tf.device(f'/{self.device}:0'):
return func(*args, **kwargs)
return wrapper
class Summarizer(DeviceAwareModel):
"""
T5-based summarizer with chunking and device management.
Chunking and progressive summarization for long conversations.
"""
def __init__(
self,
tokenizer: AutoTokenizer,
model_name="t5-small",
max_summary_length=128,
device=None,
max_summary_rounds=2
):
self.tokenizer = tokenizer
self.setup_device(device)
# Strategy scope if using distribution
if self.strategy:
with self.strategy.scope():
self._setup_model(model_name)
else:
self._setup_model(model_name)
self.max_summary_length = max_summary_length
self.max_summary_rounds = max_summary_rounds
def _setup_model(self, model_name):
self.model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)
# Optimize for inference
self.model.generate = tf.function(
self.model.generate,
input_signature=[
{
'input_ids': tf.TensorSpec(shape=[None, None], dtype=tf.int32),
'attention_mask': tf.TensorSpec(shape=[None, None], dtype=tf.int32)
}
]
)
@tf.function
def _generate_summary(self, inputs):
return self.model.generate(
inputs,
max_length=self.max_summary_length,
num_beams=4,
length_penalty=2.0,
early_stopping=True,
no_repeat_ngram_size=3
)
def chunk_text(self, text: str, chunk_size: int = 512, overlap: int = 256) -> List[str]:
"""Split text into overlapping chunks for context preservation."""
tokens = self.tokenizer.encode(text)
chunks = []
for i in range(0, len(tokens), chunk_size - overlap):
chunk = tokens[i:i + chunk_size]
chunks.append(self.tokenizer.decode(chunk, skip_special_tokens=True))
return chunks
def summarize_text(
self,
text: str,
progressive: bool = True,
round_idx: int = 0
) -> str:
"""
Progressive summarization and limited number of resummarization rounds.
"""
@self.run_on_device
def _summarize_chunk(chunk: str) -> str:
input_text = "summarize: " + chunk
inputs = self.tokenizer(
input_text,
return_tensors="tf",
padding=True,
truncation=True
)
summary_ids = self._generate_summary(inputs)
return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
# Do a single pass at resummarizing if max_summary rounds is hit
if round_idx >= self.max_summary_rounds:
return _summarize_chunk(text)
# Chunk and summarize
if len(text.split()) > 512 and progressive:
chunks = self.chunk_text(text)
chunk_summaries = [_summarize_chunk(chunk) for chunk in chunks]
# Combine chunk-level summaries
combined_summary = " ".join(chunk_summaries)
if len(combined_summary.split()) > 512:
return self.summarize_text(
combined_summary,
progressive=True,
round_idx=round_idx + 1
)
return combined_summary
else:
# Summarize once and return
return _summarize_chunk(text)
|