File size: 12,985 Bytes
bda169a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
# --- START OF FILE architecture.py ---

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import Phi3Config, Phi3ForCausalLM
from typing import Optional, Dict, List

# --- BUILDING BLOCK 1: Hierarchical VectorMemoryHead ---
# This version is improved with a hierarchical memory system (L1/L2 cache)
# to handle much longer contexts and a gated update mechanism for stability.
class VectorMemoryHead(nn.Module):
    def __init__(self, hidden_dim: int, num_memory_slots: int, num_heads: int, ff_dim: int,
                 num_long_term_memory_slots: int = 0, # <-- NEW: Size of the L2 memory cache
                 device=None, dtype=None):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_memory_slots = num_memory_slots # L1 cache size
        self.num_long_term_memory_slots = num_long_term_memory_slots # L2 cache size

        # --- L1 Working Memory Components (same as before) ---
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim, nhead=num_heads, dim_feedforward=ff_dim, dropout=0.1, batch_first=True,
            device=device, dtype=dtype
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=1)
        self.memory_queries = nn.Parameter(torch.randn(1, num_memory_slots, hidden_dim, device=device, dtype=dtype))
        self.memory_attention = nn.MultiheadAttention(
            embed_dim=hidden_dim, num_heads=num_heads, dropout=0.1, batch_first=True,
            device=device, dtype=dtype
        )
        self.memory_layernorm = nn.LayerNorm(hidden_dim, device=device, dtype=dtype)
        self.decoder_attention = nn.MultiheadAttention(
            embed_dim=hidden_dim, num_heads=num_heads, dropout=0.1, batch_first=True,
            device=device, dtype=dtype
        )
        self.decoder_layernorm = nn.LayerNorm(hidden_dim, device=device, dtype=dtype)
        self.decoder_ffn = nn.Sequential(
            nn.Linear(hidden_dim, ff_dim, device=device, dtype=dtype),
            nn.ReLU(),
            nn.Linear(ff_dim, hidden_dim, device=device, dtype=dtype)
        )

        # --- NEW: L2 Long-Term Memory Components ---
        self.use_long_term_memory = self.num_long_term_memory_slots > 0
        if self.use_long_term_memory:
            self.long_term_memory = nn.Parameter(
                torch.zeros(1, self.num_long_term_memory_slots, hidden_dim, device=device, dtype=dtype)
            )
            # Gate for updating long-term memory (similar to GRU/LSTM gates)
            self.memory_update_gate = nn.Sequential(
                nn.Linear(2 * hidden_dim, hidden_dim, device=device, dtype=dtype),
                nn.Sigmoid()
            )
            # Attention to read from L2 memory
            self.ltm_retrieval_attention = nn.MultiheadAttention(
                embed_dim=hidden_dim, num_heads=num_heads, dropout=0.1, batch_first=True,
                device=device, dtype=dtype
            )

    def forward(self, memory_input_sequence: torch.Tensor):
        batch_size = memory_input_sequence.shape[0]

        # 1. Encode input sequence
        encoded_vectors = self.encoder(memory_input_sequence)

        # 2. Compress into L1 working memory
        queries = self.memory_queries.expand(batch_size, -1, -1)
        compressed_memory, _ = self.memory_attention(query=queries, key=encoded_vectors, value=encoded_vectors)
        compressed_memory = self.memory_layernorm(compressed_memory + queries) # (B, num_memory_slots, D)

        final_memory_context = compressed_memory

        # --- NEW: Interact with L2 Long-Term Memory ---
        if self.use_long_term_memory and self.long_term_memory.shape[0] == batch_size:
            # 3a. Retrieve relevant context from L2 memory using L1 as query
            retrieved_ltm, _ = self.ltm_retrieval_attention(
                query=compressed_memory,
                key=self.long_term_memory,
                value=self.long_term_memory
            )

            # 3b. Gated update of the Long-Term Memory
            # Average the L1 memory to get a summary vector for the update
            l1_summary = compressed_memory.mean(dim=1)
            ltm_summary = self.long_term_memory.mean(dim=1)
            gate_input = torch.cat([l1_summary, ltm_summary], dim=-1)
            update_gate = self.memory_update_gate(gate_input).unsqueeze(1) # (B, 1, D)

            # Update LTM by blending new info from L1
            self.long_term_memory.data = (update_gate * l1_summary.unsqueeze(1)) + ((1 - update_gate) * self.long_term_memory.data)

            # Combine L1 and retrieved L2 context for the final output
            final_memory_context = final_memory_context + retrieved_ltm

        # 4. Decode from the final memory context to reconstruct original sequence
        reconstructed, _ = self.decoder_attention(query=encoded_vectors, key=final_memory_context, value=final_memory_context)
        reconstructed_vectors = self.decoder_layernorm(reconstructed + encoded_vectors)
        reconstructed_vectors = self.decoder_ffn(reconstructed_vectors)

        return compressed_memory, reconstructed_vectors

# --- BUILDING BLOCK 2: The Custom Layer (With Per-Dataset Parameters and Refinement) ---
class GCVectorMemoryLayer(nn.Module):
    def __init__(self, original_layer: nn.Linear, global_input_dim: int,
                 memory_dim: int, num_memory_slots: int, memory_num_heads: int,
                 global_state_storage: Dict, dataset_keys: List[str]):
        super().__init__()
        self.input_dim = original_layer.in_features
        self.output_dim = original_layer.out_features
        self.memory_dim = memory_dim
        self.global_state_storage = global_state_storage
        self.dataset_keys = dataset_keys
        self.linear = original_layer # Shared linear layer

        device, dtype = self.linear.weight.device, self.linear.weight.dtype

        # --- NEW: Per-dataset specialized parameters ---
        self.local_state_projs = nn.ModuleDict()
        self.global_state_projs = nn.ModuleDict()
        self.memory_heads = nn.ModuleDict()
        self.correction_heads = nn.ModuleDict()

        for key in self.dataset_keys:
            self.local_state_projs[key] = nn.Linear(self.input_dim, memory_dim, device=device, dtype=dtype)
            self.global_state_projs[key] = nn.Linear(global_input_dim, memory_dim, device=device, dtype=dtype)
            self.memory_heads[key] = VectorMemoryHead(
                hidden_dim=memory_dim, num_memory_slots=num_memory_slots,
                num_heads=memory_num_heads, ff_dim=memory_dim * 2,
                num_long_term_memory_slots=32, # Enable L2 Cache
                device=device, dtype=dtype
            )
            self.correction_heads[key] = nn.Linear(memory_dim, 2 * self.output_dim, device=device, dtype=dtype)

        self.refinement_passes: int = 2 # Default to 2 passes for deeper refinement

        self.last_corrected_activation: Optional[torch.Tensor] = None
        self.last_additive_correction: Optional[torch.Tensor] = None
        self.last_memory_input: Optional[torch.Tensor] = None
        self.last_reconstructed_from_memory: Optional[torch.Tensor] = None

    def forward(self, x: torch.Tensor):
        base_output = self.linear(x)

        # Determine which set of specialized parameters to use
        dataset_key = self.global_state_storage.get('dataset_key')
        if not dataset_key or 'embeds' not in self.global_state_storage or self.refinement_passes < 1:
            return base_output

        # Select the correct modules for the current context
        local_state_proj = self.local_state_projs[dataset_key]
        global_state_proj = self.global_state_projs[dataset_key]
        memory_head = self.memory_heads[dataset_key]
        correction_head = self.correction_heads[dataset_key]

        global_embeds = self.global_state_storage['embeds']
        if global_embeds.shape[1] != x.shape[1]: global_embeds = global_embeds[:, -x.shape[1]:, :]
        B, S, _ = x.shape

        # Ensure LTM is initialized with correct batch size for the specific memory head
        if memory_head.use_long_term_memory and memory_head.long_term_memory.shape[0] != B:
            memory_head.long_term_memory.data = memory_head.long_term_memory.data.expand(B, -1, -1)

        with torch.no_grad(): # Use no_grad for the refinement loop as it's an inference-like process
            proj_local = local_state_proj(x.detach())
            proj_global = global_state_proj(global_embeds.detach())
            memory_input = torch.stack([proj_global, proj_local], dim=2)
            memory_input_flat = memory_input.view(B * S, 2, self.memory_dim)
            compressed_mem_flat, _ = memory_head(memory_input_flat)
            aggregated_thought = compressed_mem_flat.mean(dim=1).view(B, S, self.memory_dim)

            # Iteratively refine the output using state-feedback
            corrected_activation = base_output
            current_thought = aggregated_thought
            for _ in range(self.refinement_passes):
                raw_correction = correction_head(current_thought)
                gate, value = torch.chunk(raw_correction, 2, dim=-1)
                corrected_activation = corrected_activation * torch.sigmoid(gate) + value
                
                current_thought_flat = current_thought.view(B * S, self.memory_dim)
                refined_thought, _ = memory_head.decoder_attention(
                    query=current_thought_flat.unsqueeze(1), key=compressed_mem_flat, value=compressed_mem_flat
                )
                refined_thought = memory_head.decoder_layernorm(refined_thought.squeeze(1) + current_thought_flat)
                current_thought = refined_thought.view(B, S, self.memory_dim)

        if self.training:
            with torch.enable_grad():
                proj_local_grad = local_state_proj(x)
                proj_global_grad = global_state_proj(global_embeds)
                memory_input_grad = torch.stack([proj_global_grad, proj_local_grad], dim=2)
                memory_input_flat_grad = memory_input_grad.view(B * S, 2, self.memory_dim)
                compressed_mem_flat_grad, recon_flat_grad = memory_head(memory_input_flat_grad)
                aggregated_thought_grad = compressed_mem_flat_grad.mean(dim=1).view(B, S, self.memory_dim)

                raw_correction_grad = correction_head(aggregated_thought_grad)
                gate_grad, value_grad = torch.chunk(raw_correction_grad, 2, dim=-1)
                final_activation = base_output * torch.sigmoid(gate_grad.to(x.dtype)) + value_grad.to(x.dtype)

                self.last_corrected_activation = final_activation
                self.last_additive_correction = value_grad
                self.last_memory_input = memory_input_flat_grad
                self.last_reconstructed_from_memory = recon_flat_grad
                return final_activation
        else:
             return corrected_activation.to(x.dtype)

# --- BUILDING BLOCK 3: The Full Custom Model Wrapper (for saving/loading) ---
class Phi3WithVectorMemoryForCausalLM(Phi3ForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        self.global_state_storage = {}
        # Target a central layer in the network for maximum impact
        self.target_layer_path = "model.layers.15.mlp.gate_up_proj"

        self.model.embed_tokens.register_forward_hook(
            lambda module, input, output: self.global_state_storage.update({'embeds': output.detach()})
        )
        
        # This logic is primarily for loading a pre-trained model.
        # The training script handles the initial creation.
        if hasattr(config, "dataset_keys") and config.dataset_keys:
            try:
                print(f"Re-initializing GCVectorMemoryLayer with dataset keys: {config.dataset_keys}")
                original_layer = self.get_submodule(self.target_layer_path)
                custom_layer = GCVectorMemoryLayer(
                    original_layer=original_layer, global_input_dim=config.hidden_size,
                    memory_dim=64,
                    num_memory_slots=8,
                    memory_num_heads=4,
                    global_state_storage=self.global_state_storage,
                    dataset_keys=config.dataset_keys # Use keys from config
                )
                parent_path = ".".join(self.target_layer_path.split('.')[:-1])
                child_name = self.target_layer_path.split('.')[-1]
                setattr(self.get_submodule(parent_path), child_name, custom_layer)
                print(f"Successfully reloaded and replaced '{self.target_layer_path}' with specialized GCVectorMemoryLayer.")
            except AttributeError:
                print(f"Could not find target layer '{self.target_layer_path}' during reload. Model remains unmodified.")
        else:
             print("No 'dataset_keys' found in config. The custom layer will not be initialized.")