Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -29,23 +29,29 @@ token2id["<EOS>"] = 1
|
|
29 |
id2token = {i: t for t, i in token2id.items()}
|
30 |
VOCAB_SIZE = len(token2id)
|
31 |
|
32 |
-
# --- Predictor Model Architecture (
|
33 |
class AntioxidantPredictor(nn.Module):
|
34 |
-
# This class definition should be an exact copy from your project
|
35 |
def __init__(self, input_dim, transformer_layers=3, transformer_heads=4, transformer_dropout=0.1):
|
36 |
super(AntioxidantPredictor, self).__init__()
|
37 |
-
|
|
|
38 |
self.t5_dim = 1024
|
39 |
-
self.hand_crafted_dim =
|
40 |
|
|
|
41 |
encoder_layer = nn.TransformerEncoderLayer(
|
42 |
-
d_model=self.t5_dim,
|
43 |
-
|
|
|
|
|
44 |
)
|
45 |
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=transformer_layers)
|
46 |
|
|
|
|
|
|
|
47 |
self.mlp = nn.Sequential(
|
48 |
-
nn.Linear(
|
49 |
nn.ReLU(),
|
50 |
nn.Dropout(0.5),
|
51 |
nn.Linear(512, 256),
|
@@ -56,18 +62,29 @@ class AntioxidantPredictor(nn.Module):
|
|
56 |
self.temperature = nn.Parameter(torch.ones(1))
|
57 |
|
58 |
def forward(self, fused_features):
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
63 |
transformer_output_pooled = transformer_output.mean(dim=1)
|
64 |
-
|
|
|
|
|
|
|
|
|
65 |
logits = self.mlp(combined_features)
|
|
|
66 |
return logits / self.temperature
|
67 |
|
68 |
def get_temperature(self):
|
69 |
return self.temperature.item()
|
70 |
|
|
|
|
|
|
|
71 |
# --- Generator Model Architecture (from generator.py) ---
|
72 |
class ProtT5Generator(nn.Module):
|
73 |
# This class definition should be an exact copy from your project
|
|
|
29 |
id2token = {i: t for t, i in token2id.items()}
|
30 |
VOCAB_SIZE = len(token2id)
|
31 |
|
32 |
+
# --- Predictor Model Architecture (Corrected to match saved weights) ---
|
33 |
class AntioxidantPredictor(nn.Module):
|
|
|
34 |
def __init__(self, input_dim, transformer_layers=3, transformer_heads=4, transformer_dropout=0.1):
|
35 |
super(AntioxidantPredictor, self).__init__()
|
36 |
+
# 根据错误日志和您的训练脚本,我们知道输入维度是固定的
|
37 |
+
# 并且模型内部处理 ProtT5 和传统特征的分离
|
38 |
self.t5_dim = 1024
|
39 |
+
self.hand_crafted_dim = input_dim - self.t5_dim
|
40 |
|
41 |
+
# 定义 Transformer Encoder
|
42 |
encoder_layer = nn.TransformerEncoderLayer(
|
43 |
+
d_model=self.t5_dim,
|
44 |
+
nhead=transformer_heads,
|
45 |
+
dropout=transformer_dropout,
|
46 |
+
batch_first=True
|
47 |
)
|
48 |
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=transformer_layers)
|
49 |
|
50 |
+
# 定义 MLP
|
51 |
+
# 错误日志表明权重文件没有 fusion_fc 和 classifier,只有一个 mlp
|
52 |
+
# 我们根据 predictor_train.py 的原始结构来重建
|
53 |
self.mlp = nn.Sequential(
|
54 |
+
nn.Linear(input_dim, 512),
|
55 |
nn.ReLU(),
|
56 |
nn.Dropout(0.5),
|
57 |
nn.Linear(512, 256),
|
|
|
62 |
self.temperature = nn.Parameter(torch.ones(1))
|
63 |
|
64 |
def forward(self, fused_features):
|
65 |
+
# 这个前向传播逻辑与您的训练脚本 predictor_train.py 更为匹配
|
66 |
+
prot_t5_features = fused_features[:, :self.t5_dim]
|
67 |
+
hand_crafted_features = fused_features[:, self.t5_dim:]
|
68 |
+
|
69 |
+
# Transformer 只处理 ProtT5 特征
|
70 |
+
prot_t5_features_unsqueezed = prot_t5_features.unsqueeze(1)
|
71 |
+
transformer_output = self.transformer_encoder(prot_t5_features_unsqueezed)
|
72 |
transformer_output_pooled = transformer_output.mean(dim=1)
|
73 |
+
|
74 |
+
# 将处理后的 ProtT5 特征与传统特征拼接
|
75 |
+
combined_features = torch.cat((transformer_output_pooled, hand_crafted_features), dim=1)
|
76 |
+
|
77 |
+
# 将最终拼接的特征送入 MLP
|
78 |
logits = self.mlp(combined_features)
|
79 |
+
|
80 |
return logits / self.temperature
|
81 |
|
82 |
def get_temperature(self):
|
83 |
return self.temperature.item()
|
84 |
|
85 |
+
def set_temperature(self, temp_value, device):
|
86 |
+
self.temperature = nn.Parameter(torch.tensor([temp_value], device=device), requires_grad=False)
|
87 |
+
|
88 |
# --- Generator Model Architecture (from generator.py) ---
|
89 |
class ProtT5Generator(nn.Module):
|
90 |
# This class definition should be an exact copy from your project
|