ncfrey commited on
Commit
78669aa
·
verified ·
1 Parent(s): 11ea176

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -0
app.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+
3
+ import gradio as gr
4
+ import torch
5
+ import plotly.express as px
6
+ import numpy as np
7
+ import pandas as pd
8
+ from sklearn.metrics.pairwise import cosine_similarity
9
+ from sklearn.decomposition import PCA
10
+ from transformers import AutoTokenizer, AutoModel
11
+
12
+ # Load model once
13
+ model_name = "karina-zadorozhny/ume"
14
+ model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
15
+ model.eval()
16
+
17
+ # Load all 3 tokenizers
18
+ tokenizer_aa = AutoTokenizer.from_pretrained(model_name, subfolder="tokenizer_amino_acid", trust_remote_code=True)
19
+ tokenizer_nt = AutoTokenizer.from_pretrained(model_name, subfolder="tokenizer_nucleotide", trust_remote_code=True)
20
+ tokenizer_sm = AutoTokenizer.from_pretrained(model_name, subfolder="tokenizer_smiles", trust_remote_code=True)
21
+
22
+
23
+ def detect_modality(seq):
24
+ seq = seq.strip().upper()
25
+ if all(c in "ATGCUN" for c in seq): # DNA/RNA
26
+ return "nucleotide"
27
+ elif all(c in "ACDEFGHIKLMNPQRSTVWYBXZJUO" for c in seq): # Protein
28
+ return "amino_acid"
29
+ else:
30
+ return "smiles"
31
+
32
+
33
+ def compute_embeddings(sequences):
34
+ embeddings = []
35
+
36
+ for seq in sequences:
37
+ modality = detect_modality(seq)
38
+ if modality == "amino_acid":
39
+ tokenizer = tokenizer_aa
40
+ elif modality == "nucleotide":
41
+ tokenizer = tokenizer_nt
42
+ else:
43
+ tokenizer = tokenizer_sm
44
+
45
+ inputs = tokenizer([seq], return_tensors="pt", padding=True, truncation=True)
46
+ with torch.no_grad():
47
+ emb = model(inputs["input_ids"].unsqueeze(1), inputs["attention_mask"].unsqueeze(1))
48
+ embeddings.append(emb.squeeze(0).squeeze(0).numpy())
49
+
50
+ return np.vstack(embeddings)
51
+
52
+ def visualize_embeddings(sequences):
53
+ embeddings = compute_embeddings(sequences)
54
+
55
+ # PCA for 2D and 3D
56
+ pca_2d = PCA(n_components=2).fit_transform(embeddings)
57
+ pca_3d = PCA(n_components=3).fit_transform(embeddings)
58
+
59
+ df_2d = pd.DataFrame(pca_2d, columns=["PC1", "PC2"])
60
+ df_2d["Sequence"] = sequences
61
+
62
+ df_3d = pd.DataFrame(pca_3d, columns=["X", "Y", "Z"])
63
+ df_3d["Sequence"] = sequences
64
+
65
+ fig_2d = px.scatter(df_2d, x="PC1", y="PC2", text="Sequence",
66
+ title="2D PCA of UME Embeddings", color="Sequence",
67
+ color_discrete_sequence=px.colors.qualitative.Bold)
68
+
69
+ fig_3d = px.scatter_3d(df_3d, x="X", y="Y", z="Z", text="Sequence",
70
+ title="3D PCA of UME Embeddings", color="Sequence",
71
+ color_discrete_sequence=px.colors.qualitative.Vivid)
72
+
73
+ return fig_2d, fig_3d
74
+
75
+
76
+ def similarity_matrix(sequences):
77
+ embeddings = compute_embeddings(sequences)
78
+ sim_matrix = cosine_similarity(embeddings)
79
+ sim_df = pd.DataFrame(sim_matrix, index=sequences, columns=sequences)
80
+ fig = px.imshow(sim_df, text_auto=True, color_continuous_scale='Viridis',
81
+ title="Cosine Similarity Matrix")
82
+ return fig
83
+
84
+
85
+ description = """
86
+ # 🧬 UME Explorer: Biosequence Embedding Playground
87
+ Welcome to **UME Explorer**, an interactive space to explore representations of molecules using the UME model.
88
+
89
+ Paste in your DNA, amino acid, or SMILES sequences and:
90
+ - ✨ Visualize embeddings in 2D and 3D
91
+ - 🔬 Explore pairwise similarities
92
+ - 🎨 Enjoy colorful, educational plots!
93
+
94
+ > **Tip**: Keep input sequences short and between 3–20 items for better visuals on CPU.
95
+ """
96
+
97
+ with gr.Blocks(theme=gr.themes.Monochrome(), css="footer {display: none}") as demo:
98
+ gr.Markdown(description)
99
+
100
+ gr.Markdown("""
101
+ ℹ️ <b>How sequence type is detected:</b><br>
102
+ - 🧬 <b>Nucleotide (DNA/RNA):</b> Only uses A, T, G, C, U, N<br>
103
+ - 🔹 <b>Protein (Amino Acid):</b> Includes letters like M, K, V, L, etc.<br>
104
+ - 🧪 <b>SMILES (Chemical):</b> Includes characters like =, (, ), C, O, etc.<br>
105
+ <small>👉 Detection is automatic. You can mix sequence types in one run!</small>
106
+ """)
107
+
108
+ with gr.Row():
109
+ seq_input = gr.Textbox(lines=8, placeholder="Enter sequences, one per line...", label="Input Sequences")
110
+ submit_btn = gr.Button("Compute Embeddings & Visualize")
111
+
112
+ with gr.Row():
113
+ out2d = gr.Plot(label="2D Plot")
114
+ out3d = gr.Plot(label="3D Plot")
115
+
116
+ sim_out = gr.Plot(label="Similarity Heatmap")
117
+
118
+ def process_input(text):
119
+ seqs = [s.strip() for s in text.splitlines() if s.strip()]
120
+ fig2d, fig3d = visualize_embeddings(seqs)
121
+ sim_fig = similarity_matrix(seqs)
122
+ return fig2d, fig3d, sim_fig
123
+
124
+ submit_btn.click(fn=process_input, inputs=seq_input, outputs=[out2d, out3d, sim_out])
125
+
126
+ demo.launch()
127
+