chshan commited on
Commit
a96a115
·
verified ·
1 Parent(s): 6756f3b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -12
app.py CHANGED
@@ -29,23 +29,29 @@ token2id["<EOS>"] = 1
29
  id2token = {i: t for t, i in token2id.items()}
30
  VOCAB_SIZE = len(token2id)
31
 
32
- # --- Predictor Model Architecture (from antioxidant_predictor_5.py) ---
33
  class AntioxidantPredictor(nn.Module):
34
- # This class definition should be an exact copy from your project
35
  def __init__(self, input_dim, transformer_layers=3, transformer_heads=4, transformer_dropout=0.1):
36
  super(AntioxidantPredictor, self).__init__()
37
- self.input_dim = input_dim
 
38
  self.t5_dim = 1024
39
- self.hand_crafted_dim = self.input_dim - self.t5_dim
40
 
 
41
  encoder_layer = nn.TransformerEncoderLayer(
42
- d_model=self.t5_dim, nhead=transformer_heads,
43
- dropout=transformer_dropout, batch_first=True
 
 
44
  )
45
  self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=transformer_layers)
46
 
 
 
 
47
  self.mlp = nn.Sequential(
48
- nn.Linear(self.input_dim, 512),
49
  nn.ReLU(),
50
  nn.Dropout(0.5),
51
  nn.Linear(512, 256),
@@ -56,18 +62,29 @@ class AntioxidantPredictor(nn.Module):
56
  self.temperature = nn.Parameter(torch.ones(1))
57
 
58
  def forward(self, fused_features):
59
- tr_features = fused_features[:, :self.t5_dim]
60
- hand_features = fused_features[:, self.t5_dim:]
61
- tr_features_unsqueezed = tr_features.unsqueeze(1)
62
- transformer_output = self.transformer_encoder(tr_features_unsqueezed)
 
 
 
63
  transformer_output_pooled = transformer_output.mean(dim=1)
64
- combined_features = torch.cat((transformer_output_pooled, hand_features), dim=1)
 
 
 
 
65
  logits = self.mlp(combined_features)
 
66
  return logits / self.temperature
67
 
68
  def get_temperature(self):
69
  return self.temperature.item()
70
 
 
 
 
71
  # --- Generator Model Architecture (from generator.py) ---
72
  class ProtT5Generator(nn.Module):
73
  # This class definition should be an exact copy from your project
 
29
  id2token = {i: t for t, i in token2id.items()}
30
  VOCAB_SIZE = len(token2id)
31
 
32
+ # --- Predictor Model Architecture (Corrected to match saved weights) ---
33
  class AntioxidantPredictor(nn.Module):
 
34
  def __init__(self, input_dim, transformer_layers=3, transformer_heads=4, transformer_dropout=0.1):
35
  super(AntioxidantPredictor, self).__init__()
36
+ # 根据错误日志和您的训练脚本,我们知道输入维度是固定的
37
+ # 并且模型内部处理 ProtT5 和传统特征的分离
38
  self.t5_dim = 1024
39
+ self.hand_crafted_dim = input_dim - self.t5_dim
40
 
41
+ # 定义 Transformer Encoder
42
  encoder_layer = nn.TransformerEncoderLayer(
43
+ d_model=self.t5_dim,
44
+ nhead=transformer_heads,
45
+ dropout=transformer_dropout,
46
+ batch_first=True
47
  )
48
  self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=transformer_layers)
49
 
50
+ # 定义 MLP
51
+ # 错误日志表明权重文件没有 fusion_fc 和 classifier,只有一个 mlp
52
+ # 我们根据 predictor_train.py 的原始结构来重建
53
  self.mlp = nn.Sequential(
54
+ nn.Linear(input_dim, 512),
55
  nn.ReLU(),
56
  nn.Dropout(0.5),
57
  nn.Linear(512, 256),
 
62
  self.temperature = nn.Parameter(torch.ones(1))
63
 
64
  def forward(self, fused_features):
65
+ # 这个前向传播逻辑与您的训练脚本 predictor_train.py 更为匹配
66
+ prot_t5_features = fused_features[:, :self.t5_dim]
67
+ hand_crafted_features = fused_features[:, self.t5_dim:]
68
+
69
+ # Transformer 只处理 ProtT5 特征
70
+ prot_t5_features_unsqueezed = prot_t5_features.unsqueeze(1)
71
+ transformer_output = self.transformer_encoder(prot_t5_features_unsqueezed)
72
  transformer_output_pooled = transformer_output.mean(dim=1)
73
+
74
+ # 将处理后的 ProtT5 特征与传统特征拼接
75
+ combined_features = torch.cat((transformer_output_pooled, hand_crafted_features), dim=1)
76
+
77
+ # 将最终拼接的特征送入 MLP
78
  logits = self.mlp(combined_features)
79
+
80
  return logits / self.temperature
81
 
82
  def get_temperature(self):
83
  return self.temperature.item()
84
 
85
+ def set_temperature(self, temp_value, device):
86
+ self.temperature = nn.Parameter(torch.tensor([temp_value], device=device), requires_grad=False)
87
+
88
  # --- Generator Model Architecture (from generator.py) ---
89
  class ProtT5Generator(nn.Module):
90
  # This class definition should be an exact copy from your project