token-attention-visualizer / core /model_handler.py
asadshahab's picture
initial
dd850a7
raw
history blame
8.2 kB
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Tuple, Optional, List, Dict, Any
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module='transformers.generation')
class ModelHandler:
def __init__(self, model_name: str = None, config=None):
self.model = None
self.tokenizer = None
self.device = None
self.model_name = model_name
self.config = config
def load_model(self, model_name: str = None) -> Tuple[bool, str]:
"""Load model with optimized settings"""
if model_name:
self.model_name = model_name
if not self.model_name:
return False, "No model name provided"
try:
print(f"Loading model: {self.model_name}...")
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
# Determine device and dtype
if self.config and hasattr(self.config, 'DEVICE'):
self.device = self.config.DEVICE
# If config specifies CPU, force it even if CUDA is available
if self.device == "cpu":
print("Forcing CPU usage as specified in config")
elif self.device == "cuda" and not torch.cuda.is_available():
print("CUDA requested but not available, falling back to CPU")
self.device = "cpu"
else:
# Fallback to auto-detection if no config provided
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# Use bfloat16 for Ampere GPUs (compute capability >= 8.0), otherwise float32
if self.device == "cuda" and torch.cuda.is_available():
capability = torch.cuda.get_device_capability()
if capability[0] >= 8:
dtype = torch.bfloat16
else:
dtype = torch.float32
else:
dtype = torch.float32
# Load model
try:
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=dtype,
attn_implementation="eager" # Force eager attention for attention extraction
).to(self.device)
print(f"Model loaded on {self.device} with dtype {dtype} (eager attention)")
except Exception as e:
print(f"Error loading model with specific dtype: {e}")
print("Attempting to load without specific dtype...")
try:
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
attn_implementation="eager"
).to(self.device)
print(f"Model loaded on {self.device} (default dtype, eager attention)")
except Exception as e2:
print(f"Error with eager attention: {e2}")
print("Loading with default settings...")
self.model = AutoModelForCausalLM.from_pretrained(self.model_name).to(self.device)
print(f"Model loaded on {self.device} (default settings)")
# Handle pad token
if self.tokenizer.pad_token is None:
if self.tokenizer.eos_token:
print("Setting pad_token to eos_token")
self.tokenizer.pad_token = self.tokenizer.eos_token
if hasattr(self.model.config, 'pad_token_id') and self.model.config.pad_token_id is None:
self.model.config.pad_token_id = self.tokenizer.eos_token_id
else:
print("Warning: No eos_token found to set as pad_token.")
return True, f"Model loaded successfully on {self.device}"
except Exception as e:
return False, f"Error loading model: {str(e)}"
def generate_with_attention(
self,
prompt: str,
max_tokens: int = 30,
temperature: float = 0.7,
top_p: float = 0.95
) -> Tuple[Optional[List], List[str], List[str], str]:
"""
Generate text and capture attention weights
Returns: (attention_matrices, output_tokens, input_tokens, generated_text)
"""
if not self.model or not self.tokenizer:
return None, [], [], "Model not loaded"
# Encode input
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
input_len_raw = input_ids.shape[1]
print(f"Generating with input length: {input_len_raw}, max_new_tokens: {max_tokens}")
# Generate with attention
with torch.no_grad():
attention_mask = torch.ones_like(input_ids)
gen_kwargs = {
"attention_mask": attention_mask,
"max_new_tokens": max_tokens,
"output_attentions": True,
"return_dict_in_generate": True,
"temperature": temperature,
"top_p": top_p,
"do_sample": temperature > 0
}
if self.tokenizer.pad_token_id is not None:
gen_kwargs["pad_token_id"] = self.tokenizer.pad_token_id
try:
output = self.model.generate(input_ids, **gen_kwargs)
except Exception as e:
print(f"Error during generation: {e}")
return None, [], [], f"Error during generation: {str(e)}"
# Extract generated tokens
full_sequence = output.sequences[0]
if full_sequence.shape[0] > input_len_raw:
generated_ids = full_sequence[input_len_raw:]
else:
generated_ids = torch.tensor([], dtype=torch.long, device=self.device)
# Convert to tokens
output_tokens = self.tokenizer.convert_ids_to_tokens(generated_ids, skip_special_tokens=False)
input_tokens_raw = self.tokenizer.convert_ids_to_tokens(input_ids[0], skip_special_tokens=False)
# Handle BOS token removal from visualization
input_tokens = input_tokens_raw
input_len_for_attention = input_len_raw
bos_token = self.tokenizer.bos_token or '<|begin_of_text|>'
if input_tokens_raw and input_tokens_raw[0] == bos_token:
input_tokens = input_tokens_raw[1:]
input_len_for_attention = input_len_raw - 1
# Handle EOS token removal
eos_token = self.tokenizer.eos_token or '<|end_of_text|>'
if output_tokens and output_tokens[-1] == eos_token:
output_tokens = output_tokens[:-1]
generated_ids = generated_ids[:-1]
# Decode generated text
generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
# Extract attention weights
attentions = getattr(output, 'attentions', None)
if attentions is None:
print("Warning: 'attentions' not found in model output. Cannot visualize attention.")
return None, output_tokens, input_tokens, generated_text
# Return raw attention, tokens, and metadata
return {
'attentions': attentions,
'input_len_for_attention': input_len_for_attention,
'output_len': len(output_tokens)
}, output_tokens, input_tokens, generated_text
def get_model_info(self) -> Dict[str, Any]:
"""Get information about the loaded model"""
if not self.model:
return {"loaded": False}
return {
"loaded": True,
"model_name": self.model_name,
"device": str(self.device),
"num_parameters": sum(p.numel() for p in self.model.parameters()),
"dtype": str(next(self.model.parameters()).dtype),
"vocab_size": self.tokenizer.vocab_size if self.tokenizer else 0
}