import gradio as gr import pandas as pd import torch from transformers import AutoTokenizer, BertForSequenceClassification # Load DNABERT tokenizer and model tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True) model = BertForSequenceClassification.from_pretrained("zhihan1996/DNABERT-2-117M") # Mutation classes (example mapping โ€” update based on your fine-tuning) mutation_map = { 0: "No Mutation", 1: "SNV", 2: "Insertion", 3: "Deletion" } # Simulates mutation detection using DNABERT def analyze_sequences(input_df): if input_df is None or input_df.empty: return pd.DataFrame(columns=["Sequence", "Predicted Mutation", "Confidence Score"]) results = [] for _, row in input_df.iterrows(): seq = row['DNA_Sequence'] # Tokenize and run inference inputs = tokenizer(seq, return_tensors="pt", padding=True, truncation=True) with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits predicted_class = torch.argmax(logits, dim=1).item() confidence = float(torch.softmax(logits, dim=1)[0][predicted_class].item()) # Map prediction to mutation type mutation = mutation_map.get(predicted_class, "Unknown") results.append({ "Sequence": seq, "Predicted Mutation": mutation, "Confidence Score": confidence }) return pd.DataFrame(results) # Loads example data and analyzes it def load_example_data(): df = pd.DataFrame({ "DNA_Sequence": [ "AGCTAGCTA", "GATCGATCG", "TTAGCTAGCT", "ATGCGTAGC" ] }) return analyze_sequences(df) # Converts DataFrame to CSV string def dataframe_to_csv(df): if df is None or df.empty: return "" csv_buffer = StringIO() df.to_csv(csv_buffer, index=False) return csv_buffer.getvalue() # Generate mutation statistics summary and chart def get_mutation_stats(result_df): if result_df is None or result_df.empty: return "No data available.", None # Count mutations mutation_counts = result_df["Predicted Mutation"].value_counts() summary_text = "๐Ÿ“Š Mutation Statistics:\n" for mutation, count in mutation_counts.items(): summary_text += f"- {mutation}: {count}\n" # Create bar chart chart = gr.BarPlot( mutation_counts.reset_index(), x="Predicted Mutation", y="count", title="Mutation Frequency", color="Predicted Mutation", tooltip=["Predicted Mutation", "count"], vertical=False, height=200 ) return summary_text, chart # Unified function to process and return all outputs def process_and_get_stats(file=None): if file is not None: result_df = analyze_sequences(file) else: result_df = load_example_data() summary, chart = get_mutation_stats(result_df) return result_df, summary, chart # Gradio Interface with gr.Blocks(theme="default") as demo: gr.Markdown(""" # ๐Ÿงฌ MutateX โ€“ Liquid Biopsy Mutation Detection Tool Upload a CSV file with DNA sequences to simulate mutation detection. *Developed by [GradSyntax](https://www.gradsyntax.com )* """) with gr.Row(equal_height=True): upload_btn = gr.File(label="๐Ÿ“ Upload CSV File", file_types=[".csv"]) example_btn = gr.Button("๐Ÿงช Load Example Data") output_table = gr.DataFrame( label="Analysis Results", headers=["Sequence", "Predicted Mutation", "Confidence Score"] ) stats_text = gr.Textbox(label="Mutation Statistics Summary") stats_chart = gr.Plot(label="Mutation Frequency Chart") download_btn = gr.File(label="โฌ‡๏ธ Download Results as CSV") # Function calls upload_btn.upload(fn=process_and_get_stats, inputs=upload_btn, outputs=[output_table, stats_text, stats_chart]) example_btn.click(fn=process_and_get_stats, inputs=None, outputs=[output_table, stats_text, stats_chart]) download_btn.upload(fn=dataframe_to_csv, inputs=output_table, outputs=download_btn) demo.launch()