File size: 8,198 Bytes
dd850a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Tuple, Optional, List, Dict, Any
import warnings

warnings.filterwarnings("ignore", category=UserWarning, module='transformers.generation')

class ModelHandler:
    def __init__(self, model_name: str = None, config=None):
        self.model = None
        self.tokenizer = None
        self.device = None
        self.model_name = model_name
        self.config = config
        
    def load_model(self, model_name: str = None) -> Tuple[bool, str]:
        """Load model with optimized settings"""
        if model_name:
            self.model_name = model_name
        
        if not self.model_name:
            return False, "No model name provided"
        
        try:
            print(f"Loading model: {self.model_name}...")
            
            # Load tokenizer
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
            
            # Determine device and dtype
            if self.config and hasattr(self.config, 'DEVICE'):
                self.device = self.config.DEVICE
                # If config specifies CPU, force it even if CUDA is available
                if self.device == "cpu":
                    print("Forcing CPU usage as specified in config")
                elif self.device == "cuda" and not torch.cuda.is_available():
                    print("CUDA requested but not available, falling back to CPU")
                    self.device = "cpu"
            else:
                # Fallback to auto-detection if no config provided
                self.device = "cuda" if torch.cuda.is_available() else "cpu"
            
            # Use bfloat16 for Ampere GPUs (compute capability >= 8.0), otherwise float32
            if self.device == "cuda" and torch.cuda.is_available():
                capability = torch.cuda.get_device_capability()
                if capability[0] >= 8:
                    dtype = torch.bfloat16
                else:
                    dtype = torch.float32
            else:
                dtype = torch.float32
            
            # Load model
            try:
                self.model = AutoModelForCausalLM.from_pretrained(
                    self.model_name,
                    torch_dtype=dtype,
                    attn_implementation="eager"  # Force eager attention for attention extraction
                ).to(self.device)
                print(f"Model loaded on {self.device} with dtype {dtype} (eager attention)")
            except Exception as e:
                print(f"Error loading model with specific dtype: {e}")
                print("Attempting to load without specific dtype...")
                try:
                    self.model = AutoModelForCausalLM.from_pretrained(
                        self.model_name,
                        attn_implementation="eager"
                    ).to(self.device)
                    print(f"Model loaded on {self.device} (default dtype, eager attention)")
                except Exception as e2:
                    print(f"Error with eager attention: {e2}")
                    print("Loading with default settings...")
                    self.model = AutoModelForCausalLM.from_pretrained(self.model_name).to(self.device)
                    print(f"Model loaded on {self.device} (default settings)")
            
            # Handle pad token
            if self.tokenizer.pad_token is None:
                if self.tokenizer.eos_token:
                    print("Setting pad_token to eos_token")
                    self.tokenizer.pad_token = self.tokenizer.eos_token
                    if hasattr(self.model.config, 'pad_token_id') and self.model.config.pad_token_id is None:
                        self.model.config.pad_token_id = self.tokenizer.eos_token_id
                else:
                    print("Warning: No eos_token found to set as pad_token.")
            
            return True, f"Model loaded successfully on {self.device}"
            
        except Exception as e:
            return False, f"Error loading model: {str(e)}"
    
    def generate_with_attention(
        self, 
        prompt: str, 
        max_tokens: int = 30,
        temperature: float = 0.7,
        top_p: float = 0.95
    ) -> Tuple[Optional[List], List[str], List[str], str]:
        """
        Generate text and capture attention weights
        Returns: (attention_matrices, output_tokens, input_tokens, generated_text)
        """
        if not self.model or not self.tokenizer:
            return None, [], [], "Model not loaded"
        
        # Encode input
        input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
        input_len_raw = input_ids.shape[1]
        
        print(f"Generating with input length: {input_len_raw}, max_new_tokens: {max_tokens}")
        
        # Generate with attention
        with torch.no_grad():
            attention_mask = torch.ones_like(input_ids)
            gen_kwargs = {
                "attention_mask": attention_mask,
                "max_new_tokens": max_tokens,
                "output_attentions": True,
                "return_dict_in_generate": True,
                "temperature": temperature,
                "top_p": top_p,
                "do_sample": temperature > 0
            }
            
            if self.tokenizer.pad_token_id is not None:
                gen_kwargs["pad_token_id"] = self.tokenizer.pad_token_id
            
            try:
                output = self.model.generate(input_ids, **gen_kwargs)
            except Exception as e:
                print(f"Error during generation: {e}")
                return None, [], [], f"Error during generation: {str(e)}"
        
        # Extract generated tokens
        full_sequence = output.sequences[0]
        if full_sequence.shape[0] > input_len_raw:
            generated_ids = full_sequence[input_len_raw:]
        else:
            generated_ids = torch.tensor([], dtype=torch.long, device=self.device)
        
        # Convert to tokens
        output_tokens = self.tokenizer.convert_ids_to_tokens(generated_ids, skip_special_tokens=False)
        input_tokens_raw = self.tokenizer.convert_ids_to_tokens(input_ids[0], skip_special_tokens=False)
        
        # Handle BOS token removal from visualization
        input_tokens = input_tokens_raw
        input_len_for_attention = input_len_raw
        bos_token = self.tokenizer.bos_token or '<|begin_of_text|>'
        
        if input_tokens_raw and input_tokens_raw[0] == bos_token:
            input_tokens = input_tokens_raw[1:]
            input_len_for_attention = input_len_raw - 1
        
        # Handle EOS token removal
        eos_token = self.tokenizer.eos_token or '<|end_of_text|>'
        if output_tokens and output_tokens[-1] == eos_token:
            output_tokens = output_tokens[:-1]
            generated_ids = generated_ids[:-1]
        
        # Decode generated text
        generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
        
        # Extract attention weights
        attentions = getattr(output, 'attentions', None)
        if attentions is None:
            print("Warning: 'attentions' not found in model output. Cannot visualize attention.")
            return None, output_tokens, input_tokens, generated_text
        
        # Return raw attention, tokens, and metadata
        return {
            'attentions': attentions,
            'input_len_for_attention': input_len_for_attention,
            'output_len': len(output_tokens)
        }, output_tokens, input_tokens, generated_text
    
    def get_model_info(self) -> Dict[str, Any]:
        """Get information about the loaded model"""
        if not self.model:
            return {"loaded": False}
        
        return {
            "loaded": True,
            "model_name": self.model_name,
            "device": str(self.device),
            "num_parameters": sum(p.numel() for p in self.model.parameters()),
            "dtype": str(next(self.model.parameters()).dtype),
            "vocab_size": self.tokenizer.vocab_size if self.tokenizer else 0
        }