chshan commited on
Commit
8a4f49a
Β·
verified Β·
1 Parent(s): 6f96910

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -24
app.py CHANGED
@@ -16,10 +16,7 @@
16
  # β”‚ └── final_generator_model.pth
17
  # β”œβ”€β”€ prott5/
18
  # β”‚ └── model/
19
- # β”‚ β”œβ”€β”€ config.json
20
- # β”‚ β”œβ”€β”€ pytorch_model.bin (The base ProtT5 model from Rostlab)
21
- # β”‚ β”œβ”€β”€ finetuned_prott5.bin (Your fine-tuned feature extractor weights)
22
- # β”‚ └── ... (other tokenizer files)
23
  # └── requirements.txt
24
 
25
  import os
@@ -50,25 +47,27 @@ VOCAB_SIZE = len(token2id)
50
 
51
 
52
  # --- Feature Extractor Model Class (For ProtT5) ---
53
- # This class robustly loads the base ProtT5 model and applies your fine-tuned weights.
 
54
  class FeatureProtT5Model:
55
- def __init__(self, model_dir_path, finetuned_weights_path=None):
56
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
57
  print(f"Initializing ProtT5 for feature extraction on device: {self.device}")
58
-
59
- # Load the base model architecture and tokenizer from the specified directory.
60
- self.tokenizer = transformers.T5Tokenizer.from_pretrained(model_dir_path, do_lower_case=False)
61
- self.model = transformers.T5EncoderModel.from_pretrained(model_dir_path)
62
-
 
63
  # If a path to a fine-tuned weights file is provided, load and apply those weights.
64
  if finetuned_weights_path and os.path.exists(finetuned_weights_path):
65
- print(f"Applying fine-tuned weights from: {finetuned_weights_path}")
66
  state_dict = torch.load(finetuned_weights_path, map_location=self.device)
67
  self.model.load_state_dict(state_dict, strict=False)
68
  print("Successfully applied fine-tuned weights.")
69
  else:
70
  print("Warning: Fine-tuned weights not found or not provided. Using base ProtT5 weights.")
71
-
72
  self.model.to(self.device)
73
  self.model.eval()
74
 
@@ -82,10 +81,10 @@ class AntioxidantPredictor(nn.Module):
82
  self.handcrafted_dim = input_dim - self.prott5_dim
83
  self.seq_len = 16
84
  self.prott5_feature_dim = 64 # 16 * 64 = 1024
85
-
86
  encoder_layer = nn.TransformerEncoderLayer(d_model=self.prott5_feature_dim, nhead=transformer_heads, dropout=transformer_dropout, batch_first=True)
87
  self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=transformer_layers)
88
-
89
  fused_dim = self.prott5_feature_dim + self.handcrafted_dim
90
  self.fusion_fc = nn.Sequential(nn.Linear(fused_dim, 1024), nn.ReLU(), nn.Dropout(0.3), nn.Linear(1024, 512), nn.ReLU(), nn.Dropout(0.3))
91
  self.classifier = nn.Sequential(nn.Linear(512, 256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, 1))
@@ -96,17 +95,17 @@ class AntioxidantPredictor(nn.Module):
96
  # The input 'x' is a flat 1914-dim vector from extract_features()
97
  prot_t5_features = x[:, :self.prott5_dim]
98
  handcrafted_features = x[:, self.prott5_dim:]
99
-
100
  # Reshape the first 1024 features back into a sequence representation
101
  prot_t5_seq = prot_t5_features.view(batch_size, self.seq_len, self.prott5_feature_dim)
102
-
103
  encoded_seq = self.transformer_encoder(prot_t5_seq)
104
  refined_prott5 = encoded_seq.mean(dim=1)
105
-
106
  fused_features = torch.cat([refined_prott5, handcrafted_features], dim=1)
107
  fused_output = self.fusion_fc(fused_features)
108
  logits = self.classifier(fused_output)
109
-
110
  return logits / self.temperature
111
 
112
  def get_temperature(self):
@@ -123,13 +122,13 @@ class ProtT5Generator(nn.Module):
123
  self.vocab_size = vocab_size
124
  self.eos_token_id = token2id["<EOS>"]
125
  self.pad_token_id = token2id["<PAD>"]
126
-
127
  def forward(self, input_ids):
128
  embeddings = self.embed_tokens(input_ids)
129
  encoder_output = self.encoder(embeddings)
130
  logits = self.lm_head(encoder_output)
131
  return logits
132
-
133
  def sample(self, batch_size, max_length=20, device="cpu", temperature=2.5, min_decoded_length=3):
134
  start_token = torch.randint(2, self.vocab_size, (batch_size, 1), device=device)
135
  generated = start_token
@@ -138,7 +137,7 @@ class ProtT5Generator(nn.Module):
138
  next_logits = logits[:, -1, :] / temperature
139
  if generated.size(1) < min_decoded_length:
140
  next_logits[:, self.eos_token_id] = -float("inf")
141
-
142
  probs = torch.softmax(next_logits, dim=-1)
143
  next_token = torch.multinomial(probs, num_samples=1)
144
  generated = torch.cat((generated, next_token), dim=1)
@@ -208,7 +207,9 @@ try:
208
  PREDICTOR_CHECKPOINT_PATH = "checkpoints/final_rl_model_logitp0.1_calibrated_FINETUNED_PROTT5.pth"
209
  SCALER_PATH = "checkpoints/scaler_FINETUNED_PROTT5.pkl"
210
  GENERATOR_CHECKPOINT_PATH = "generator_checkpoints_v3.6/final_generator_model.pth"
211
- PROTT5_BASE_MODEL_PATH = "prott5/model/"
 
 
212
  FINETUNED_PROTT5_FOR_FEATURES_PATH = "prott5/model/finetuned_prott5.bin"
213
 
214
  # --- Load Predictor Model ---
@@ -223,8 +224,9 @@ try:
223
  print(f"Loading Scaler from: {SCALER_PATH}")
224
  SCALER = joblib.load(SCALER_PATH)
225
  print("Loading ProtT5 Feature Extractor...")
 
226
  PROTT5_EXTRACTOR = FeatureProtT5Model(
227
- model_dir_path=PROTT5_BASE_MODEL_PATH,
228
  finetuned_weights_path=FINETUNED_PROTT5_FOR_FEATURES_PATH
229
  )
230
  print("βœ… Scaler and Feature Extractor loaded.")
 
16
  # β”‚ └── final_generator_model.pth
17
  # β”œβ”€β”€ prott5/
18
  # β”‚ └── model/
19
+ # β”‚ └── finetuned_prott5.bin (Your fine-tuned feature extractor weights)
 
 
 
20
  # └── requirements.txt
21
 
22
  import os
 
47
 
48
 
49
  # --- Feature Extractor Model Class (For ProtT5) ---
50
+ # MODIFIED: This class now loads the base model from the Hugging Face Hub ID
51
+ # and then applies your local fine-tuned weights.
52
  class FeatureProtT5Model:
53
+ def __init__(self, base_model_id, finetuned_weights_path=None):
54
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
55
  print(f"Initializing ProtT5 for feature extraction on device: {self.device}")
56
+
57
+ # Load the base model architecture and tokenizer directly from the Hub ID.
58
+ print(f"Loading base model and tokenizer from '{base_model_id}'...")
59
+ self.tokenizer = transformers.T5Tokenizer.from_pretrained(base_model_id, do_lower_case=False)
60
+ self.model = transformers.T5EncoderModel.from_pretrained(base_model_id)
61
+
62
  # If a path to a fine-tuned weights file is provided, load and apply those weights.
63
  if finetuned_weights_path and os.path.exists(finetuned_weights_path):
64
+ print(f"Applying local fine-tuned weights from: {finetuned_weights_path}")
65
  state_dict = torch.load(finetuned_weights_path, map_location=self.device)
66
  self.model.load_state_dict(state_dict, strict=False)
67
  print("Successfully applied fine-tuned weights.")
68
  else:
69
  print("Warning: Fine-tuned weights not found or not provided. Using base ProtT5 weights.")
70
+
71
  self.model.to(self.device)
72
  self.model.eval()
73
 
 
81
  self.handcrafted_dim = input_dim - self.prott5_dim
82
  self.seq_len = 16
83
  self.prott5_feature_dim = 64 # 16 * 64 = 1024
84
+
85
  encoder_layer = nn.TransformerEncoderLayer(d_model=self.prott5_feature_dim, nhead=transformer_heads, dropout=transformer_dropout, batch_first=True)
86
  self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=transformer_layers)
87
+
88
  fused_dim = self.prott5_feature_dim + self.handcrafted_dim
89
  self.fusion_fc = nn.Sequential(nn.Linear(fused_dim, 1024), nn.ReLU(), nn.Dropout(0.3), nn.Linear(1024, 512), nn.ReLU(), nn.Dropout(0.3))
90
  self.classifier = nn.Sequential(nn.Linear(512, 256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, 1))
 
95
  # The input 'x' is a flat 1914-dim vector from extract_features()
96
  prot_t5_features = x[:, :self.prott5_dim]
97
  handcrafted_features = x[:, self.prott5_dim:]
98
+
99
  # Reshape the first 1024 features back into a sequence representation
100
  prot_t5_seq = prot_t5_features.view(batch_size, self.seq_len, self.prott5_feature_dim)
101
+
102
  encoded_seq = self.transformer_encoder(prot_t5_seq)
103
  refined_prott5 = encoded_seq.mean(dim=1)
104
+
105
  fused_features = torch.cat([refined_prott5, handcrafted_features], dim=1)
106
  fused_output = self.fusion_fc(fused_features)
107
  logits = self.classifier(fused_output)
108
+
109
  return logits / self.temperature
110
 
111
  def get_temperature(self):
 
122
  self.vocab_size = vocab_size
123
  self.eos_token_id = token2id["<EOS>"]
124
  self.pad_token_id = token2id["<PAD>"]
125
+
126
  def forward(self, input_ids):
127
  embeddings = self.embed_tokens(input_ids)
128
  encoder_output = self.encoder(embeddings)
129
  logits = self.lm_head(encoder_output)
130
  return logits
131
+
132
  def sample(self, batch_size, max_length=20, device="cpu", temperature=2.5, min_decoded_length=3):
133
  start_token = torch.randint(2, self.vocab_size, (batch_size, 1), device=device)
134
  generated = start_token
 
137
  next_logits = logits[:, -1, :] / temperature
138
  if generated.size(1) < min_decoded_length:
139
  next_logits[:, self.eos_token_id] = -float("inf")
140
+
141
  probs = torch.softmax(next_logits, dim=-1)
142
  next_token = torch.multinomial(probs, num_samples=1)
143
  generated = torch.cat((generated, next_token), dim=1)
 
207
  PREDICTOR_CHECKPOINT_PATH = "checkpoints/final_rl_model_logitp0.1_calibrated_FINETUNED_PROTT5.pth"
208
  SCALER_PATH = "checkpoints/scaler_FINETUNED_PROTT5.pkl"
209
  GENERATOR_CHECKPOINT_PATH = "generator_checkpoints_v3.6/final_generator_model.pth"
210
+
211
+ # Define the base model ID from the Hub and the path to your local fine-tuned weights.
212
+ PROTT5_BASE_MODEL_ID = "Rostlab/prot_t5_xl_uniref50"
213
  FINETUNED_PROTT5_FOR_FEATURES_PATH = "prott5/model/finetuned_prott5.bin"
214
 
215
  # --- Load Predictor Model ---
 
224
  print(f"Loading Scaler from: {SCALER_PATH}")
225
  SCALER = joblib.load(SCALER_PATH)
226
  print("Loading ProtT5 Feature Extractor...")
227
+ # Pass the Hub ID to the updated class to load the base model.
228
  PROTT5_EXTRACTOR = FeatureProtT5Model(
229
+ base_model_id=PROTT5_BASE_MODEL_ID,
230
  finetuned_weights_path=FINETUNED_PROTT5_FOR_FEATURES_PATH
231
  )
232
  print("βœ… Scaler and Feature Extractor loaded.")