# --- START OF FILE architecture.py --- import torch import torch.nn as nn import torch.nn.functional as F from transformers import Phi3Config, Phi3ForCausalLM from typing import Optional, Dict, List # --- BUILDING BLOCK 1: Hierarchical VectorMemoryHead --- # This version is improved with a hierarchical memory system (L1/L2 cache) # to handle much longer contexts and a gated update mechanism for stability. class VectorMemoryHead(nn.Module): def __init__(self, hidden_dim: int, num_memory_slots: int, num_heads: int, ff_dim: int, num_long_term_memory_slots: int = 0, # <-- NEW: Size of the L2 memory cache device=None, dtype=None): super().__init__() self.hidden_dim = hidden_dim self.num_memory_slots = num_memory_slots # L1 cache size self.num_long_term_memory_slots = num_long_term_memory_slots # L2 cache size # --- L1 Working Memory Components (same as before) --- encoder_layer = nn.TransformerEncoderLayer( d_model=hidden_dim, nhead=num_heads, dim_feedforward=ff_dim, dropout=0.1, batch_first=True, device=device, dtype=dtype ) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=1) self.memory_queries = nn.Parameter(torch.randn(1, num_memory_slots, hidden_dim, device=device, dtype=dtype)) self.memory_attention = nn.MultiheadAttention( embed_dim=hidden_dim, num_heads=num_heads, dropout=0.1, batch_first=True, device=device, dtype=dtype ) self.memory_layernorm = nn.LayerNorm(hidden_dim, device=device, dtype=dtype) self.decoder_attention = nn.MultiheadAttention( embed_dim=hidden_dim, num_heads=num_heads, dropout=0.1, batch_first=True, device=device, dtype=dtype ) self.decoder_layernorm = nn.LayerNorm(hidden_dim, device=device, dtype=dtype) self.decoder_ffn = nn.Sequential( nn.Linear(hidden_dim, ff_dim, device=device, dtype=dtype), nn.ReLU(), nn.Linear(ff_dim, hidden_dim, device=device, dtype=dtype) ) # --- NEW: L2 Long-Term Memory Components --- self.use_long_term_memory = self.num_long_term_memory_slots > 0 if self.use_long_term_memory: self.long_term_memory = nn.Parameter( torch.zeros(1, self.num_long_term_memory_slots, hidden_dim, device=device, dtype=dtype) ) # Gate for updating long-term memory (similar to GRU/LSTM gates) self.memory_update_gate = nn.Sequential( nn.Linear(2 * hidden_dim, hidden_dim, device=device, dtype=dtype), nn.Sigmoid() ) # Attention to read from L2 memory self.ltm_retrieval_attention = nn.MultiheadAttention( embed_dim=hidden_dim, num_heads=num_heads, dropout=0.1, batch_first=True, device=device, dtype=dtype ) def forward(self, memory_input_sequence: torch.Tensor): batch_size = memory_input_sequence.shape[0] # 1. Encode input sequence encoded_vectors = self.encoder(memory_input_sequence) # 2. Compress into L1 working memory queries = self.memory_queries.expand(batch_size, -1, -1) compressed_memory, _ = self.memory_attention(query=queries, key=encoded_vectors, value=encoded_vectors) compressed_memory = self.memory_layernorm(compressed_memory + queries) # (B, num_memory_slots, D) final_memory_context = compressed_memory # --- NEW: Interact with L2 Long-Term Memory --- if self.use_long_term_memory and self.long_term_memory.shape[0] == batch_size: # 3a. Retrieve relevant context from L2 memory using L1 as query retrieved_ltm, _ = self.ltm_retrieval_attention( query=compressed_memory, key=self.long_term_memory, value=self.long_term_memory ) # 3b. Gated update of the Long-Term Memory # Average the L1 memory to get a summary vector for the update l1_summary = compressed_memory.mean(dim=1) ltm_summary = self.long_term_memory.mean(dim=1) gate_input = torch.cat([l1_summary, ltm_summary], dim=-1) update_gate = self.memory_update_gate(gate_input).unsqueeze(1) # (B, 1, D) # Update LTM by blending new info from L1 self.long_term_memory.data = (update_gate * l1_summary.unsqueeze(1)) + ((1 - update_gate) * self.long_term_memory.data) # Combine L1 and retrieved L2 context for the final output final_memory_context = final_memory_context + retrieved_ltm # 4. Decode from the final memory context to reconstruct original sequence reconstructed, _ = self.decoder_attention(query=encoded_vectors, key=final_memory_context, value=final_memory_context) reconstructed_vectors = self.decoder_layernorm(reconstructed + encoded_vectors) reconstructed_vectors = self.decoder_ffn(reconstructed_vectors) return compressed_memory, reconstructed_vectors # --- BUILDING BLOCK 2: The Custom Layer (With Per-Dataset Parameters and Refinement) --- class GCVectorMemoryLayer(nn.Module): def __init__(self, original_layer: nn.Linear, global_input_dim: int, memory_dim: int, num_memory_slots: int, memory_num_heads: int, global_state_storage: Dict, dataset_keys: List[str]): super().__init__() self.input_dim = original_layer.in_features self.output_dim = original_layer.out_features self.memory_dim = memory_dim self.global_state_storage = global_state_storage self.dataset_keys = dataset_keys self.linear = original_layer # Shared linear layer device, dtype = self.linear.weight.device, self.linear.weight.dtype # --- NEW: Per-dataset specialized parameters --- self.local_state_projs = nn.ModuleDict() self.global_state_projs = nn.ModuleDict() self.memory_heads = nn.ModuleDict() self.correction_heads = nn.ModuleDict() for key in self.dataset_keys: self.local_state_projs[key] = nn.Linear(self.input_dim, memory_dim, device=device, dtype=dtype) self.global_state_projs[key] = nn.Linear(global_input_dim, memory_dim, device=device, dtype=dtype) self.memory_heads[key] = VectorMemoryHead( hidden_dim=memory_dim, num_memory_slots=num_memory_slots, num_heads=memory_num_heads, ff_dim=memory_dim * 2, num_long_term_memory_slots=32, # Enable L2 Cache device=device, dtype=dtype ) self.correction_heads[key] = nn.Linear(memory_dim, 2 * self.output_dim, device=device, dtype=dtype) self.refinement_passes: int = 2 # Default to 2 passes for deeper refinement self.last_corrected_activation: Optional[torch.Tensor] = None self.last_additive_correction: Optional[torch.Tensor] = None self.last_memory_input: Optional[torch.Tensor] = None self.last_reconstructed_from_memory: Optional[torch.Tensor] = None def forward(self, x: torch.Tensor): base_output = self.linear(x) # Determine which set of specialized parameters to use dataset_key = self.global_state_storage.get('dataset_key') if not dataset_key or 'embeds' not in self.global_state_storage or self.refinement_passes < 1: return base_output # Select the correct modules for the current context local_state_proj = self.local_state_projs[dataset_key] global_state_proj = self.global_state_projs[dataset_key] memory_head = self.memory_heads[dataset_key] correction_head = self.correction_heads[dataset_key] global_embeds = self.global_state_storage['embeds'] if global_embeds.shape[1] != x.shape[1]: global_embeds = global_embeds[:, -x.shape[1]:, :] B, S, _ = x.shape # Ensure LTM is initialized with correct batch size for the specific memory head if memory_head.use_long_term_memory and memory_head.long_term_memory.shape[0] != B: memory_head.long_term_memory.data = memory_head.long_term_memory.data.expand(B, -1, -1) with torch.no_grad(): # Use no_grad for the refinement loop as it's an inference-like process proj_local = local_state_proj(x.detach()) proj_global = global_state_proj(global_embeds.detach()) memory_input = torch.stack([proj_global, proj_local], dim=2) memory_input_flat = memory_input.view(B * S, 2, self.memory_dim) compressed_mem_flat, _ = memory_head(memory_input_flat) aggregated_thought = compressed_mem_flat.mean(dim=1).view(B, S, self.memory_dim) # Iteratively refine the output using state-feedback corrected_activation = base_output current_thought = aggregated_thought for _ in range(self.refinement_passes): raw_correction = correction_head(current_thought) gate, value = torch.chunk(raw_correction, 2, dim=-1) corrected_activation = corrected_activation * torch.sigmoid(gate) + value current_thought_flat = current_thought.view(B * S, self.memory_dim) refined_thought, _ = memory_head.decoder_attention( query=current_thought_flat.unsqueeze(1), key=compressed_mem_flat, value=compressed_mem_flat ) refined_thought = memory_head.decoder_layernorm(refined_thought.squeeze(1) + current_thought_flat) current_thought = refined_thought.view(B, S, self.memory_dim) if self.training: with torch.enable_grad(): proj_local_grad = local_state_proj(x) proj_global_grad = global_state_proj(global_embeds) memory_input_grad = torch.stack([proj_global_grad, proj_local_grad], dim=2) memory_input_flat_grad = memory_input_grad.view(B * S, 2, self.memory_dim) compressed_mem_flat_grad, recon_flat_grad = memory_head(memory_input_flat_grad) aggregated_thought_grad = compressed_mem_flat_grad.mean(dim=1).view(B, S, self.memory_dim) raw_correction_grad = correction_head(aggregated_thought_grad) gate_grad, value_grad = torch.chunk(raw_correction_grad, 2, dim=-1) final_activation = base_output * torch.sigmoid(gate_grad.to(x.dtype)) + value_grad.to(x.dtype) self.last_corrected_activation = final_activation self.last_additive_correction = value_grad self.last_memory_input = memory_input_flat_grad self.last_reconstructed_from_memory = recon_flat_grad return final_activation else: return corrected_activation.to(x.dtype) # --- BUILDING BLOCK 3: The Full Custom Model Wrapper (for saving/loading) --- class Phi3WithVectorMemoryForCausalLM(Phi3ForCausalLM): def __init__(self, config): super().__init__(config) self.global_state_storage = {} # Target a central layer in the network for maximum impact self.target_layer_path = "model.layers.15.mlp.gate_up_proj" self.model.embed_tokens.register_forward_hook( lambda module, input, output: self.global_state_storage.update({'embeds': output.detach()}) ) # This logic is primarily for loading a pre-trained model. # The training script handles the initial creation. if hasattr(config, "dataset_keys") and config.dataset_keys: try: print(f"Re-initializing GCVectorMemoryLayer with dataset keys: {config.dataset_keys}") original_layer = self.get_submodule(self.target_layer_path) custom_layer = GCVectorMemoryLayer( original_layer=original_layer, global_input_dim=config.hidden_size, memory_dim=64, num_memory_slots=8, memory_num_heads=4, global_state_storage=self.global_state_storage, dataset_keys=config.dataset_keys # Use keys from config ) parent_path = ".".join(self.target_layer_path.split('.')[:-1]) child_name = self.target_layer_path.split('.')[-1] setattr(self.get_submodule(parent_path), child_name, custom_layer) print(f"Successfully reloaded and replaced '{self.target_layer_path}' with specialized GCVectorMemoryLayer.") except AttributeError: print(f"Could not find target layer '{self.target_layer_path}' during reload. Model remains unmodified.") else: print("No 'dataset_keys' found in config. The custom layer will not be initialized.")