Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -16,10 +16,7 @@
|
|
16 |
# β βββ final_generator_model.pth
|
17 |
# βββ prott5/
|
18 |
# β βββ model/
|
19 |
-
# β
|
20 |
-
# β βββ pytorch_model.bin (The base ProtT5 model from Rostlab)
|
21 |
-
# β βββ finetuned_prott5.bin (Your fine-tuned feature extractor weights)
|
22 |
-
# β βββ ... (other tokenizer files)
|
23 |
# βββ requirements.txt
|
24 |
|
25 |
import os
|
@@ -50,25 +47,27 @@ VOCAB_SIZE = len(token2id)
|
|
50 |
|
51 |
|
52 |
# --- Feature Extractor Model Class (For ProtT5) ---
|
53 |
-
# This class
|
|
|
54 |
class FeatureProtT5Model:
|
55 |
-
def __init__(self,
|
56 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
57 |
print(f"Initializing ProtT5 for feature extraction on device: {self.device}")
|
58 |
-
|
59 |
-
# Load the base model architecture and tokenizer from the
|
60 |
-
|
61 |
-
self.
|
62 |
-
|
|
|
63 |
# If a path to a fine-tuned weights file is provided, load and apply those weights.
|
64 |
if finetuned_weights_path and os.path.exists(finetuned_weights_path):
|
65 |
-
print(f"Applying fine-tuned weights from: {finetuned_weights_path}")
|
66 |
state_dict = torch.load(finetuned_weights_path, map_location=self.device)
|
67 |
self.model.load_state_dict(state_dict, strict=False)
|
68 |
print("Successfully applied fine-tuned weights.")
|
69 |
else:
|
70 |
print("Warning: Fine-tuned weights not found or not provided. Using base ProtT5 weights.")
|
71 |
-
|
72 |
self.model.to(self.device)
|
73 |
self.model.eval()
|
74 |
|
@@ -82,10 +81,10 @@ class AntioxidantPredictor(nn.Module):
|
|
82 |
self.handcrafted_dim = input_dim - self.prott5_dim
|
83 |
self.seq_len = 16
|
84 |
self.prott5_feature_dim = 64 # 16 * 64 = 1024
|
85 |
-
|
86 |
encoder_layer = nn.TransformerEncoderLayer(d_model=self.prott5_feature_dim, nhead=transformer_heads, dropout=transformer_dropout, batch_first=True)
|
87 |
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=transformer_layers)
|
88 |
-
|
89 |
fused_dim = self.prott5_feature_dim + self.handcrafted_dim
|
90 |
self.fusion_fc = nn.Sequential(nn.Linear(fused_dim, 1024), nn.ReLU(), nn.Dropout(0.3), nn.Linear(1024, 512), nn.ReLU(), nn.Dropout(0.3))
|
91 |
self.classifier = nn.Sequential(nn.Linear(512, 256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, 1))
|
@@ -96,17 +95,17 @@ class AntioxidantPredictor(nn.Module):
|
|
96 |
# The input 'x' is a flat 1914-dim vector from extract_features()
|
97 |
prot_t5_features = x[:, :self.prott5_dim]
|
98 |
handcrafted_features = x[:, self.prott5_dim:]
|
99 |
-
|
100 |
# Reshape the first 1024 features back into a sequence representation
|
101 |
prot_t5_seq = prot_t5_features.view(batch_size, self.seq_len, self.prott5_feature_dim)
|
102 |
-
|
103 |
encoded_seq = self.transformer_encoder(prot_t5_seq)
|
104 |
refined_prott5 = encoded_seq.mean(dim=1)
|
105 |
-
|
106 |
fused_features = torch.cat([refined_prott5, handcrafted_features], dim=1)
|
107 |
fused_output = self.fusion_fc(fused_features)
|
108 |
logits = self.classifier(fused_output)
|
109 |
-
|
110 |
return logits / self.temperature
|
111 |
|
112 |
def get_temperature(self):
|
@@ -123,13 +122,13 @@ class ProtT5Generator(nn.Module):
|
|
123 |
self.vocab_size = vocab_size
|
124 |
self.eos_token_id = token2id["<EOS>"]
|
125 |
self.pad_token_id = token2id["<PAD>"]
|
126 |
-
|
127 |
def forward(self, input_ids):
|
128 |
embeddings = self.embed_tokens(input_ids)
|
129 |
encoder_output = self.encoder(embeddings)
|
130 |
logits = self.lm_head(encoder_output)
|
131 |
return logits
|
132 |
-
|
133 |
def sample(self, batch_size, max_length=20, device="cpu", temperature=2.5, min_decoded_length=3):
|
134 |
start_token = torch.randint(2, self.vocab_size, (batch_size, 1), device=device)
|
135 |
generated = start_token
|
@@ -138,7 +137,7 @@ class ProtT5Generator(nn.Module):
|
|
138 |
next_logits = logits[:, -1, :] / temperature
|
139 |
if generated.size(1) < min_decoded_length:
|
140 |
next_logits[:, self.eos_token_id] = -float("inf")
|
141 |
-
|
142 |
probs = torch.softmax(next_logits, dim=-1)
|
143 |
next_token = torch.multinomial(probs, num_samples=1)
|
144 |
generated = torch.cat((generated, next_token), dim=1)
|
@@ -208,7 +207,9 @@ try:
|
|
208 |
PREDICTOR_CHECKPOINT_PATH = "checkpoints/final_rl_model_logitp0.1_calibrated_FINETUNED_PROTT5.pth"
|
209 |
SCALER_PATH = "checkpoints/scaler_FINETUNED_PROTT5.pkl"
|
210 |
GENERATOR_CHECKPOINT_PATH = "generator_checkpoints_v3.6/final_generator_model.pth"
|
211 |
-
|
|
|
|
|
212 |
FINETUNED_PROTT5_FOR_FEATURES_PATH = "prott5/model/finetuned_prott5.bin"
|
213 |
|
214 |
# --- Load Predictor Model ---
|
@@ -223,8 +224,9 @@ try:
|
|
223 |
print(f"Loading Scaler from: {SCALER_PATH}")
|
224 |
SCALER = joblib.load(SCALER_PATH)
|
225 |
print("Loading ProtT5 Feature Extractor...")
|
|
|
226 |
PROTT5_EXTRACTOR = FeatureProtT5Model(
|
227 |
-
|
228 |
finetuned_weights_path=FINETUNED_PROTT5_FOR_FEATURES_PATH
|
229 |
)
|
230 |
print("β
Scaler and Feature Extractor loaded.")
|
|
|
16 |
# β βββ final_generator_model.pth
|
17 |
# βββ prott5/
|
18 |
# β βββ model/
|
19 |
+
# β βββ finetuned_prott5.bin (Your fine-tuned feature extractor weights)
|
|
|
|
|
|
|
20 |
# βββ requirements.txt
|
21 |
|
22 |
import os
|
|
|
47 |
|
48 |
|
49 |
# --- Feature Extractor Model Class (For ProtT5) ---
|
50 |
+
# MODIFIED: This class now loads the base model from the Hugging Face Hub ID
|
51 |
+
# and then applies your local fine-tuned weights.
|
52 |
class FeatureProtT5Model:
|
53 |
+
def __init__(self, base_model_id, finetuned_weights_path=None):
|
54 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
55 |
print(f"Initializing ProtT5 for feature extraction on device: {self.device}")
|
56 |
+
|
57 |
+
# Load the base model architecture and tokenizer directly from the Hub ID.
|
58 |
+
print(f"Loading base model and tokenizer from '{base_model_id}'...")
|
59 |
+
self.tokenizer = transformers.T5Tokenizer.from_pretrained(base_model_id, do_lower_case=False)
|
60 |
+
self.model = transformers.T5EncoderModel.from_pretrained(base_model_id)
|
61 |
+
|
62 |
# If a path to a fine-tuned weights file is provided, load and apply those weights.
|
63 |
if finetuned_weights_path and os.path.exists(finetuned_weights_path):
|
64 |
+
print(f"Applying local fine-tuned weights from: {finetuned_weights_path}")
|
65 |
state_dict = torch.load(finetuned_weights_path, map_location=self.device)
|
66 |
self.model.load_state_dict(state_dict, strict=False)
|
67 |
print("Successfully applied fine-tuned weights.")
|
68 |
else:
|
69 |
print("Warning: Fine-tuned weights not found or not provided. Using base ProtT5 weights.")
|
70 |
+
|
71 |
self.model.to(self.device)
|
72 |
self.model.eval()
|
73 |
|
|
|
81 |
self.handcrafted_dim = input_dim - self.prott5_dim
|
82 |
self.seq_len = 16
|
83 |
self.prott5_feature_dim = 64 # 16 * 64 = 1024
|
84 |
+
|
85 |
encoder_layer = nn.TransformerEncoderLayer(d_model=self.prott5_feature_dim, nhead=transformer_heads, dropout=transformer_dropout, batch_first=True)
|
86 |
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=transformer_layers)
|
87 |
+
|
88 |
fused_dim = self.prott5_feature_dim + self.handcrafted_dim
|
89 |
self.fusion_fc = nn.Sequential(nn.Linear(fused_dim, 1024), nn.ReLU(), nn.Dropout(0.3), nn.Linear(1024, 512), nn.ReLU(), nn.Dropout(0.3))
|
90 |
self.classifier = nn.Sequential(nn.Linear(512, 256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, 1))
|
|
|
95 |
# The input 'x' is a flat 1914-dim vector from extract_features()
|
96 |
prot_t5_features = x[:, :self.prott5_dim]
|
97 |
handcrafted_features = x[:, self.prott5_dim:]
|
98 |
+
|
99 |
# Reshape the first 1024 features back into a sequence representation
|
100 |
prot_t5_seq = prot_t5_features.view(batch_size, self.seq_len, self.prott5_feature_dim)
|
101 |
+
|
102 |
encoded_seq = self.transformer_encoder(prot_t5_seq)
|
103 |
refined_prott5 = encoded_seq.mean(dim=1)
|
104 |
+
|
105 |
fused_features = torch.cat([refined_prott5, handcrafted_features], dim=1)
|
106 |
fused_output = self.fusion_fc(fused_features)
|
107 |
logits = self.classifier(fused_output)
|
108 |
+
|
109 |
return logits / self.temperature
|
110 |
|
111 |
def get_temperature(self):
|
|
|
122 |
self.vocab_size = vocab_size
|
123 |
self.eos_token_id = token2id["<EOS>"]
|
124 |
self.pad_token_id = token2id["<PAD>"]
|
125 |
+
|
126 |
def forward(self, input_ids):
|
127 |
embeddings = self.embed_tokens(input_ids)
|
128 |
encoder_output = self.encoder(embeddings)
|
129 |
logits = self.lm_head(encoder_output)
|
130 |
return logits
|
131 |
+
|
132 |
def sample(self, batch_size, max_length=20, device="cpu", temperature=2.5, min_decoded_length=3):
|
133 |
start_token = torch.randint(2, self.vocab_size, (batch_size, 1), device=device)
|
134 |
generated = start_token
|
|
|
137 |
next_logits = logits[:, -1, :] / temperature
|
138 |
if generated.size(1) < min_decoded_length:
|
139 |
next_logits[:, self.eos_token_id] = -float("inf")
|
140 |
+
|
141 |
probs = torch.softmax(next_logits, dim=-1)
|
142 |
next_token = torch.multinomial(probs, num_samples=1)
|
143 |
generated = torch.cat((generated, next_token), dim=1)
|
|
|
207 |
PREDICTOR_CHECKPOINT_PATH = "checkpoints/final_rl_model_logitp0.1_calibrated_FINETUNED_PROTT5.pth"
|
208 |
SCALER_PATH = "checkpoints/scaler_FINETUNED_PROTT5.pkl"
|
209 |
GENERATOR_CHECKPOINT_PATH = "generator_checkpoints_v3.6/final_generator_model.pth"
|
210 |
+
|
211 |
+
# Define the base model ID from the Hub and the path to your local fine-tuned weights.
|
212 |
+
PROTT5_BASE_MODEL_ID = "Rostlab/prot_t5_xl_uniref50"
|
213 |
FINETUNED_PROTT5_FOR_FEATURES_PATH = "prott5/model/finetuned_prott5.bin"
|
214 |
|
215 |
# --- Load Predictor Model ---
|
|
|
224 |
print(f"Loading Scaler from: {SCALER_PATH}")
|
225 |
SCALER = joblib.load(SCALER_PATH)
|
226 |
print("Loading ProtT5 Feature Extractor...")
|
227 |
+
# Pass the Hub ID to the updated class to load the base model.
|
228 |
PROTT5_EXTRACTOR = FeatureProtT5Model(
|
229 |
+
base_model_id=PROTT5_BASE_MODEL_ID,
|
230 |
finetuned_weights_path=FINETUNED_PROTT5_FOR_FEATURES_PATH
|
231 |
)
|
232 |
print("β
Scaler and Feature Extractor loaded.")
|