#!/usr/bin/env python # -*- coding: utf-8 -*- import os import numpy as np import random import pandas as pd from Bio.SeqUtils.ProtParam import ProteinAnalysis from sklearn.model_selection import train_test_split from sklearn.preprocessing import RobustScaler # 导入 RobustScaler import torch from transformers import T5EncoderModel, T5Tokenizer class ProtT5Model: """ 从本地加载 ProtT5 模型。如果 finetuned_model_file 不为空,则加载微调后的权重(使用 strict=False)。 """ def __init__(self, model_path, finetuned_model_file=None): self.device = "cuda" if torch.cuda.is_available() else "cpu" # 尝试加载本地文件,如果失败,transformers库可能会尝试从hub下载(取决于配置) try: self.tokenizer = T5Tokenizer.from_pretrained(model_path, do_lower_case=False, local_files_only=True) self.model = T5EncoderModel.from_pretrained(model_path, local_files_only=True) except OSError: # OSError: Can't load tokenizer for '...'. If you were trying to load it from 'https://huggingface.co/models', make sure you don't have a local directory with the same name. Otherwise, make sure '...' is the correct path to a directory containing all relevant files for a T5Tokenizer tokenizer. print(f"警告: 无法从本地路径 {model_path} 加载ProtT5模型/分词器。尝试从HuggingFace Hub下载(如果transformers配置允许)。") self.tokenizer = T5Tokenizer.from_pretrained(model_path.split('/')[-1] if '/' in model_path else model_path, do_lower_case=False) # 尝试使用模型名下载 self.model = T5EncoderModel.from_pretrained(model_path.split('/')[-1] if '/' in model_path else model_path) if finetuned_model_file is not None and os.path.exists(finetuned_model_file): try: state_dict = torch.load(finetuned_model_file, map_location=self.device) missing_keys, unexpected_keys = self.model.load_state_dict(state_dict, strict=False) print(f"加载微调权重 {finetuned_model_file}:缺失键 {missing_keys}, 意外键 {unexpected_keys}") except Exception as e: print(f"加载微调权重 {finetuned_model_file} 失败: {e}") self.model.to(self.device) self.model.eval() def encode(self, sequence): if not sequence or not isinstance(sequence, str): # 增加对空序列或非字符串的检查 print(f"警告: ProtT5Model.encode 接收到无效序列: {sequence}") # 返回一个零向量或根据需要处理错误 # 假设 ProtT5 输出维度为 1024 (embedding.shape[1]) # 假设序列处理后平均池化,所以返回 (1024,) # 但 encode 返回的是 (seq_len, hidden_dim),所以这里返回一个模拟的短序列零嵌入 return np.zeros((1, 1024), dtype=np.float32) # (1, hidden_dim) seq_spaced = " ".join(list(sequence)) # 修改变量名以避免覆盖外部seq try: encoded_input = self.tokenizer(seq_spaced, return_tensors='pt', padding=True, truncation=True, max_length=1022) # ProtT5通常最大长度1024,tokenized后可能更长 except Exception as e: print(f"分词失败序列 '{sequence[:30]}...': {e}") return np.zeros((1, 1024), dtype=np.float32) encoded_input = {k: v.to(self.device) for k, v in encoded_input.items()} with torch.no_grad(): try: embedding = self.model(**encoded_input).last_hidden_state # (batch_size, seq_len, hidden_dim) except Exception as e: print(f"ProtT5模型推理失败序列 '{sequence[:30]}...': {e}") return np.zeros((1, 1024), dtype=np.float32) emb = embedding.squeeze(0).cpu().numpy() # (seq_len, hidden_dim) if emb.shape[0] == 0: # 如果由于某种原因序列长度为0 return np.zeros((1, 1024), dtype=np.float32) return emb def load_fasta(fasta_file): # (您的 load_fasta 实现) sequences = [] try: with open(fasta_file, 'r') as f: current_seq_lines = [] for line in f: line = line.strip() if not line: continue if line.startswith(">"): if current_seq_lines: sequences.append("".join(current_seq_lines)) current_seq_lines = [] else: current_seq_lines.append(line) if current_seq_lines: sequences.append("".join(current_seq_lines)) except FileNotFoundError: print(f"文件未找到: {fasta_file}"); return [] return sequences def load_fasta_with_labels(fasta_file): sequences, labels = [], [] try: with open(fasta_file, 'r') as f: current_seq_lines, current_label = [], None for line in f: line = line.strip() if not line: continue if line.startswith(">"): if current_seq_lines: sequences.append("".join(current_seq_lines)) labels.append(current_label if current_label is not None else 0) # Default label 0 current_seq_lines = [] current_label = int(line[1]) if len(line) > 1 and line[1] in ['0', '1'] else 0 else: current_seq_lines.append(line) if current_seq_lines: sequences.append("".join(current_seq_lines)) labels.append(current_label if current_label is not None else 0) except FileNotFoundError: print(f"文件未找到: {fasta_file}"); return [],[] return sequences, labels def compute_amino_acid_composition(seq): if not seq: return {aa: 0.0 for aa in "ACDEFGHIKLMNPQRSTVWY"} amino_acids = "ACDEFGHIKLMNPQRSTVWY" seq_len = len(seq) return {aa: seq.upper().count(aa) / seq_len for aa in amino_acids} def compute_reducing_aa_ratio(seq): if not seq: return 0.0 reducing = ['C', 'M', 'W'] return sum(seq.upper().count(aa) for aa in reducing) / len(seq) if len(seq) > 0 else 0.0 def compute_physicochemical_properties(seq): if not seq or not all(c.upper() in "ACDEFGHIKLMNPQRSTVWYXUBZ" for c in seq): # ProteinAnalysis might fail on invalid chars return 0.0, 0.0, 0.0 # Default values try: analysis = ProteinAnalysis(str(seq).upper().replace('X','A').replace('U','C').replace('B','N').replace('Z','Q')) # Replace non-standard with common ones for analysis return analysis.gravy(), analysis.isoelectric_point(), analysis.molecular_weight() except Exception: # Catch any error from ProteinAnalysis return 0.0, 7.0, 110.0 * len(seq) # Rough defaults def compute_electronic_features(seq): if not seq: return 0.0, 0.0 electronegativity = {'A':1.8,'C':2.5,'D':3.0,'E':3.2,'F':2.8,'G':1.6,'H':2.4,'I':4.5,'K':3.0,'L':4.2,'M':4.5,'N':2.0,'P':3.5,'Q':3.5,'R':2.5,'S':1.8,'T':2.5,'V':4.0,'W':5.0,'Y':4.0} values = [electronegativity.get(aa.upper(), 2.5) for aa in seq] avg_val = sum(values) / len(values) if values else 2.5 return avg_val + 0.1, avg_val - 0.1 def compute_dimer_frequency(seq): if len(seq) < 2: return np.zeros(400) # 20*20 amino_acids = "ACDEFGHIKLMNPQRSTVWY" dimer_counts = {aa1+aa2: 0 for aa1 in amino_acids for aa2 in amino_acids} for i in range(len(seq) - 1): dimer = seq[i:i+2].upper() if dimer in dimer_counts: dimer_counts[dimer] += 1 total = max(len(seq) - 1, 1) for key in dimer_counts: dimer_counts[key] /= total return np.array([dimer_counts[d] for d in sorted(dimer_counts.keys())]) def positional_encoding(seq_len_actual, L_fixed=29, d_model=16): pos_enc = np.zeros((L_fixed, d_model)) for pos in range(L_fixed): for i in range(d_model): angle = pos / (10000 ** (2 * (i // 2) / d_model)) pos_enc[pos, i] = np.sin(angle) if i % 2 == 0 else np.cos(angle) return pos_enc.flatten() def perturb_sequence(seq, perturb_rate=0.1, critical=['C', 'M', 'W']): # (您的 perturb_sequence 实现) if not seq: return "" seq_list = list(seq) amino_acids = "ACDEFGHIKLMNPQRSTVWY" for i, aa in enumerate(seq_list): if aa.upper() not in critical and random.random() < perturb_rate: seq_list[i] = random.choice([x for x in amino_acids if x != aa.upper()]) return "".join(seq_list) def extract_features(seq, prott5_model_instance, L_fixed=29, d_model_pe=16): # Renamed d_model to d_model_pe if not seq or not isinstance(seq, str) or len(seq) == 0: print(f"警告: extract_features 接收到空或无效序列。返回零特征。") return np.zeros(1024 + 20 + 1 + 3 + 2 + 400 + (L_fixed * d_model_pe)) embedding = prott5_model_instance.encode(seq) # prott5_model is now an instance prot_embed = np.mean(embedding, axis=0) if embedding.shape[0] > 0 else np.zeros(embedding.shape[1] if embedding.ndim > 1 else 1024) # Handle empty embedding if prot_embed.shape[0] != 1024: # Ensure consistent ProtT5 embedding dim # print(f"警告: ProtT5 嵌入维度异常 ({prot_embed.shape[0]}) for seq '{seq[:20]}...'. 使用零向量。") prot_embed = np.zeros(1024) aa_comp = compute_amino_acid_composition(seq) aa_comp_vector = np.array([aa_comp[aa] for aa in "ACDEFGHIKLMNPQRSTVWY"]) red_ratio = np.array([compute_reducing_aa_ratio(seq)]) gravy, pI, mol_weight = compute_physicochemical_properties(seq) phys_props = np.array([gravy, pI, mol_weight]) HOMO, LUMO = compute_electronic_features(seq) electronic = np.array([HOMO, LUMO]) dimer_vector = compute_dimer_frequency(seq) pos_enc = positional_encoding(len(seq), L_fixed, d_model_pe) # Pass actual length, though current PE uses L_fixed features = np.concatenate([prot_embed, aa_comp_vector, red_ratio, phys_props, electronic, dimer_vector, pos_enc]) return features ############################################## # 主接口函数 prepare_features ############################################## def prepare_features(neg_fasta, pos_fasta, prott5_model_path, additional_params=None): neg_seqs = load_fasta(neg_fasta) pos_seqs = load_fasta(pos_fasta) if not neg_seqs and not pos_seqs: raise ValueError("未能从FASTA文件加载任何序列。请检查文件路径和内容。") neg_labels = [0] * len(neg_seqs) pos_labels = [1] * len(pos_seqs) sequences = neg_seqs + pos_seqs labels = neg_labels + pos_labels combined = list(zip(sequences, labels)) random.shuffle(combined) sequences, labels = zip(*combined) sequences = list(sequences) labels = list(labels) train_seqs, val_seqs, train_labels, val_labels = train_test_split( sequences, labels, test_size=0.1, random_state=42, stratify=labels if len(np.unique(labels)) > 1 else None ) print("训练集原始样本数:", len(train_seqs)) print("验证集原始样本数:", len(val_seqs)) if additional_params is not None and additional_params.get("augment", False): # (数据增强逻辑 - 如果启用) augmented_seqs, augmented_labels = [], [] perturb_rate = additional_params.get("perturb_rate", 0.1) for seq, label in zip(train_seqs, train_labels): aug_seq = perturb_sequence(seq, perturb_rate=perturb_rate) augmented_seqs.append(aug_seq) augmented_labels.append(label) train_seqs.extend(augmented_seqs) train_labels.extend(augmented_labels) print("数据增强后训练集样本数:", len(train_seqs)) finetuned_model_file = additional_params.get("finetuned_model_file") if additional_params else None # 创建 ProtT5Model 实例 prott5_model_instance = ProtT5Model(prott5_model_path, finetuned_model_file=finetuned_model_file) def process_data(seqs_list): # Renamed seqs to seqs_list feature_list = [] for s_item in seqs_list: # Renamed s to s_item # 将 ProtT5Model 实例传递给 extract_features features = extract_features(s_item, prott5_model_instance) feature_list.append(features) return np.array(feature_list) X_train = process_data(train_seqs) X_val = process_data(val_seqs) if X_train.shape[0] == 0 or X_val.shape[0] == 0: raise ValueError("特征提取后训练集或验证集为空。请检查序列数据和特征提取过程。") # --- **关键修改:使用 RobustScaler** --- scaler = RobustScaler() print("使用 RobustScaler 进行特征归一化。") X_train_scaled = scaler.fit_transform(X_train) X_val_scaled = scaler.transform(X_val) return X_train_scaled, X_val_scaled, np.array(train_labels), np.array(val_labels), scaler if __name__ == "__main__": # 确保测试时使用的路径是有效的,或者创建虚拟文件 neg_fasta_test = "dummy_data/test_neg.fasta" pos_fasta_test = "dummy_data/test_pos.fasta" prott5_path_test = "dummy_prott5_model/" # 需要一个包含config.json, pytorch_model.bin等的目录结构 os.makedirs("dummy_data", exist_ok=True) os.makedirs(prott5_path_test, exist_ok=True) # 创建虚拟模型目录 if not os.path.exists(neg_fasta_test): with open(neg_fasta_test, "w") as f: f.write(">neg1\nKALKALKALK\n>neg2\nPEPTPEPT\n") if not os.path.exists(pos_fasta_test): with open(pos_fasta_test, "w") as f: f.write(">pos1\nAOPPEPTIDE\n>pos2\nTRYTRYTRY\n") if not os.listdir(prott5_path_test): # 如果目录为空 print(f"警告: {prott5_path_test} 为空。ProtT5Model可能尝试从HuggingFace Hub下载模型。") print(f"请确保您已下载Rostlab/ProstT5-XL-UniRef50或类似模型到该路径,或使用其HuggingFace名称。") # 作为演示,我们假设用户会提供一个有效的路径或transformers可以处理它 # 如果要完全本地运行而不下载,需要填充该目录。 additional_params_test = { "augment": False, "perturb_rate": 0.1, "finetuned_model_file": None # 测试时不使用微调模型 } print("开始测试 prepare_features (使用RobustScaler)...") try: X_train_t, X_val_t, y_train_t, y_val_t, scaler_t = prepare_features( neg_fasta_test, pos_fasta_test, "Rostlab/ProstT5-XL-UniRef50", # 使用HuggingFace模型名称,如果本地路径无效 additional_params_test ) print("prepare_features 测试完成。") print("训练集样本数:", X_train_t.shape[0]) print("验证集样本数:", X_val_t.shape[0]) if X_train_t.shape[0] > 0: print("训练集特征维度:", X_train_t.shape[1]) print("一个缩放后的训练样本 (前5个特征):", X_train_t[0, :5]) if scaler_t: print("Scaler类型:", type(scaler_t)) except Exception as e: print(f"prepare_features 测试失败: {e}") print("这可能是由于无法加载ProtT5模型或FASTA文件处理问题。请检查路径和文件内容。")