Spaces:
Sleeping
Sleeping
Update feature_extract.py
Browse files- feature_extract.py +2 -32
feature_extract.py
CHANGED
@@ -6,19 +6,11 @@ import random
|
|
6 |
import pandas as pd
|
7 |
from Bio.SeqUtils.ProtParam import ProteinAnalysis
|
8 |
from sklearn.model_selection import train_test_split
|
9 |
-
# from sklearn.preprocessing import StandardScaler # 不再使用 StandardScaler
|
10 |
from sklearn.preprocessing import RobustScaler # 导入 RobustScaler
|
11 |
import torch
|
12 |
from transformers import T5EncoderModel, T5Tokenizer
|
13 |
|
14 |
-
|
15 |
-
# compute_amino_acid_composition, compute_reducing_aa_ratio,
|
16 |
-
# compute_physicochemical_properties, compute_electronic_features,
|
17 |
-
# compute_dimer_frequency, positional_encoding, perturb_sequence,
|
18 |
-
# generate_adversarial_samples, extract_features 函数与您之前提供的版本相同。
|
19 |
-
# 为保持简洁,此处省略这些函数的代码。请确保它们在您的文件中是完整的。
|
20 |
-
# 您可以从之前的日志或您本地的文件中复制这些函数。
|
21 |
-
# 以下是 prepare_features 函数的修改版,以及其他函数的占位符。
|
22 |
|
23 |
class ProtT5Model:
|
24 |
"""
|
@@ -76,10 +68,6 @@ class ProtT5Model:
|
|
76 |
return np.zeros((1, 1024), dtype=np.float32)
|
77 |
return emb
|
78 |
|
79 |
-
# --- (此处应包含您之前版本中所有其他的特征提取辅助函数) ---
|
80 |
-
# load_fasta, load_fasta_with_labels, compute_amino_acid_composition, ... extract_features
|
81 |
-
# 为确保完整性,请从您本地的 feature_extract.py 文件中复制这些函数到这里。
|
82 |
-
# 下面是这些函数的一个简化占位符,您需要用实际的函数替换它们。
|
83 |
|
84 |
def load_fasta(fasta_file):
|
85 |
# (您的 load_fasta 实现)
|
@@ -99,7 +87,6 @@ def load_fasta(fasta_file):
|
|
99 |
return sequences
|
100 |
|
101 |
def load_fasta_with_labels(fasta_file):
|
102 |
-
# (您的 load_fasta_with_labels 实现)
|
103 |
sequences, labels = [], []
|
104 |
try:
|
105 |
with open(fasta_file, 'r') as f:
|
@@ -123,7 +110,6 @@ def load_fasta_with_labels(fasta_file):
|
|
123 |
|
124 |
def compute_amino_acid_composition(seq):
|
125 |
if not seq: return {aa: 0.0 for aa in "ACDEFGHIKLMNPQRSTVWY"}
|
126 |
-
# (您的 compute_amino_acid_composition 实现)
|
127 |
amino_acids = "ACDEFGHIKLMNPQRSTVWY"
|
128 |
seq_len = len(seq)
|
129 |
return {aa: seq.upper().count(aa) / seq_len for aa in amino_acids}
|
@@ -131,7 +117,6 @@ def compute_amino_acid_composition(seq):
|
|
131 |
|
132 |
def compute_reducing_aa_ratio(seq):
|
133 |
if not seq: return 0.0
|
134 |
-
# (您的 compute_reducing_aa_ratio 实现)
|
135 |
reducing = ['C', 'M', 'W']
|
136 |
return sum(seq.upper().count(aa) for aa in reducing) / len(seq) if len(seq) > 0 else 0.0
|
137 |
|
@@ -146,7 +131,6 @@ def compute_physicochemical_properties(seq):
|
|
146 |
|
147 |
def compute_electronic_features(seq):
|
148 |
if not seq: return 0.0, 0.0
|
149 |
-
# (您的 compute_electronic_features 实现)
|
150 |
electronegativity = {'A':1.8,'C':2.5,'D':3.0,'E':3.2,'F':2.8,'G':1.6,'H':2.4,'I':4.5,'K':3.0,'L':4.2,'M':4.5,'N':2.0,'P':3.5,'Q':3.5,'R':2.5,'S':1.8,'T':2.5,'V':4.0,'W':5.0,'Y':4.0}
|
151 |
values = [electronegativity.get(aa.upper(), 2.5) for aa in seq]
|
152 |
avg_val = sum(values) / len(values) if values else 2.5
|
@@ -155,7 +139,6 @@ def compute_electronic_features(seq):
|
|
155 |
|
156 |
def compute_dimer_frequency(seq):
|
157 |
if len(seq) < 2: return np.zeros(400) # 20*20
|
158 |
-
# (您的 compute_dimer_frequency 实现)
|
159 |
amino_acids = "ACDEFGHIKLMNPQRSTVWY"
|
160 |
dimer_counts = {aa1+aa2: 0 for aa1 in amino_acids for aa2 in amino_acids}
|
161 |
for i in range(len(seq) - 1):
|
@@ -166,12 +149,7 @@ def compute_dimer_frequency(seq):
|
|
166 |
return np.array([dimer_counts[d] for d in sorted(dimer_counts.keys())])
|
167 |
|
168 |
|
169 |
-
def positional_encoding(seq_len_actual, L_fixed=29, d_model=16):
|
170 |
-
# (您的 positional_encoding 实现)
|
171 |
-
# This PE is fixed length, not dependent on actual seq len if L_fixed is used.
|
172 |
-
# For random short sequences, this fixed PE might be an issue.
|
173 |
-
# A more dynamic PE or no PE for very short sequences might be better.
|
174 |
-
# However, to match current model input, we keep it.
|
175 |
pos_enc = np.zeros((L_fixed, d_model))
|
176 |
for pos in range(L_fixed):
|
177 |
for i in range(d_model):
|
@@ -194,9 +172,6 @@ def perturb_sequence(seq, perturb_rate=0.1, critical=['C', 'M', 'W']):
|
|
194 |
def extract_features(seq, prott5_model_instance, L_fixed=29, d_model_pe=16): # Renamed d_model to d_model_pe
|
195 |
if not seq or not isinstance(seq, str) or len(seq) == 0:
|
196 |
print(f"警告: extract_features 接收到空或无效序列。返回零特征。")
|
197 |
-
# 返回一个与预期特征维度匹配的零向量
|
198 |
-
# 1024 (protT5) + 20 (aac) + 1 (red_ratio) + 3 (phys) + 2 (elec) + 400 (dimer) + L_fixed*d_model_pe (pos_enc)
|
199 |
-
# Example: 1024 + 20 + 1 + 3 + 2 + 400 + 29*16 = 1024 + 20 + 1 + 3 + 2 + 400 + 464 = 1914
|
200 |
return np.zeros(1024 + 20 + 1 + 3 + 2 + 400 + (L_fixed * d_model_pe))
|
201 |
|
202 |
|
@@ -280,7 +255,6 @@ def prepare_features(neg_fasta, pos_fasta, prott5_model_path, additional_params=
|
|
280 |
|
281 |
|
282 |
# --- **关键修改:使用 RobustScaler** ---
|
283 |
-
# scaler = StandardScaler() # 原来的 StandardScaler
|
284 |
scaler = RobustScaler()
|
285 |
print("使用 RobustScaler 进行特征归一化。")
|
286 |
|
@@ -303,10 +277,6 @@ if __name__ == "__main__":
|
|
303 |
if not os.path.exists(pos_fasta_test):
|
304 |
with open(pos_fasta_test, "w") as f: f.write(">pos1\nAOPPEPTIDE\n>pos2\nTRYTRYTRY\n")
|
305 |
|
306 |
-
# 为了让ProtT5Model能加载,需要模拟一个最小的transformers模型目录结构
|
307 |
-
# 通常至少需要 config.json, pytorch_model.bin (或 tf_model.h5), tokenizer_config.json, spiece.model
|
308 |
-
# 这里我们只创建目录,实际加载可能会失败,除非transformers库能从模型名下载
|
309 |
-
# 或者您提供一个真实的本地ProtT5模型路径
|
310 |
if not os.listdir(prott5_path_test): # 如果目录为空
|
311 |
print(f"警告: {prott5_path_test} 为空。ProtT5Model可能尝试从HuggingFace Hub下载模型。")
|
312 |
print(f"请确保您已下载Rostlab/ProstT5-XL-UniRef50或类似模型到该路径,或使用其HuggingFace名称。")
|
|
|
6 |
import pandas as pd
|
7 |
from Bio.SeqUtils.ProtParam import ProteinAnalysis
|
8 |
from sklearn.model_selection import train_test_split
|
|
|
9 |
from sklearn.preprocessing import RobustScaler # 导入 RobustScaler
|
10 |
import torch
|
11 |
from transformers import T5EncoderModel, T5Tokenizer
|
12 |
|
13 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
class ProtT5Model:
|
16 |
"""
|
|
|
68 |
return np.zeros((1, 1024), dtype=np.float32)
|
69 |
return emb
|
70 |
|
|
|
|
|
|
|
|
|
71 |
|
72 |
def load_fasta(fasta_file):
|
73 |
# (您的 load_fasta 实现)
|
|
|
87 |
return sequences
|
88 |
|
89 |
def load_fasta_with_labels(fasta_file):
|
|
|
90 |
sequences, labels = [], []
|
91 |
try:
|
92 |
with open(fasta_file, 'r') as f:
|
|
|
110 |
|
111 |
def compute_amino_acid_composition(seq):
|
112 |
if not seq: return {aa: 0.0 for aa in "ACDEFGHIKLMNPQRSTVWY"}
|
|
|
113 |
amino_acids = "ACDEFGHIKLMNPQRSTVWY"
|
114 |
seq_len = len(seq)
|
115 |
return {aa: seq.upper().count(aa) / seq_len for aa in amino_acids}
|
|
|
117 |
|
118 |
def compute_reducing_aa_ratio(seq):
|
119 |
if not seq: return 0.0
|
|
|
120 |
reducing = ['C', 'M', 'W']
|
121 |
return sum(seq.upper().count(aa) for aa in reducing) / len(seq) if len(seq) > 0 else 0.0
|
122 |
|
|
|
131 |
|
132 |
def compute_electronic_features(seq):
|
133 |
if not seq: return 0.0, 0.0
|
|
|
134 |
electronegativity = {'A':1.8,'C':2.5,'D':3.0,'E':3.2,'F':2.8,'G':1.6,'H':2.4,'I':4.5,'K':3.0,'L':4.2,'M':4.5,'N':2.0,'P':3.5,'Q':3.5,'R':2.5,'S':1.8,'T':2.5,'V':4.0,'W':5.0,'Y':4.0}
|
135 |
values = [electronegativity.get(aa.upper(), 2.5) for aa in seq]
|
136 |
avg_val = sum(values) / len(values) if values else 2.5
|
|
|
139 |
|
140 |
def compute_dimer_frequency(seq):
|
141 |
if len(seq) < 2: return np.zeros(400) # 20*20
|
|
|
142 |
amino_acids = "ACDEFGHIKLMNPQRSTVWY"
|
143 |
dimer_counts = {aa1+aa2: 0 for aa1 in amino_acids for aa2 in amino_acids}
|
144 |
for i in range(len(seq) - 1):
|
|
|
149 |
return np.array([dimer_counts[d] for d in sorted(dimer_counts.keys())])
|
150 |
|
151 |
|
152 |
+
def positional_encoding(seq_len_actual, L_fixed=29, d_model=16):
|
|
|
|
|
|
|
|
|
|
|
153 |
pos_enc = np.zeros((L_fixed, d_model))
|
154 |
for pos in range(L_fixed):
|
155 |
for i in range(d_model):
|
|
|
172 |
def extract_features(seq, prott5_model_instance, L_fixed=29, d_model_pe=16): # Renamed d_model to d_model_pe
|
173 |
if not seq or not isinstance(seq, str) or len(seq) == 0:
|
174 |
print(f"警告: extract_features 接收到空或无效序列。返回零特征。")
|
|
|
|
|
|
|
175 |
return np.zeros(1024 + 20 + 1 + 3 + 2 + 400 + (L_fixed * d_model_pe))
|
176 |
|
177 |
|
|
|
255 |
|
256 |
|
257 |
# --- **关键修改:使用 RobustScaler** ---
|
|
|
258 |
scaler = RobustScaler()
|
259 |
print("使用 RobustScaler 进行特征归一化。")
|
260 |
|
|
|
277 |
if not os.path.exists(pos_fasta_test):
|
278 |
with open(pos_fasta_test, "w") as f: f.write(">pos1\nAOPPEPTIDE\n>pos2\nTRYTRYTRY\n")
|
279 |
|
|
|
|
|
|
|
|
|
280 |
if not os.listdir(prott5_path_test): # 如果目录为空
|
281 |
print(f"警告: {prott5_path_test} 为空。ProtT5Model可能尝试从HuggingFace Hub下载模型。")
|
282 |
print(f"请确保您已下载Rostlab/ProstT5-XL-UniRef50或类似模型到该路径,或使用其HuggingFace名称。")
|