chshan commited on
Commit
02b6e86
·
verified ·
1 Parent(s): ed68bd1

Update feature_extract.py

Browse files
Files changed (1) hide show
  1. feature_extract.py +2 -32
feature_extract.py CHANGED
@@ -6,19 +6,11 @@ import random
6
  import pandas as pd
7
  from Bio.SeqUtils.ProtParam import ProteinAnalysis
8
  from sklearn.model_selection import train_test_split
9
- # from sklearn.preprocessing import StandardScaler # 不再使用 StandardScaler
10
  from sklearn.preprocessing import RobustScaler # 导入 RobustScaler
11
  import torch
12
  from transformers import T5EncoderModel, T5Tokenizer
13
 
14
- # ProtT5Model, load_fasta, load_fasta_with_labels,
15
- # compute_amino_acid_composition, compute_reducing_aa_ratio,
16
- # compute_physicochemical_properties, compute_electronic_features,
17
- # compute_dimer_frequency, positional_encoding, perturb_sequence,
18
- # generate_adversarial_samples, extract_features 函数与您之前提供的版本相同。
19
- # 为保持简洁,此处省略这些函数的代码。请确保它们在您的文件中是完整的。
20
- # 您可以从之前的日志或您本地的文件中复制这些函数。
21
- # 以下是 prepare_features 函数的修改版,以及其他函数的占位符。
22
 
23
  class ProtT5Model:
24
  """
@@ -76,10 +68,6 @@ class ProtT5Model:
76
  return np.zeros((1, 1024), dtype=np.float32)
77
  return emb
78
 
79
- # --- (此处应包含您之前版本中所有其他的特征提取辅助函数) ---
80
- # load_fasta, load_fasta_with_labels, compute_amino_acid_composition, ... extract_features
81
- # 为确保完整性,请从您本地的 feature_extract.py 文件中复制这些函数到这里。
82
- # 下面是这些函数的一个简化占位符,您需要用实际的函数替换它们。
83
 
84
  def load_fasta(fasta_file):
85
  # (您的 load_fasta 实现)
@@ -99,7 +87,6 @@ def load_fasta(fasta_file):
99
  return sequences
100
 
101
  def load_fasta_with_labels(fasta_file):
102
- # (您的 load_fasta_with_labels 实现)
103
  sequences, labels = [], []
104
  try:
105
  with open(fasta_file, 'r') as f:
@@ -123,7 +110,6 @@ def load_fasta_with_labels(fasta_file):
123
 
124
  def compute_amino_acid_composition(seq):
125
  if not seq: return {aa: 0.0 for aa in "ACDEFGHIKLMNPQRSTVWY"}
126
- # (您的 compute_amino_acid_composition 实现)
127
  amino_acids = "ACDEFGHIKLMNPQRSTVWY"
128
  seq_len = len(seq)
129
  return {aa: seq.upper().count(aa) / seq_len for aa in amino_acids}
@@ -131,7 +117,6 @@ def compute_amino_acid_composition(seq):
131
 
132
  def compute_reducing_aa_ratio(seq):
133
  if not seq: return 0.0
134
- # (您的 compute_reducing_aa_ratio 实现)
135
  reducing = ['C', 'M', 'W']
136
  return sum(seq.upper().count(aa) for aa in reducing) / len(seq) if len(seq) > 0 else 0.0
137
 
@@ -146,7 +131,6 @@ def compute_physicochemical_properties(seq):
146
 
147
  def compute_electronic_features(seq):
148
  if not seq: return 0.0, 0.0
149
- # (您的 compute_electronic_features 实现)
150
  electronegativity = {'A':1.8,'C':2.5,'D':3.0,'E':3.2,'F':2.8,'G':1.6,'H':2.4,'I':4.5,'K':3.0,'L':4.2,'M':4.5,'N':2.0,'P':3.5,'Q':3.5,'R':2.5,'S':1.8,'T':2.5,'V':4.0,'W':5.0,'Y':4.0}
151
  values = [electronegativity.get(aa.upper(), 2.5) for aa in seq]
152
  avg_val = sum(values) / len(values) if values else 2.5
@@ -155,7 +139,6 @@ def compute_electronic_features(seq):
155
 
156
  def compute_dimer_frequency(seq):
157
  if len(seq) < 2: return np.zeros(400) # 20*20
158
- # (您的 compute_dimer_frequency 实现)
159
  amino_acids = "ACDEFGHIKLMNPQRSTVWY"
160
  dimer_counts = {aa1+aa2: 0 for aa1 in amino_acids for aa2 in amino_acids}
161
  for i in range(len(seq) - 1):
@@ -166,12 +149,7 @@ def compute_dimer_frequency(seq):
166
  return np.array([dimer_counts[d] for d in sorted(dimer_counts.keys())])
167
 
168
 
169
- def positional_encoding(seq_len_actual, L_fixed=29, d_model=16): # Pass actual sequence length or use L_fixed
170
- # (您的 positional_encoding 实现)
171
- # This PE is fixed length, not dependent on actual seq len if L_fixed is used.
172
- # For random short sequences, this fixed PE might be an issue.
173
- # A more dynamic PE or no PE for very short sequences might be better.
174
- # However, to match current model input, we keep it.
175
  pos_enc = np.zeros((L_fixed, d_model))
176
  for pos in range(L_fixed):
177
  for i in range(d_model):
@@ -194,9 +172,6 @@ def perturb_sequence(seq, perturb_rate=0.1, critical=['C', 'M', 'W']):
194
  def extract_features(seq, prott5_model_instance, L_fixed=29, d_model_pe=16): # Renamed d_model to d_model_pe
195
  if not seq or not isinstance(seq, str) or len(seq) == 0:
196
  print(f"警告: extract_features 接收到空或无效序列。返回零特征。")
197
- # 返回一个与预期特征维度匹配的零向量
198
- # 1024 (protT5) + 20 (aac) + 1 (red_ratio) + 3 (phys) + 2 (elec) + 400 (dimer) + L_fixed*d_model_pe (pos_enc)
199
- # Example: 1024 + 20 + 1 + 3 + 2 + 400 + 29*16 = 1024 + 20 + 1 + 3 + 2 + 400 + 464 = 1914
200
  return np.zeros(1024 + 20 + 1 + 3 + 2 + 400 + (L_fixed * d_model_pe))
201
 
202
 
@@ -280,7 +255,6 @@ def prepare_features(neg_fasta, pos_fasta, prott5_model_path, additional_params=
280
 
281
 
282
  # --- **关键修改:使用 RobustScaler** ---
283
- # scaler = StandardScaler() # 原来的 StandardScaler
284
  scaler = RobustScaler()
285
  print("使用 RobustScaler 进行特征归一化。")
286
 
@@ -303,10 +277,6 @@ if __name__ == "__main__":
303
  if not os.path.exists(pos_fasta_test):
304
  with open(pos_fasta_test, "w") as f: f.write(">pos1\nAOPPEPTIDE\n>pos2\nTRYTRYTRY\n")
305
 
306
- # 为了让ProtT5Model能加载,需要模拟一个最小的transformers模型目录结构
307
- # 通常至少需要 config.json, pytorch_model.bin (或 tf_model.h5), tokenizer_config.json, spiece.model
308
- # 这里我们只创建目录,实际加载可能会失败,除非transformers库能从模型名下载
309
- # 或者您提供一个真实的本地ProtT5模型路径
310
  if not os.listdir(prott5_path_test): # 如果目录为空
311
  print(f"警告: {prott5_path_test} 为空。ProtT5Model可能尝试从HuggingFace Hub下载模型。")
312
  print(f"请确保您已下载Rostlab/ProstT5-XL-UniRef50或类似模型到该路径,或使用其HuggingFace名称。")
 
6
  import pandas as pd
7
  from Bio.SeqUtils.ProtParam import ProteinAnalysis
8
  from sklearn.model_selection import train_test_split
 
9
  from sklearn.preprocessing import RobustScaler # 导入 RobustScaler
10
  import torch
11
  from transformers import T5EncoderModel, T5Tokenizer
12
 
13
+
 
 
 
 
 
 
 
14
 
15
  class ProtT5Model:
16
  """
 
68
  return np.zeros((1, 1024), dtype=np.float32)
69
  return emb
70
 
 
 
 
 
71
 
72
  def load_fasta(fasta_file):
73
  # (您的 load_fasta 实现)
 
87
  return sequences
88
 
89
  def load_fasta_with_labels(fasta_file):
 
90
  sequences, labels = [], []
91
  try:
92
  with open(fasta_file, 'r') as f:
 
110
 
111
  def compute_amino_acid_composition(seq):
112
  if not seq: return {aa: 0.0 for aa in "ACDEFGHIKLMNPQRSTVWY"}
 
113
  amino_acids = "ACDEFGHIKLMNPQRSTVWY"
114
  seq_len = len(seq)
115
  return {aa: seq.upper().count(aa) / seq_len for aa in amino_acids}
 
117
 
118
  def compute_reducing_aa_ratio(seq):
119
  if not seq: return 0.0
 
120
  reducing = ['C', 'M', 'W']
121
  return sum(seq.upper().count(aa) for aa in reducing) / len(seq) if len(seq) > 0 else 0.0
122
 
 
131
 
132
  def compute_electronic_features(seq):
133
  if not seq: return 0.0, 0.0
 
134
  electronegativity = {'A':1.8,'C':2.5,'D':3.0,'E':3.2,'F':2.8,'G':1.6,'H':2.4,'I':4.5,'K':3.0,'L':4.2,'M':4.5,'N':2.0,'P':3.5,'Q':3.5,'R':2.5,'S':1.8,'T':2.5,'V':4.0,'W':5.0,'Y':4.0}
135
  values = [electronegativity.get(aa.upper(), 2.5) for aa in seq]
136
  avg_val = sum(values) / len(values) if values else 2.5
 
139
 
140
  def compute_dimer_frequency(seq):
141
  if len(seq) < 2: return np.zeros(400) # 20*20
 
142
  amino_acids = "ACDEFGHIKLMNPQRSTVWY"
143
  dimer_counts = {aa1+aa2: 0 for aa1 in amino_acids for aa2 in amino_acids}
144
  for i in range(len(seq) - 1):
 
149
  return np.array([dimer_counts[d] for d in sorted(dimer_counts.keys())])
150
 
151
 
152
+ def positional_encoding(seq_len_actual, L_fixed=29, d_model=16):
 
 
 
 
 
153
  pos_enc = np.zeros((L_fixed, d_model))
154
  for pos in range(L_fixed):
155
  for i in range(d_model):
 
172
  def extract_features(seq, prott5_model_instance, L_fixed=29, d_model_pe=16): # Renamed d_model to d_model_pe
173
  if not seq or not isinstance(seq, str) or len(seq) == 0:
174
  print(f"警告: extract_features 接收到空或无效序列。返回零特征。")
 
 
 
175
  return np.zeros(1024 + 20 + 1 + 3 + 2 + 400 + (L_fixed * d_model_pe))
176
 
177
 
 
255
 
256
 
257
  # --- **关键修改:使用 RobustScaler** ---
 
258
  scaler = RobustScaler()
259
  print("使用 RobustScaler 进行特征归一化。")
260
 
 
277
  if not os.path.exists(pos_fasta_test):
278
  with open(pos_fasta_test, "w") as f: f.write(">pos1\nAOPPEPTIDE\n>pos2\nTRYTRYTRY\n")
279
 
 
 
 
 
280
  if not os.listdir(prott5_path_test): # 如果目录为空
281
  print(f"警告: {prott5_path_test} 为空。ProtT5Model可能尝试从HuggingFace Hub下载模型。")
282
  print(f"请确保您已下载Rostlab/ProstT5-XL-UniRef50或类似模型到该路径,或使用其HuggingFace名称。")