File size: 6,853 Bytes
5273503 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import torch.nn as nn
import numpy as np
import torch
class CustomSigmoid(nn.Module):
"""
Custom Sigmoid function with alpha and beta parameters
alpha: scaling factor
beta: shifting factor
a: scaling factor for the output
b: shifting factor for the output
"""
def __init__(self, alpha=1.0, beta=0.0, a=1, b=0):
super(CustomSigmoid, self).__init__()
self.alpha = alpha
self.beta = beta
self.a = a
self.b = b
def forward(self, x):
return self.a * (1 / (1 + torch.exp(-self.alpha * (x - self.beta)))) + self.b
class RMSNorm(nn.Module):
"""
Root Mean Square Normalization Layer
'https://arxiv.org/abs/1910.07467'
"""
def __init__(self, eps=1e-15):
super(RMSNorm, self).__init__()
self.eps = eps
self.weight = nn.Parameter(torch.rand(1))
def forward(self, x):
# calculate the root mean square normalization
norm = torch.sqrt(torch.mean(x.pow(2), dim=-1, keepdim=True) + self.eps)
# normalize and scale
x_normalized = x / norm
return x_normalized * self.weight
class Small_MLP(nn.Module):
"""
Small MLP for the input and output mapping
input: in dimension
hidden: mid dimension
output: out dimension
"""
def __init__(self, in_dim, mid_dim, out_dim):
super(Small_MLP, self).__init__()
self.mlp = nn.Sequential(
nn.Linear(in_dim, mid_dim),
RMSNorm(),
nn.GELU(),
nn.Linear(mid_dim, mid_dim),
RMSNorm(),
nn.GELU(),
nn.Linear(mid_dim, mid_dim),
RMSNorm(),
nn.GELU(),
nn.Linear(mid_dim, out_dim),
)
def forward(self, x):
return self.mlp(x)
class MSA(nn.Module):
"""
Multi-head self-attention layer
d: hidden dimension
n_heads: number of heads
"""
def __init__(self, d, n_heads):
super(MSA, self).__init__()
self.d = d
self.n_heads = n_heads
self.d_head = int(d // n_heads)
self.q_map = nn.ModuleList([nn.Linear(self.d_head, self.d_head) for _ in range(n_heads)])
self.k_map = nn.ModuleList([nn.Linear(self.d_head, self.d_head) for _ in range(n_heads)])
self.v_map = nn.ModuleList([nn.Linear(self.d_head, self.d_head) for _ in range(n_heads)])
self.softmax = nn.Softmax(dim=-1)
def forward(self, sequences):
# split the sequences into n_heads
q = [q_map(sequences[:, :, i*self.d_head:(i+1)*self.d_head]) for i, q_map in enumerate(self.q_map)]
k = [k_map(sequences[:, :, i*self.d_head:(i+1)*self.d_head]) for i, k_map in enumerate(self.k_map)]
v = [v_map(sequences[:, :, i*self.d_head:(i+1)*self.d_head]) for i, v_map in enumerate(self.v_map)]
results = []
for i in range(self.n_heads):
# calculate the attention score
attn_score = torch.bmm(q[i], k[i].transpose(1, 2)) / np.sqrt(self.d_head)
attn_score = self.softmax(attn_score)
# calculate the output
output = torch.bmm(attn_score, v[i])
results.append(output)
return torch.cat(results, dim=2)
class Vit_block(nn.Module):
"""
A transformer with Multi-head self-attention and MLP
hidden_d: hidden dimension
n_heads: number of heads
mlp_ratio: mlp ratio for the hidden dimension
"""
def __init__(self, hidden_d, n_heads, mlp_ratio=4.0):
super(Vit_block, self).__init__()
self.hidden_d = hidden_d
self.n_heads = n_heads
self.msa = MSA(hidden_d, n_heads)
self.norm1 = RMSNorm() #nn.LayerNorm([length, hidden_d]) # RMSNorm(hidden_d )
self.norm2 = RMSNorm() #nn.LayerNorm([length, hidden_d]) # RMSNorm(hidden_d)
self.mlp = nn.Sequential(
nn.Linear(hidden_d, int(hidden_d * mlp_ratio)),
RMSNorm(),
nn.GELU(),
nn.Linear(int(hidden_d * mlp_ratio), int(hidden_d * mlp_ratio)),
RMSNorm(),
nn.GELU(),
nn.Linear(int(hidden_d * mlp_ratio), hidden_d),
)
def forward(self, x):
x = x + self.msa(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
class ViT_encodernopara(nn.Module):
"""
A transformer encoder
chw: input data shape (1, num_days, num_time_steps+1), channel is always 1, the last dimension is the time steps+1 because of the date embedding has 1 more dimension
hidden_d: hidden dimension
out_d: output dimension (number of the time steps)
n_heads: number of heads
mlp_ratio: mlp ratio for the hidden dimension
n_blocks: number of transformer blocks
alpha: scaling factor for the sigmoid
beta: shifting factor for the sigmoid
"""
def __init__(self,
chw = (1, 24, 24),
hidden_d = 96,
out_d = 2,
n_heads = 6,
mlp_ratio = 4.0,
n_blocks = 3,
alpha=1,
beta=0.5
):
# Super constructor
super(ViT_encodernopara, self).__init__()
# input data shape (N, 365, 24)
self.chw = chw # channel, height, width = 1, 365, 24
self.hidden_d = hidden_d
self.out_d = out_d
self.linear_map_in = Small_MLP(self.chw[2], self.hidden_d, self.hidden_d) # nn.Linear(self.chw[2], self.hidden_d)
self.linear_map_out2 = Small_MLP(self.hidden_d, self.hidden_d, self.out_d) # nn.Linear(self.hidden_d, self.out_d)
# Vit block
self.n_heads = n_heads
self.mlp_ratio = mlp_ratio
self.n_blocks = n_blocks
self.vit_blocks = nn.ModuleList([Vit_block(self.hidden_d, self.n_heads,
self.mlp_ratio) for _ in range(self.n_blocks)])
# output adding layer
self.sig = CustomSigmoid(alpha, beta) # nn.Sigmoid() #CustomSigmoid(alpha, beta)
self.bias = 0.000001
def forward(self, images):
_images = images
tokens = self.linear_map_in(_images)
for block in self.vit_blocks:
tokens = block(tokens)
tokens = self.linear_map_out2(tokens)
return tokens
def output_adding_layer(self, _new_para, _param):
b, _, _ = _new_para.shape
_new_para = _new_para.view(b, -1) + _param.view(b, -1)
_new_para = self.sig(_new_para) + self.bias
return _new_para
|