Spaces:
Runtime error
Runtime error
import os | |
import json | |
import sys | |
import tempfile | |
import gradio as gr | |
from typing import Dict, List, Optional | |
from pathlib import Path | |
from Bio import SeqIO | |
from io import StringIO | |
# 添加必要的路径 | |
root_path = os.path.dirname(os.path.abspath(__file__)) | |
sys.path.append(root_path) | |
sys.path.append(os.path.join(root_path, "Models/ProTrek")) | |
# 导入所需模块 | |
from interproscan import InterproScan | |
from Bio.Blast.Applications import NcbiblastpCommandline | |
from utils.utils import extract_interproscan_metrics, get_seqnid, extract_blast_metrics, rename_interproscan_keys | |
from go_integration_pipeline import GOIntegrationPipeline | |
from utils.openai_access import call_chatgpt | |
from utils.prompts import FUNCTION_PROMPT | |
def get_prompt_template(selected_info_types=None): | |
""" | |
获取prompt模板,支持可选的信息类型 | |
Args: | |
selected_info_types: 需要包含的信息类型列表,如['motif', 'go', 'superfamily', 'panther'] | |
""" | |
if selected_info_types is None: | |
selected_info_types = ['motif', 'go'] # 默认包含motif和go信息 | |
PROMPT_TEMPLATE = FUNCTION_PROMPT + '\n' + """ | |
input information: | |
{%- if 'motif' in selected_info_types and motif_pfam %} | |
motif:{% for motif_id, motif_info in motif_pfam.items() %} | |
{{motif_id}}: {{motif_info}} | |
{% endfor %} | |
{%- endif %} | |
{%- if 'go' in selected_info_types and go_data.status == 'success' %} | |
GO:{% for go_entry in go_data.go_annotations %} | |
▢ GO term{{loop.index}}: {{go_entry.go_id}} | |
• definition: {{ go_data.all_related_definitions.get(go_entry.go_id, 'not found definition') }} | |
{% endfor %} | |
{%- endif %} | |
{%- for info_type in selected_info_types %} | |
{%- if info_type not in ['motif', 'go'] and interpro_descriptions.get(info_type) %} | |
{{info_type}}:{% for ipr_id, ipr_info in interpro_descriptions[info_type].items() %} | |
▢ {{ipr_id}}: {{ipr_info.name}} | |
• description: {{ipr_info.abstract}} | |
{% endfor %} | |
{%- endif %} | |
{%- endfor %} | |
question: \n {{question}} | |
""" | |
return PROMPT_TEMPLATE | |
class ProteinAnalysisDemo: | |
def __init__(self): | |
""" | |
蛋白质分析演示类 | |
""" | |
self.blast_database = "uniprot_swissprot" | |
self.expect_value = 0.01 | |
self.interproscan_path = "interproscan/interproscan-5.75-106.0/interproscan.sh" | |
self.interproscan_libraries = [ | |
"PFAM", "PIRSR", "PROSITE_PROFILES", "SUPERFAMILY", "PRINTS", | |
"PANTHER", "CDD", "GENE3D", "NCBIFAM", "SFLM", "MOBIDB_LITE", | |
"COILS", "PROSITE_PATTERNS", "FUNFAM", "SMART" | |
] | |
self.go_topk = 2 | |
self.selected_info_types = ['motif', 'go'] | |
# 文件路径配置 | |
self.pfam_descriptions_path = 'data/raw_data/all_pfam_descriptions.json' | |
self.go_info_path = 'data/raw_data/go.json' | |
self.interpro_data_path = 'data/raw_data/interpro_data.json' | |
# 初始化GO整合管道 | |
self.go_pipeline = GOIntegrationPipeline(topk=self.go_topk) | |
# 初始化InterPro管理器(如果需要) | |
self.interpro_manager = None | |
other_types = [t for t in self.selected_info_types if t not in ['motif', 'go']] | |
if other_types and os.path.exists(self.interpro_data_path): | |
try: | |
from utils.generate_protein_prompt import get_interpro_manager | |
self.interpro_manager = get_interpro_manager(self.interpro_data_path, None) | |
except Exception as e: | |
print(f"初始化InterPro管理器失败: {str(e)}") | |
def validate_protein_sequence(self, sequence: str) -> bool: | |
""" | |
验证蛋白质序列格式 | |
""" | |
if not sequence: | |
return False | |
# 移除空白字符 | |
sequence = sequence.strip().upper() | |
# 检查是否包含有效的氨基酸字符 | |
valid_aa = set('ACDEFGHIKLMNPQRSTVWY') | |
sequence_chars = set(sequence.replace('\n', '').replace(' ', '')) | |
return sequence_chars.issubset(valid_aa) and len(sequence) > 0 | |
def parse_fasta_content(self, fasta_content: str) -> tuple: | |
""" | |
解析FASTA内容,返回第一个序列 | |
""" | |
try: | |
fasta_io = StringIO(fasta_content) | |
records = list(SeqIO.parse(fasta_io, "fasta")) | |
if not records: | |
return None, "FASTA文件中没有找到有效的序列" | |
if len(records) > 1: | |
return None, "演示版本只支持单一序列,检测到多个序列" | |
record = records[0] | |
return str(record.seq), f"成功解析序列 ID: {record.id}" | |
except Exception as e: | |
return None, f"解析FASTA文件出错: {str(e)}" | |
def create_temp_fasta(self, sequence: str, seq_id: str = "demo_protein") -> str: | |
""" | |
创建临时FASTA文件 | |
""" | |
temp_file = tempfile.NamedTemporaryFile(mode='w', suffix='.fasta', delete=False) | |
temp_file.write(f">{seq_id}\n{sequence}\n") | |
temp_file.close() | |
return temp_file.name | |
def run_blast_analysis(self, fasta_file: str, temp_dir: str) -> Dict: | |
""" | |
运行BLAST分析 | |
""" | |
blast_xml = os.path.join(temp_dir, "blast_results.xml") | |
try: | |
blast_cmd = NcbiblastpCommandline( | |
query=fasta_file, | |
db=self.blast_database, | |
out=blast_xml, | |
outfmt=5, # XML格式 | |
evalue=self.expect_value | |
) | |
blast_cmd() | |
# 提取BLAST结果 | |
blast_results = extract_blast_metrics(blast_xml) | |
# 获取序列字典 | |
seq_dict = get_seqnid(fasta_file) | |
blast_info = {} | |
for uid, info in blast_results.items(): | |
blast_info[uid] = {"sequence": seq_dict[uid], "blast_results": info} | |
return blast_info | |
except Exception as e: | |
print(f"BLAST分析出错: {str(e)}") | |
return {} | |
finally: | |
if os.path.exists(blast_xml): | |
os.remove(blast_xml) | |
def run_interproscan_analysis(self, fasta_file: str, temp_dir: str) -> Dict: | |
""" | |
运行InterProScan分析 | |
""" | |
interproscan_json = os.path.join(temp_dir, "interproscan_results.json") | |
try: | |
interproscan = InterproScan(self.interproscan_path) | |
input_args = { | |
"fasta_file": fasta_file, | |
"goterms": True, | |
"pathways": True, | |
"save_dir": interproscan_json | |
} | |
interproscan.run(**input_args) | |
# 提取InterProScan结果 | |
interproscan_results = extract_interproscan_metrics( | |
interproscan_json, | |
librarys=self.interproscan_libraries | |
) | |
# 获取序列字典 | |
seq_dict = get_seqnid(fasta_file) | |
interproscan_info = {} | |
for id, seq in seq_dict.items(): | |
info = interproscan_results[seq] | |
info = rename_interproscan_keys(info) | |
interproscan_info[id] = {"sequence": seq, "interproscan_results": info} | |
return interproscan_info | |
except Exception as e: | |
print(f"InterProScan分析出错: {str(e)}") | |
return {} | |
finally: | |
if os.path.exists(interproscan_json): | |
os.remove(interproscan_json) | |
def generate_prompt(self, protein_id: str, interproscan_info: Dict, | |
protein_go_dict: Dict, question: str) -> str: | |
""" | |
从内存中的数据生成prompt,包含完整的motif和GO定义 | |
""" | |
try: | |
from utils.protein_go_analysis import get_go_definition | |
from jinja2 import Template | |
# from utils.generate_protein_prompt import get_prompt_template | |
# 获取GO分析结果 | |
go_ids = protein_go_dict.get(protein_id, []) | |
go_annotations = [] | |
all_related_definitions = {} | |
if go_ids: | |
for go_id in go_ids: | |
# 确保GO ID格式正确 | |
clean_go_id = go_id.split(":")[-1] if ":" in go_id else go_id | |
go_annotations.append({"go_id": clean_go_id}) | |
# 获取GO定义 | |
if os.path.exists(self.go_info_path): | |
definition = get_go_definition(clean_go_id, self.go_info_path) | |
if definition: | |
all_related_definitions[clean_go_id] = definition | |
# 获取motif信息 | |
motif_pfam = {} | |
if os.path.exists(self.pfam_descriptions_path): | |
try: | |
# 从interproscan结果中提取pfam信息 | |
interproscan_results = interproscan_info[protein_id].get('interproscan_results', {}) | |
pfam_entries = interproscan_results.get('pfam_id', []) | |
# 加载pfam描述 | |
with open(self.pfam_descriptions_path, 'r') as f: | |
pfam_descriptions = json.load(f) | |
# 构建motif_pfam字典 | |
for entry in pfam_entries: | |
for pfam_id, ipr_id in entry.items(): | |
if pfam_id and pfam_id in pfam_descriptions: | |
motif_pfam[pfam_id] = pfam_descriptions[pfam_id]['description'] | |
except Exception as e: | |
print(f"获取motif信息时出错: {str(e)}") | |
# 获取InterPro描述信息 | |
interpro_descriptions = {} | |
other_types = [t for t in self.selected_info_types if t not in ['motif', 'go']] | |
if other_types and self.interpro_manager: | |
interpro_descriptions = self.interpro_manager.get_description(protein_id, other_types) | |
# 准备模板数据 | |
template_data = { | |
"protein_id": protein_id, | |
"selected_info_types": self.selected_info_types, | |
"go_data": { | |
"status": "success" if go_annotations else "no_data", | |
"go_annotations": go_annotations, | |
"all_related_definitions": all_related_definitions | |
}, | |
"motif_pfam": motif_pfam, | |
"interpro_descriptions": interpro_descriptions, | |
"question": question | |
} | |
# 使用模板生成prompt | |
PROMPT_TEMPLATE = get_prompt_template(self.selected_info_types) # demo版本不使用lmdb | |
template = Template(PROMPT_TEMPLATE) | |
return template.render(**template_data) | |
except Exception as e: | |
print(f"生成prompt时出错 (protein_id: {protein_id}): {str(e)}") | |
# 如果出错,返回简化版本的prompt | |
return self._generate_fallback_prompt(protein_id, interproscan_info, protein_go_dict, question) | |
def _generate_fallback_prompt(self, protein_id: str, interproscan_info: Dict, | |
protein_go_dict: Dict, question: str) -> str: | |
""" | |
生成备用prompt(当主要方法失败时使用) | |
""" | |
from utils.prompts import FUNCTION_PROMPT | |
prompt_parts = [FUNCTION_PROMPT] | |
prompt_parts.append("\ninput information:") | |
# 添加motif信息 | |
if 'motif' in self.selected_info_types: | |
interproscan_results = interproscan_info[protein_id].get('interproscan_results', {}) | |
pfam_entries = interproscan_results.get('pfam_id', []) | |
if pfam_entries: | |
prompt_parts.append("\nmotif:") | |
for entry in pfam_entries: | |
for key, value in entry.items(): | |
if value: | |
prompt_parts.append(f"{value}: motif information") | |
# 添加GO信息 | |
if 'go' in self.selected_info_types: | |
go_ids = protein_go_dict.get(protein_id, []) | |
if go_ids: | |
prompt_parts.append("\nGO:") | |
for i, go_id in enumerate(go_ids[:10], 1): | |
prompt_parts.append(f"▢ GO term{i}: {go_id}") | |
prompt_parts.append(f"• definition: GO term definition") | |
# 添加用户问题 | |
prompt_parts.append(f"\nquestion: \n{question}") | |
return "\n".join(prompt_parts) | |
def analyze_protein(self, sequence_input: str, fasta_file, question: str) -> str: | |
""" | |
分析蛋白质序列并回答问题 | |
""" | |
if not question.strip(): | |
return "请输入您的问题" | |
# 确定使用哪个序列输入 | |
final_sequence = None | |
sequence_source = "" | |
if fasta_file is not None: | |
# 优先使用上传的文件 | |
try: | |
fasta_content = fasta_file.read().decode('utf-8') | |
final_sequence, parse_msg = self.parse_fasta_content(fasta_content) | |
if final_sequence is None: | |
return f"文件解析错误: {parse_msg}" | |
sequence_source = f"来自上传文件: {parse_msg}" | |
except Exception as e: | |
return f"读取上传文件出错: {str(e)}" | |
elif sequence_input.strip(): | |
# 使用文本框输入的序列 | |
if self.validate_protein_sequence(sequence_input): | |
final_sequence = sequence_input.strip().upper().replace('\n', '').replace(' ', '') | |
sequence_source = "来自文本框输入" | |
else: | |
return "输入的序列格式不正确,请输入有效的蛋白质序列" | |
else: | |
return "请输入蛋白质序列或上传FASTA文件" | |
# 创建临时目录和文件 | |
with tempfile.TemporaryDirectory() as temp_dir: | |
try: | |
# 创建临时FASTA文件 | |
temp_fasta = self.create_temp_fasta(final_sequence, "demo_protein") | |
# 运行分析 | |
status_msg = f"序列来源: {sequence_source}\n序列长度: {len(final_sequence)} 氨基酸\n\n正在进行分析...\n" | |
# 步骤1: BLAST和InterProScan分析 | |
status_msg += "步骤1: 运行BLAST分析...\n" | |
blast_info = self.run_blast_analysis(temp_fasta, temp_dir) | |
status_msg += "步骤2: 运行InterProScan分析...\n" | |
interproscan_info = self.run_interproscan_analysis(temp_fasta, temp_dir) | |
if not blast_info or not interproscan_info: | |
return status_msg + "分析失败: 无法获取BLAST或InterProScan结果" | |
# 步骤2: 整合GO信息 | |
status_msg += "步骤3: 整合GO信息...\n" | |
protein_go_dict = self.go_pipeline.first_level_filtering(interproscan_info, blast_info) | |
# 步骤3: 生成prompt | |
status_msg += "步骤4: 生成分析prompt...\n" | |
protein_id = "demo_protein" | |
prompt = self.generate_prompt(protein_id, interproscan_info, protein_go_dict, question) | |
# 步骤4: 调用LLM生成答案 | |
status_msg += "步骤5: 生成答案...\n" | |
llm_response = call_chatgpt(prompt) | |
# 组织最终结果 | |
result = f""" | |
{status_msg} | |
=== 分析完成 === | |
问题: {question} | |
答案: {llm_response} | |
=== 分析详情 === | |
- BLAST匹配数: {len(blast_info.get(protein_id, {}).get('blast_results', []))} | |
- InterProScan域数: {len(interproscan_info.get(protein_id, {}).get('interproscan_results', {}).get('pfam_id', []))} | |
- GO术语数: {len(protein_go_dict.get(protein_id, []))} | |
""" | |
return result | |
except Exception as e: | |
return f"分析过程中出错: {str(e)}" | |
finally: | |
# 清理临时文件 | |
if 'temp_fasta' in locals() and os.path.exists(temp_fasta): | |
os.remove(temp_fasta) | |
def create_demo(): | |
""" | |
创建Gradio演示界面 | |
""" | |
analyzer = ProteinAnalysisDemo() | |
with gr.Blocks(title="蛋白质功能分析演示") as demo: | |
gr.Markdown("# 🧬 蛋白质功能分析演示") | |
gr.Markdown("输入蛋白质序列和问题,AI将基于BLAST、InterProScan和GO信息为您提供专业分析") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("### 📝 序列输入") | |
sequence_input = gr.Textbox( | |
label="蛋白质序列", | |
placeholder="请输入蛋白质序列(单字母氨基酸代码)...", | |
lines=5, | |
max_lines=10 | |
) | |
gr.Markdown("**或者**") | |
fasta_file = gr.File( | |
label="上传FASTA文件", | |
file_types=[".fasta", ".fa", ".fas"], | |
file_count="single" | |
) | |
gr.Markdown("### ❓ 您的问题") | |
question_input = gr.Textbox( | |
label="问题", | |
placeholder="请输入关于该蛋白质的问题,例如:这个蛋白质的主要功能是什么?", | |
lines=3 | |
) | |
analyze_btn = gr.Button("🔍 开始分析", variant="primary", size="lg") | |
with gr.Column(scale=2): | |
gr.Markdown("### 📊 分析结果") | |
output = gr.Textbox( | |
label="分析结果", | |
lines=20, | |
max_lines=30, | |
show_copy_button=True | |
) | |
# 示例 | |
gr.Markdown("### 💡 示例") | |
gr.Examples( | |
examples=[ | |
["MKALIVLGLVLLSVTVQGKVFERCELARTLKRLGMDGYRGISLANWMCLAKWESGYNTRATNYNAGDRSTDYGIFQINSRYWCNDGKTPGAVNACHLSCSALLQDNIADAVACAKRVVRDPQGIRAWVAWRNRCQNRDVRQYVQGCGV", "这个蛋白质的主要功能是什么?"], | |
["MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFPTSREJ", "这个蛋白质可能参与哪些生物学过程?"], | |
["ATGAGTGAACGTCTGAAATCTATCATCACCGTCGACGACGAGAACGTCAAGCTGATCGACAAGATCCTGGCCTCCATCAAGGACCTGAACGAGCTGGTGGACATGATCGACGAGATCAAGAACGTCGACGACGAGCTGATCGACAAGATCCTGGCC", "这个序列编码的蛋白质具有什么结构特征?"] | |
], | |
inputs=[sequence_input, question_input] | |
) | |
analyze_btn.click( | |
fn=analyzer.analyze_protein, | |
inputs=[sequence_input, fasta_file, question_input], | |
outputs=[output] | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = create_demo() | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=30002, | |
share=True, | |
debug=False | |
) |