File size: 12,985 Bytes
bda169a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 |
# --- START OF FILE architecture.py ---
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import Phi3Config, Phi3ForCausalLM
from typing import Optional, Dict, List
# --- BUILDING BLOCK 1: Hierarchical VectorMemoryHead ---
# This version is improved with a hierarchical memory system (L1/L2 cache)
# to handle much longer contexts and a gated update mechanism for stability.
class VectorMemoryHead(nn.Module):
def __init__(self, hidden_dim: int, num_memory_slots: int, num_heads: int, ff_dim: int,
num_long_term_memory_slots: int = 0, # <-- NEW: Size of the L2 memory cache
device=None, dtype=None):
super().__init__()
self.hidden_dim = hidden_dim
self.num_memory_slots = num_memory_slots # L1 cache size
self.num_long_term_memory_slots = num_long_term_memory_slots # L2 cache size
# --- L1 Working Memory Components (same as before) ---
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_dim, nhead=num_heads, dim_feedforward=ff_dim, dropout=0.1, batch_first=True,
device=device, dtype=dtype
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=1)
self.memory_queries = nn.Parameter(torch.randn(1, num_memory_slots, hidden_dim, device=device, dtype=dtype))
self.memory_attention = nn.MultiheadAttention(
embed_dim=hidden_dim, num_heads=num_heads, dropout=0.1, batch_first=True,
device=device, dtype=dtype
)
self.memory_layernorm = nn.LayerNorm(hidden_dim, device=device, dtype=dtype)
self.decoder_attention = nn.MultiheadAttention(
embed_dim=hidden_dim, num_heads=num_heads, dropout=0.1, batch_first=True,
device=device, dtype=dtype
)
self.decoder_layernorm = nn.LayerNorm(hidden_dim, device=device, dtype=dtype)
self.decoder_ffn = nn.Sequential(
nn.Linear(hidden_dim, ff_dim, device=device, dtype=dtype),
nn.ReLU(),
nn.Linear(ff_dim, hidden_dim, device=device, dtype=dtype)
)
# --- NEW: L2 Long-Term Memory Components ---
self.use_long_term_memory = self.num_long_term_memory_slots > 0
if self.use_long_term_memory:
self.long_term_memory = nn.Parameter(
torch.zeros(1, self.num_long_term_memory_slots, hidden_dim, device=device, dtype=dtype)
)
# Gate for updating long-term memory (similar to GRU/LSTM gates)
self.memory_update_gate = nn.Sequential(
nn.Linear(2 * hidden_dim, hidden_dim, device=device, dtype=dtype),
nn.Sigmoid()
)
# Attention to read from L2 memory
self.ltm_retrieval_attention = nn.MultiheadAttention(
embed_dim=hidden_dim, num_heads=num_heads, dropout=0.1, batch_first=True,
device=device, dtype=dtype
)
def forward(self, memory_input_sequence: torch.Tensor):
batch_size = memory_input_sequence.shape[0]
# 1. Encode input sequence
encoded_vectors = self.encoder(memory_input_sequence)
# 2. Compress into L1 working memory
queries = self.memory_queries.expand(batch_size, -1, -1)
compressed_memory, _ = self.memory_attention(query=queries, key=encoded_vectors, value=encoded_vectors)
compressed_memory = self.memory_layernorm(compressed_memory + queries) # (B, num_memory_slots, D)
final_memory_context = compressed_memory
# --- NEW: Interact with L2 Long-Term Memory ---
if self.use_long_term_memory and self.long_term_memory.shape[0] == batch_size:
# 3a. Retrieve relevant context from L2 memory using L1 as query
retrieved_ltm, _ = self.ltm_retrieval_attention(
query=compressed_memory,
key=self.long_term_memory,
value=self.long_term_memory
)
# 3b. Gated update of the Long-Term Memory
# Average the L1 memory to get a summary vector for the update
l1_summary = compressed_memory.mean(dim=1)
ltm_summary = self.long_term_memory.mean(dim=1)
gate_input = torch.cat([l1_summary, ltm_summary], dim=-1)
update_gate = self.memory_update_gate(gate_input).unsqueeze(1) # (B, 1, D)
# Update LTM by blending new info from L1
self.long_term_memory.data = (update_gate * l1_summary.unsqueeze(1)) + ((1 - update_gate) * self.long_term_memory.data)
# Combine L1 and retrieved L2 context for the final output
final_memory_context = final_memory_context + retrieved_ltm
# 4. Decode from the final memory context to reconstruct original sequence
reconstructed, _ = self.decoder_attention(query=encoded_vectors, key=final_memory_context, value=final_memory_context)
reconstructed_vectors = self.decoder_layernorm(reconstructed + encoded_vectors)
reconstructed_vectors = self.decoder_ffn(reconstructed_vectors)
return compressed_memory, reconstructed_vectors
# --- BUILDING BLOCK 2: The Custom Layer (With Per-Dataset Parameters and Refinement) ---
class GCVectorMemoryLayer(nn.Module):
def __init__(self, original_layer: nn.Linear, global_input_dim: int,
memory_dim: int, num_memory_slots: int, memory_num_heads: int,
global_state_storage: Dict, dataset_keys: List[str]):
super().__init__()
self.input_dim = original_layer.in_features
self.output_dim = original_layer.out_features
self.memory_dim = memory_dim
self.global_state_storage = global_state_storage
self.dataset_keys = dataset_keys
self.linear = original_layer # Shared linear layer
device, dtype = self.linear.weight.device, self.linear.weight.dtype
# --- NEW: Per-dataset specialized parameters ---
self.local_state_projs = nn.ModuleDict()
self.global_state_projs = nn.ModuleDict()
self.memory_heads = nn.ModuleDict()
self.correction_heads = nn.ModuleDict()
for key in self.dataset_keys:
self.local_state_projs[key] = nn.Linear(self.input_dim, memory_dim, device=device, dtype=dtype)
self.global_state_projs[key] = nn.Linear(global_input_dim, memory_dim, device=device, dtype=dtype)
self.memory_heads[key] = VectorMemoryHead(
hidden_dim=memory_dim, num_memory_slots=num_memory_slots,
num_heads=memory_num_heads, ff_dim=memory_dim * 2,
num_long_term_memory_slots=32, # Enable L2 Cache
device=device, dtype=dtype
)
self.correction_heads[key] = nn.Linear(memory_dim, 2 * self.output_dim, device=device, dtype=dtype)
self.refinement_passes: int = 2 # Default to 2 passes for deeper refinement
self.last_corrected_activation: Optional[torch.Tensor] = None
self.last_additive_correction: Optional[torch.Tensor] = None
self.last_memory_input: Optional[torch.Tensor] = None
self.last_reconstructed_from_memory: Optional[torch.Tensor] = None
def forward(self, x: torch.Tensor):
base_output = self.linear(x)
# Determine which set of specialized parameters to use
dataset_key = self.global_state_storage.get('dataset_key')
if not dataset_key or 'embeds' not in self.global_state_storage or self.refinement_passes < 1:
return base_output
# Select the correct modules for the current context
local_state_proj = self.local_state_projs[dataset_key]
global_state_proj = self.global_state_projs[dataset_key]
memory_head = self.memory_heads[dataset_key]
correction_head = self.correction_heads[dataset_key]
global_embeds = self.global_state_storage['embeds']
if global_embeds.shape[1] != x.shape[1]: global_embeds = global_embeds[:, -x.shape[1]:, :]
B, S, _ = x.shape
# Ensure LTM is initialized with correct batch size for the specific memory head
if memory_head.use_long_term_memory and memory_head.long_term_memory.shape[0] != B:
memory_head.long_term_memory.data = memory_head.long_term_memory.data.expand(B, -1, -1)
with torch.no_grad(): # Use no_grad for the refinement loop as it's an inference-like process
proj_local = local_state_proj(x.detach())
proj_global = global_state_proj(global_embeds.detach())
memory_input = torch.stack([proj_global, proj_local], dim=2)
memory_input_flat = memory_input.view(B * S, 2, self.memory_dim)
compressed_mem_flat, _ = memory_head(memory_input_flat)
aggregated_thought = compressed_mem_flat.mean(dim=1).view(B, S, self.memory_dim)
# Iteratively refine the output using state-feedback
corrected_activation = base_output
current_thought = aggregated_thought
for _ in range(self.refinement_passes):
raw_correction = correction_head(current_thought)
gate, value = torch.chunk(raw_correction, 2, dim=-1)
corrected_activation = corrected_activation * torch.sigmoid(gate) + value
current_thought_flat = current_thought.view(B * S, self.memory_dim)
refined_thought, _ = memory_head.decoder_attention(
query=current_thought_flat.unsqueeze(1), key=compressed_mem_flat, value=compressed_mem_flat
)
refined_thought = memory_head.decoder_layernorm(refined_thought.squeeze(1) + current_thought_flat)
current_thought = refined_thought.view(B, S, self.memory_dim)
if self.training:
with torch.enable_grad():
proj_local_grad = local_state_proj(x)
proj_global_grad = global_state_proj(global_embeds)
memory_input_grad = torch.stack([proj_global_grad, proj_local_grad], dim=2)
memory_input_flat_grad = memory_input_grad.view(B * S, 2, self.memory_dim)
compressed_mem_flat_grad, recon_flat_grad = memory_head(memory_input_flat_grad)
aggregated_thought_grad = compressed_mem_flat_grad.mean(dim=1).view(B, S, self.memory_dim)
raw_correction_grad = correction_head(aggregated_thought_grad)
gate_grad, value_grad = torch.chunk(raw_correction_grad, 2, dim=-1)
final_activation = base_output * torch.sigmoid(gate_grad.to(x.dtype)) + value_grad.to(x.dtype)
self.last_corrected_activation = final_activation
self.last_additive_correction = value_grad
self.last_memory_input = memory_input_flat_grad
self.last_reconstructed_from_memory = recon_flat_grad
return final_activation
else:
return corrected_activation.to(x.dtype)
# --- BUILDING BLOCK 3: The Full Custom Model Wrapper (for saving/loading) ---
class Phi3WithVectorMemoryForCausalLM(Phi3ForCausalLM):
def __init__(self, config):
super().__init__(config)
self.global_state_storage = {}
# Target a central layer in the network for maximum impact
self.target_layer_path = "model.layers.15.mlp.gate_up_proj"
self.model.embed_tokens.register_forward_hook(
lambda module, input, output: self.global_state_storage.update({'embeds': output.detach()})
)
# This logic is primarily for loading a pre-trained model.
# The training script handles the initial creation.
if hasattr(config, "dataset_keys") and config.dataset_keys:
try:
print(f"Re-initializing GCVectorMemoryLayer with dataset keys: {config.dataset_keys}")
original_layer = self.get_submodule(self.target_layer_path)
custom_layer = GCVectorMemoryLayer(
original_layer=original_layer, global_input_dim=config.hidden_size,
memory_dim=64,
num_memory_slots=8,
memory_num_heads=4,
global_state_storage=self.global_state_storage,
dataset_keys=config.dataset_keys # Use keys from config
)
parent_path = ".".join(self.target_layer_path.split('.')[:-1])
child_name = self.target_layer_path.split('.')[-1]
setattr(self.get_submodule(parent_path), child_name, custom_layer)
print(f"Successfully reloaded and replaced '{self.target_layer_path}' with specialized GCVectorMemoryLayer.")
except AttributeError:
print(f"Could not find target layer '{self.target_layer_path}' during reload. Model remains unmodified.")
else:
print("No 'dataset_keys' found in config. The custom layer will not be initialized.")
|