File size: 7,508 Bytes
b2bdd6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# --- START OF FILE architecture.py ---

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import Phi3Config, Phi3ForCausalLM
from typing import Optional, Dict

# --- BUILDING BLOCK 1: VectorMemoryHead (No changes needed here, it inherits dtype correctly) ---
class VectorMemoryHead(nn.Module):
    def __init__(self, hidden_dim: int, num_memory_slots: int, num_heads: int, ff_dim: int, device=None, dtype=None):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_memory_slots = num_memory_slots
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim, nhead=num_heads, dim_feedforward=ff_dim, dropout=0.1, batch_first=True,
            device=device, dtype=dtype
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=1)
        self.memory_queries = nn.Parameter(torch.randn(1, num_memory_slots, hidden_dim, device=device, dtype=dtype))
        self.memory_attention = nn.MultiheadAttention(
            embed_dim=hidden_dim, num_heads=num_heads, dropout=0.1, batch_first=True,
            device=device, dtype=dtype
        )
        self.memory_layernorm = nn.LayerNorm(hidden_dim, device=device, dtype=dtype)
        self.decoder_attention = nn.MultiheadAttention(
            embed_dim=hidden_dim, num_heads=num_heads, dropout=0.1, batch_first=True,
            device=device, dtype=dtype
        )
        self.decoder_layernorm = nn.LayerNorm(hidden_dim, device=device, dtype=dtype)
        self.decoder_ffn = nn.Sequential(
            nn.Linear(hidden_dim, ff_dim, device=device, dtype=dtype),
            nn.ReLU(),
            nn.Linear(ff_dim, hidden_dim, device=device, dtype=dtype)
        )

    def forward(self, memory_input_sequence: torch.Tensor):
        batch_size = memory_input_sequence.shape[0]
        encoded_vectors = self.encoder(memory_input_sequence)
        queries = self.memory_queries.expand(batch_size, -1, -1)
        compressed_memory, _ = self.memory_attention(query=queries, key=encoded_vectors, value=encoded_vectors)
        compressed_memory = self.memory_layernorm(compressed_memory + queries)
        reconstructed, _ = self.decoder_attention(query=encoded_vectors, key=compressed_memory, value=compressed_memory)
        reconstructed_vectors = self.decoder_layernorm(reconstructed + encoded_vectors)
        reconstructed_vectors = self.decoder_ffn(reconstructed_vectors)
        return compressed_memory, reconstructed_vectors

# --- BUILDING BLOCK 2: The Custom Layer (With Iterative Self-Correction) ---
class GCVectorMemoryLayer(nn.Module):
    def __init__(self, original_layer: nn.Linear, global_input_dim: int,
                 memory_dim: int, num_memory_slots: int, memory_num_heads: int,
                 global_state_storage: Dict):
        super().__init__()
        self.input_dim = original_layer.in_features
        self.output_dim = original_layer.out_features
        self.memory_dim = memory_dim
        self.global_state_storage = global_state_storage
        self.linear = original_layer

        device, dtype = self.linear.weight.device, self.linear.weight.dtype

        # This part is correct: initialize with the correct dtype
        self.local_state_proj = nn.Linear(self.input_dim, memory_dim, device=device, dtype=dtype)
        self.global_state_proj = nn.Linear(global_input_dim, memory_dim, device=device, dtype=dtype)
        self.memory_head = VectorMemoryHead(
            hidden_dim=memory_dim, num_memory_slots=num_memory_slots,
            num_heads=memory_num_heads, ff_dim=memory_dim * 2, device=device, dtype=dtype
        )
        self.correction_head = nn.Linear(memory_dim, 2 * self.output_dim, device=device, dtype=dtype)

        # --- NEW: Parameter for iterative self-correction ---
        # This can be changed at inference time to apply the correction multiple times.
        # Default is 1 to match training behavior.
        self.num_correction_passes: int = 1

        self.last_corrected_activation: Optional[torch.Tensor] = None
        self.last_additive_correction: Optional[torch.Tensor] = None
        self.last_memory_input: Optional[torch.Tensor] = None
        self.last_reconstructed_from_memory: Optional[torch.Tensor] = None

    def forward(self, x: torch.Tensor):
        base_output = self.linear(x)

        # If no global state is available or correction is disabled, return base output.
        if 'embeds' not in self.global_state_storage or self.num_correction_passes < 1:
            return base_output

        global_embeds = self.global_state_storage['embeds']
        if global_embeds.shape[1] != x.shape[1]: global_embeds = global_embeds[:, -x.shape[1]:, :]
        B, S, _ = x.shape

        with torch.enable_grad():
            # --- 1. Calculate the correction signal ONCE ---
            proj_local = self.local_state_proj(x)
            proj_global = self.global_state_proj(global_embeds)

            memory_input = torch.stack([proj_global, proj_local], dim=2)
            memory_input_flat = memory_input.view(B * S, 2, self.memory_dim)
            compressed_mem_flat, recon_flat = self.memory_head(memory_input_flat)
            aggregated_thought_flat = compressed_mem_flat.mean(dim=1)
            aggregated_thought = aggregated_thought_flat.view(B, S, self.memory_dim)
            raw_correction = self.correction_head(aggregated_thought)
            gate, value = torch.chunk(raw_correction, 2, dim=-1)

            # --- 2. Iteratively apply the correction signal ---
            corrected_activation = base_output
            for _ in range(self.num_correction_passes):
                corrected_activation = corrected_activation * torch.sigmoid(gate.to(x.dtype)) + value.to(x.dtype)

        # During training, store the final activation and the original correction signal
        # for loss calculation.
        if self.training:
            self.last_corrected_activation = corrected_activation
            self.last_additive_correction = value # The 'value' is the core additive signal
            self.last_memory_input = memory_input_flat
            self.last_reconstructed_from_memory = recon_flat

        return corrected_activation

# --- BUILDING BLOCK 3: The Full Custom Model ---
class Phi3WithVectorMemoryForCausalLM(Phi3ForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        self.global_state_storage = {}
        self.target_layer_path = "model.layers.15.mlp.gate_up_proj"

        self.model.embed_tokens.register_forward_hook(
            lambda module, input, output: self.global_state_storage.update({'embeds': output.detach()})
        )

        try:
            original_layer = self.get_submodule(self.target_layer_path)
            custom_layer = GCVectorMemoryLayer(
                original_layer=original_layer, global_input_dim=config.hidden_size,
                memory_dim=64, num_memory_slots=8, memory_num_heads=4,
                global_state_storage=self.global_state_storage
            )
            parent_path = ".".join(self.target_layer_path.split('.')[:-1])
            child_name = self.target_layer_path.split('.')[-1]
            setattr(self.get_submodule(parent_path), child_name, custom_layer)
            print(f"Successfully replaced '{self.target_layer_path}' with GCVectorMemoryLayer.")
        except AttributeError:
            print(f"Could not find target layer '{self.target_layer_path}'. Model remains unmodified.")
# --- END OF FILE architecture.py ---