gradsyntax's picture
"Integrated DNABERT for real mutation detection"
879cd81 verified
import gradio as gr
import pandas as pd
import torch
from transformers import AutoTokenizer, BertForSequenceClassification
# Load DNABERT tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True)
model = BertForSequenceClassification.from_pretrained("zhihan1996/DNABERT-2-117M")
# Mutation classes (example mapping β€” update based on your fine-tuning)
mutation_map = {
0: "No Mutation",
1: "SNV",
2: "Insertion",
3: "Deletion"
}
# Simulates mutation detection using DNABERT
def analyze_sequences(input_df):
if input_df is None or input_df.empty:
return pd.DataFrame(columns=["Sequence", "Predicted Mutation", "Confidence Score"])
results = []
for _, row in input_df.iterrows():
seq = row['DNA_Sequence']
# Tokenize and run inference
inputs = tokenizer(seq, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class = torch.argmax(logits, dim=1).item()
confidence = float(torch.softmax(logits, dim=1)[0][predicted_class].item())
# Map prediction to mutation type
mutation = mutation_map.get(predicted_class, "Unknown")
results.append({
"Sequence": seq,
"Predicted Mutation": mutation,
"Confidence Score": confidence
})
return pd.DataFrame(results)
# Loads example data and analyzes it
def load_example_data():
df = pd.DataFrame({
"DNA_Sequence": [
"AGCTAGCTA",
"GATCGATCG",
"TTAGCTAGCT",
"ATGCGTAGC"
]
})
return analyze_sequences(df)
# Converts DataFrame to CSV string
def dataframe_to_csv(df):
if df is None or df.empty:
return ""
csv_buffer = StringIO()
df.to_csv(csv_buffer, index=False)
return csv_buffer.getvalue()
# Generate mutation statistics summary and chart
def get_mutation_stats(result_df):
if result_df is None or result_df.empty:
return "No data available.", None
# Count mutations
mutation_counts = result_df["Predicted Mutation"].value_counts()
summary_text = "πŸ“Š Mutation Statistics:\n"
for mutation, count in mutation_counts.items():
summary_text += f"- {mutation}: {count}\n"
# Create bar chart
chart = gr.BarPlot(
mutation_counts.reset_index(),
x="Predicted Mutation",
y="count",
title="Mutation Frequency",
color="Predicted Mutation",
tooltip=["Predicted Mutation", "count"],
vertical=False,
height=200
)
return summary_text, chart
# Unified function to process and return all outputs
def process_and_get_stats(file=None):
if file is not None:
result_df = analyze_sequences(file)
else:
result_df = load_example_data()
summary, chart = get_mutation_stats(result_df)
return result_df, summary, chart
# Gradio Interface
with gr.Blocks(theme="default") as demo:
gr.Markdown("""
# 🧬 MutateX – Liquid Biopsy Mutation Detection Tool
Upload a CSV file with DNA sequences to simulate mutation detection.
*Developed by [GradSyntax](https://www.gradsyntax.com )*
""")
with gr.Row(equal_height=True):
upload_btn = gr.File(label="πŸ“ Upload CSV File", file_types=[".csv"])
example_btn = gr.Button("πŸ§ͺ Load Example Data")
output_table = gr.DataFrame(
label="Analysis Results",
headers=["Sequence", "Predicted Mutation", "Confidence Score"]
)
stats_text = gr.Textbox(label="Mutation Statistics Summary")
stats_chart = gr.Plot(label="Mutation Frequency Chart")
download_btn = gr.File(label="⬇️ Download Results as CSV")
# Function calls
upload_btn.upload(fn=process_and_get_stats, inputs=upload_btn, outputs=[output_table, stats_text, stats_chart])
example_btn.click(fn=process_and_get_stats, inputs=None, outputs=[output_table, stats_text, stats_chart])
download_btn.upload(fn=dataframe_to_csv, inputs=output_table, outputs=download_btn)
demo.launch()