lucapinello commited on
Commit
b074e28
·
1 Parent(s): f9cdce6

Add application file

Browse files
Files changed (4) hide show
  1. app.py +251 -0
  2. dna-slot-machine.html +726 -0
  3. dna_diffusion_model.py +283 -0
  4. requirements.txt +14 -0
app.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DNA-Diffusion Gradio Application
3
+ Interactive DNA sequence generation with slot machine visualization
4
+ """
5
+
6
+ import gradio as gr
7
+ import logging
8
+ import json
9
+ import os
10
+ from typing import Dict, Any, Tuple
11
+ import html
12
+
13
+ # Configure logging
14
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # Try to import model, but allow app to run without it for UI development
18
+ try:
19
+ from dna_diffusion_model import DNADiffusionModel, get_model
20
+ MODEL_AVAILABLE = True
21
+ logger.info("DNA-Diffusion model module loaded successfully")
22
+ except ImportError as e:
23
+ logger.warning(f"DNA-Diffusion model not available: {e}")
24
+ MODEL_AVAILABLE = False
25
+
26
+ # Load the HTML interface
27
+ HTML_FILE = "dna-slot-machine.html"
28
+ if not os.path.exists(HTML_FILE):
29
+ raise FileNotFoundError(f"HTML interface file '{HTML_FILE}' not found. Please ensure it exists in the same directory as app.py")
30
+
31
+ with open(HTML_FILE, "r") as f:
32
+ SLOT_MACHINE_HTML = f.read()
33
+
34
+ class DNADiffusionApp:
35
+ """Main application class for DNA-Diffusion Gradio interface"""
36
+
37
+ def __init__(self):
38
+ self.model = None
39
+ self.model_loading = False
40
+ self.model_error = None
41
+
42
+ def initialize_model(self):
43
+ """Initialize the DNA-Diffusion model"""
44
+ if not MODEL_AVAILABLE:
45
+ self.model_error = "DNA-Diffusion model module not available. Please install dependencies."
46
+ return
47
+
48
+ if self.model_loading:
49
+ return
50
+
51
+ self.model_loading = True
52
+ try:
53
+ logger.info("Starting model initialization...")
54
+ self.model = get_model()
55
+ logger.info("Model initialized successfully!")
56
+ self.model_error = None
57
+ except Exception as e:
58
+ logger.error(f"Failed to initialize model: {e}")
59
+ self.model_error = str(e)
60
+ self.model = None
61
+ finally:
62
+ self.model_loading = False
63
+
64
+ def generate_sequence(self, cell_type: str, guidance_scale: float = 1.0) -> Tuple[str, Dict[str, Any]]:
65
+ """Generate a DNA sequence using the model or mock data"""
66
+
67
+ # Use mock generation if model is not available
68
+ if not MODEL_AVAILABLE or self.model is None:
69
+ logger.warning("Using mock sequence generation")
70
+ import random
71
+ sequence = ''.join(random.choice(['A', 'T', 'C', 'G']) for _ in range(200))
72
+ metadata = {
73
+ 'cell_type': cell_type,
74
+ 'guidance_scale': guidance_scale,
75
+ 'generation_time': 2.0,
76
+ 'mock': True
77
+ }
78
+ # Simulate generation time
79
+ import time
80
+ time.sleep(2.0)
81
+ return sequence, metadata
82
+
83
+ # Use real model
84
+ try:
85
+ result = self.model.generate(cell_type, guidance_scale)
86
+ return result['sequence'], result['metadata']
87
+ except Exception as e:
88
+ logger.error(f"Generation failed: {e}")
89
+ raise
90
+
91
+ def handle_generation_request(self, cell_type: str, guidance_scale: float):
92
+ """Handle sequence generation request from Gradio"""
93
+ try:
94
+ logger.info(f"Generating sequence for cell type: {cell_type}")
95
+ sequence, metadata = self.generate_sequence(cell_type, guidance_scale)
96
+ return sequence, json.dumps(metadata)
97
+
98
+ except Exception as e:
99
+ error_msg = str(e)
100
+ logger.error(f"Generation request failed: {error_msg}")
101
+ return "", json.dumps({"error": error_msg})
102
+
103
+ # Create single app instance
104
+ app = DNADiffusionApp()
105
+
106
+ def create_demo():
107
+ """Create the Gradio demo interface"""
108
+
109
+ # CSS to hide backend controls
110
+ css = """
111
+ #hidden-controls { display: none !important; }
112
+ """
113
+
114
+ # JavaScript for handling communication between iframe and Gradio
115
+ js = """
116
+ function() {
117
+ console.log('Initializing DNA-Diffusion Gradio interface...');
118
+
119
+ // Set up message listener to receive requests from iframe
120
+ window.addEventListener('message', function(event) {
121
+ console.log('Parent received message:', event.data);
122
+
123
+ if (event.data.type === 'generate_request') {
124
+ console.log('Triggering generation for cell type:', event.data.cellType);
125
+
126
+ // Update the hidden cell type input
127
+ const radioInputs = document.querySelectorAll('#cell-type-input input[type="radio"]');
128
+ radioInputs.forEach(input => {
129
+ if (input.value === event.data.cellType) {
130
+ input.checked = true;
131
+ // Trigger change event
132
+ input.dispatchEvent(new Event('change'));
133
+ }
134
+ });
135
+
136
+ // Small delay to ensure radio button update is processed
137
+ setTimeout(() => {
138
+ document.querySelector('#generate-btn').click();
139
+ }, 100);
140
+ }
141
+ });
142
+
143
+ // Function to send sequence to iframe
144
+ window.sendSequenceToIframe = function(sequence, metadata) {
145
+ console.log('Sending sequence to iframe:', sequence);
146
+ const iframe = document.querySelector('#dna-frame iframe');
147
+ if (iframe && iframe.contentWindow) {
148
+ try {
149
+ const meta = JSON.parse(metadata);
150
+ if (meta.error) {
151
+ iframe.contentWindow.postMessage({
152
+ type: 'generation_error',
153
+ error: meta.error
154
+ }, '*');
155
+ } else {
156
+ iframe.contentWindow.postMessage({
157
+ type: 'sequence_generated',
158
+ sequence: sequence,
159
+ metadata: meta
160
+ }, '*');
161
+ }
162
+ } catch (e) {
163
+ console.error('Failed to parse metadata:', e);
164
+ // If parsing fails, still send the sequence
165
+ iframe.contentWindow.postMessage({
166
+ type: 'sequence_generated',
167
+ sequence: sequence,
168
+ metadata: {}
169
+ }, '*');
170
+ }
171
+ } else {
172
+ console.error('Could not find iframe');
173
+ }
174
+ };
175
+ }
176
+ """
177
+
178
+ with gr.Blocks(css=css, js=js, theme=gr.themes.Base()) as demo:
179
+
180
+ # Hidden controls for backend processing
181
+ with gr.Column(elem_id="hidden-controls", visible=False):
182
+ cell_type_input = gr.Radio(
183
+ ["K562", "GM12878", "HepG2"],
184
+ value="K562",
185
+ label="Cell Type",
186
+ elem_id="cell-type-input"
187
+ )
188
+ guidance_input = gr.Slider(
189
+ minimum=1.0,
190
+ maximum=10.0,
191
+ value=1.0,
192
+ step=0.5,
193
+ label="Guidance Scale",
194
+ elem_id="guidance-input"
195
+ )
196
+ generate_btn = gr.Button("Generate", elem_id="generate-btn")
197
+
198
+ sequence_output = gr.Textbox(label="Sequence", elem_id="sequence-output")
199
+ metadata_output = gr.Textbox(label="Metadata", elem_id="metadata-output")
200
+
201
+ # Main interface - the slot machine in an iframe
202
+ # Escape the HTML content for srcdoc
203
+ escaped_html = html.escape(SLOT_MACHINE_HTML, quote=True)
204
+ iframe_html = f'<iframe srcdoc="{escaped_html}" style="width: 100%; height: 100vh; border: none;"></iframe>'
205
+
206
+ html_display = gr.HTML(
207
+ iframe_html,
208
+ elem_id="dna-frame"
209
+ )
210
+
211
+ # Wire up the generation
212
+ generate_btn.click(
213
+ fn=app.handle_generation_request,
214
+ inputs=[cell_type_input, guidance_input],
215
+ outputs=[sequence_output, metadata_output]
216
+ ).then(
217
+ fn=None,
218
+ inputs=[sequence_output, metadata_output],
219
+ outputs=None,
220
+ js="(seq, meta) => sendSequenceToIframe(seq, meta)"
221
+ )
222
+
223
+ # Initialize model on load
224
+ demo.load(
225
+ fn=app.initialize_model,
226
+ inputs=None,
227
+ outputs=None
228
+ )
229
+
230
+ return demo
231
+
232
+ # Launch the app
233
+ if __name__ == "__main__":
234
+ demo = create_demo()
235
+
236
+ # Parse any command line arguments
237
+ import argparse
238
+ parser = argparse.ArgumentParser(description="DNA-Diffusion Gradio App")
239
+ parser.add_argument("--share", action="store_true", help="Create a public shareable link")
240
+ parser.add_argument("--port", type=int, default=7860, help="Port to run the app on")
241
+ parser.add_argument("--host", type=str, default="127.0.0.1", help="Host to run the app on")
242
+ args = parser.parse_args()
243
+
244
+ logger.info(f"Starting DNA-Diffusion Gradio app on {args.host}:{args.port}")
245
+
246
+ demo.launch(
247
+ share=args.share,
248
+ server_name=args.host,
249
+ server_port=args.port,
250
+ inbrowser=True
251
+ )
dna-slot-machine.html ADDED
@@ -0,0 +1,726 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>DNA Slot Machine</title>
7
+ <style>
8
+ * {
9
+ margin: 0;
10
+ padding: 0;
11
+ box-sizing: border-box;
12
+ }
13
+
14
+ body {
15
+ background: #0a0a0a;
16
+ color: #fff;
17
+ font-family: 'Courier New', monospace;
18
+ overflow-x: hidden;
19
+ display: flex;
20
+ flex-direction: column;
21
+ align-items: center;
22
+ justify-content: center;
23
+ min-height: 100vh;
24
+ position: relative;
25
+ }
26
+
27
+ body::before {
28
+ content: '';
29
+ position: fixed;
30
+ top: 0;
31
+ left: 0;
32
+ width: 100%;
33
+ height: 100%;
34
+ background-image:
35
+ repeating-linear-gradient(
36
+ 0deg,
37
+ transparent 0px,
38
+ rgba(255,255,255,0.08) 1px,
39
+ transparent 1px,
40
+ transparent 2px
41
+ ),
42
+ repeating-linear-gradient(
43
+ 90deg,
44
+ transparent 0px,
45
+ rgba(0,0,0,0.05) 1px,
46
+ transparent 1px,
47
+ transparent 2px
48
+ ),
49
+ repeating-linear-gradient(
50
+ 45deg,
51
+ transparent 0px,
52
+ rgba(255,255,255,0.03) 1px,
53
+ transparent 2px,
54
+ transparent 3px
55
+ ),
56
+ repeating-linear-gradient(
57
+ -45deg,
58
+ transparent 0px,
59
+ rgba(0,0,0,0.03) 1px,
60
+ transparent 2px,
61
+ transparent 3px
62
+ );
63
+ background-size: 2px 2px, 2px 2px, 3px 3px, 3px 3px;
64
+ pointer-events: none;
65
+ z-index: 1;
66
+ opacity: 0.8;
67
+ animation: staticNoise 0.1s steps(8) infinite;
68
+ }
69
+
70
+ body::after {
71
+ content: '';
72
+ position: fixed;
73
+ top: 0;
74
+ left: 0;
75
+ width: 100%;
76
+ height: 100%;
77
+ background:
78
+ radial-gradient(circle at 17% 23%, rgba(255,255,255,0.1) 0px, transparent 1px),
79
+ radial-gradient(circle at 67% 71%, rgba(0,0,0,0.08) 0px, transparent 1px),
80
+ radial-gradient(circle at 41% 57%, rgba(255,255,255,0.06) 0px, transparent 1px),
81
+ radial-gradient(circle at 89% 13%, rgba(0,0,0,0.07) 0px, transparent 1px),
82
+ radial-gradient(circle at 23% 89%, rgba(255,255,255,0.05) 0px, transparent 1px);
83
+ background-size: 3px 3px, 2px 2px, 4px 4px, 2px 2px, 3px 3px;
84
+ pointer-events: none;
85
+ z-index: 1;
86
+ animation: staticNoise 0.15s steps(10) infinite reverse;
87
+ }
88
+
89
+ @keyframes staticNoise {
90
+ 0%, 100% { transform: translate(0, 0); }
91
+ 10% { transform: translate(-1px, -1px); }
92
+ 20% { transform: translate(1px, 0px); }
93
+ 30% { transform: translate(0px, 1px); }
94
+ 40% { transform: translate(-1px, 1px); }
95
+ 50% { transform: translate(1px, -1px); }
96
+ 60% { transform: translate(-1px, 0px); }
97
+ 70% { transform: translate(0px, -1px); }
98
+ 80% { transform: translate(1px, 1px); }
99
+ 90% { transform: translate(-1px, -1px); }
100
+ }
101
+
102
+ .machine-container {
103
+ background: linear-gradient(145deg, #1a1a1a, #2d2d2d);
104
+ border-radius: 20px;
105
+ padding: 20px;
106
+ padding-right: 100px;
107
+ box-shadow: 0 20px 40px rgba(0,0,0,0.5),
108
+ inset 0 2px 10px rgba(255,255,255,0.1);
109
+ width: 95vw;
110
+ max-width: 1400px;
111
+ position: relative;
112
+ z-index: 2;
113
+ }
114
+
115
+ .title {
116
+ text-align: center;
117
+ font-size: 2.5rem;
118
+ margin-bottom: 20px;
119
+ background: linear-gradient(45deg, #00ff88, #00ffff, #ff00ff);
120
+ -webkit-background-clip: text;
121
+ -webkit-text-fill-color: transparent;
122
+ text-shadow: 0 0 30px rgba(0,255,136,0.5);
123
+ font-weight: bold;
124
+ letter-spacing: 0.1em;
125
+ }
126
+
127
+ .cell-type-selector {
128
+ display: flex;
129
+ align-items: center;
130
+ justify-content: center;
131
+ gap: 20px;
132
+ margin-bottom: 20px;
133
+ }
134
+
135
+ .cell-type-label {
136
+ font-size: 1.2rem;
137
+ color: #ccc;
138
+ }
139
+
140
+ .radio-group {
141
+ display: flex;
142
+ gap: 20px;
143
+ }
144
+
145
+ .radio-label {
146
+ display: flex;
147
+ align-items: center;
148
+ gap: 8px;
149
+ cursor: pointer;
150
+ font-size: 1.1rem;
151
+ color: #fff;
152
+ transition: color 0.3s ease;
153
+ }
154
+
155
+ .radio-label:hover {
156
+ color: #00ff88;
157
+ }
158
+
159
+ .radio-label input[type="radio"] {
160
+ width: 18px;
161
+ height: 18px;
162
+ accent-color: #00ff88;
163
+ cursor: pointer;
164
+ }
165
+
166
+ .reels-container {
167
+ background: #000;
168
+ border: 3px solid #333;
169
+ border-radius: 10px;
170
+ padding: 20px;
171
+ max-width: 100%;
172
+ position: relative;
173
+ box-shadow: inset 0 0 20px rgba(0,0,0,0.5);
174
+ overflow: visible;
175
+ }
176
+
177
+ .reels-wrapper {
178
+ display: flex;
179
+ gap: 1px;
180
+ min-width: fit-content;
181
+ padding: 5px 0;
182
+ justify-content: center;
183
+ flex-wrap: wrap;
184
+ max-width: 1200px;
185
+ margin: 0 auto;
186
+ }
187
+
188
+ .reel {
189
+ width: 18px;
190
+ height: 40px;
191
+ background: #ffffff;
192
+ border: 1px solid #ddd;
193
+ border-radius: 2px;
194
+ overflow: hidden;
195
+ position: relative;
196
+ box-shadow: inset 0 0 3px rgba(0,0,0,0.1);
197
+ }
198
+
199
+ .reel-strip {
200
+ position: absolute;
201
+ width: 100%;
202
+ transition: transform 0.5s ease-out;
203
+ }
204
+
205
+ .nucleotide {
206
+ height: 40px;
207
+ display: flex;
208
+ align-items: center;
209
+ justify-content: center;
210
+ font-size: 0.9rem;
211
+ font-weight: bold;
212
+ background: #ffffff;
213
+ }
214
+
215
+ .nucleotide.A { color: #00ff00; }
216
+ .nucleotide.T { color: #ff0000; }
217
+ .nucleotide.C { color: #0000ff; }
218
+ .nucleotide.G { color: #ffa500; }
219
+
220
+ .controls {
221
+ display: flex;
222
+ flex-direction: column;
223
+ align-items: center;
224
+ gap: 20px;
225
+ margin-top: 30px;
226
+ }
227
+
228
+ .spin-button {
229
+ background: #4a4a4a;
230
+ border: none;
231
+ padding: 20px 60px;
232
+ font-size: 1.5rem;
233
+ font-weight: bold;
234
+ border-radius: 50px;
235
+ cursor: pointer;
236
+ text-transform: uppercase;
237
+ letter-spacing: 2px;
238
+ box-shadow: 0 10px 20px rgba(0,0,0,0.5);
239
+ transition: all 0.3s ease;
240
+ color: #fff;
241
+ text-shadow: 0 2px 4px rgba(0,0,0,0.3);
242
+ position: relative;
243
+ overflow: hidden;
244
+ }
245
+
246
+ .spin-button::before {
247
+ content: '';
248
+ position: absolute;
249
+ top: 0;
250
+ left: 0;
251
+ width: 100%;
252
+ height: 100%;
253
+ background-image:
254
+ radial-gradient(circle at 20% 30%, #00ff00 0px, transparent 2px),
255
+ radial-gradient(circle at 80% 70%, #ff0000 0px, transparent 2px),
256
+ radial-gradient(circle at 50% 50%, #0000ff 0px, transparent 2px),
257
+ radial-gradient(circle at 30% 80%, #ffa500 0px, transparent 2px),
258
+ radial-gradient(circle at 70% 20%, #00ff00 0px, transparent 2px),
259
+ radial-gradient(circle at 10% 60%, #ff0000 0px, transparent 2px),
260
+ radial-gradient(circle at 90% 40%, #0000ff 0px, transparent 2px),
261
+ radial-gradient(circle at 40% 10%, #ffa500 0px, transparent 2px);
262
+ background-size: 20px 20px, 25px 25px, 30px 30px, 15px 15px,
263
+ 18px 18px, 22px 22px, 28px 28px, 24px 24px;
264
+ opacity: 0.25;
265
+ animation: nucleotideNoise 0.8s steps(6) infinite;
266
+ }
267
+
268
+ .spin-button::after {
269
+ content: '';
270
+ position: absolute;
271
+ top: 0;
272
+ left: 0;
273
+ width: 100%;
274
+ height: 100%;
275
+ background-image:
276
+ radial-gradient(circle at 60% 40%, #00ff00 0px, transparent 1px),
277
+ radial-gradient(circle at 25% 75%, #ff0000 0px, transparent 1px),
278
+ radial-gradient(circle at 85% 15%, #0000ff 0px, transparent 1px),
279
+ radial-gradient(circle at 15% 25%, #ffa500 0px, transparent 1px);
280
+ background-size: 10px 10px, 12px 12px, 14px 14px, 16px 16px;
281
+ opacity: 0.2;
282
+ animation: nucleotideNoise 1.2s steps(8) infinite reverse;
283
+ }
284
+
285
+ @keyframes nucleotideNoise {
286
+ 0% { transform: translate(0, 0) scale(1); }
287
+ 16% { transform: translate(-2px, 1px) scale(1.02); }
288
+ 33% { transform: translate(1px, -2px) scale(0.98); }
289
+ 50% { transform: translate(-1px, 2px) scale(1.01); }
290
+ 66% { transform: translate(2px, -1px) scale(0.99); }
291
+ 83% { transform: translate(-2px, -2px) scale(1.02); }
292
+ 100% { transform: translate(1px, 1px) scale(1); }
293
+ }
294
+
295
+ .spin-button span {
296
+ position: relative;
297
+ z-index: 2;
298
+ }
299
+
300
+ .spin-button:hover {
301
+ transform: translateY(-2px);
302
+ box-shadow: 0 15px 30px rgba(0,0,0,0.6);
303
+ background: #5a5a5a;
304
+ }
305
+
306
+ .spin-button:hover::before {
307
+ opacity: 0.35;
308
+ animation-duration: 0.4s;
309
+ }
310
+
311
+ .spin-button:active {
312
+ transform: translateY(0);
313
+ }
314
+
315
+ .spin-button:disabled {
316
+ background: #444;
317
+ cursor: not-allowed;
318
+ box-shadow: none;
319
+ }
320
+
321
+ .sequence-display {
322
+ background: #0a0a0a;
323
+ border: 2px solid #333;
324
+ border-radius: 10px;
325
+ padding: 25px 30px 15px 30px;
326
+ font-family: 'Courier New', monospace;
327
+ font-size: 0.9rem;
328
+ letter-spacing: 1px;
329
+ width: 100%;
330
+ max-width: 1200px;
331
+ text-align: left;
332
+ word-wrap: break-word;
333
+ line-height: 1.5;
334
+ position: relative;
335
+ margin: 0 auto;
336
+ }
337
+
338
+ .sequence-display::before {
339
+ content: 'SYNTHETIC REGULATORY ELEMENT';
340
+ position: absolute;
341
+ top: -10px;
342
+ left: 50%;
343
+ transform: translateX(-50%);
344
+ background: #0a0a0a;
345
+ padding: 0 15px;
346
+ font-size: 0.7rem;
347
+ color: #00ff88;
348
+ letter-spacing: 2px;
349
+ white-space: nowrap;
350
+ }
351
+
352
+ .info {
353
+ text-align: center;
354
+ margin-top: 20px;
355
+ color: #888;
356
+ font-size: 0.9rem;
357
+ }
358
+
359
+ .lab-credit {
360
+ text-align: center;
361
+ margin-top: 15px;
362
+ font-size: 1.1rem;
363
+ }
364
+
365
+ .lab-credit a {
366
+ color: #00ff88;
367
+ text-decoration: none;
368
+ font-weight: bold;
369
+ letter-spacing: 1px;
370
+ transition: all 0.3s ease;
371
+ padding: 5px 15px;
372
+ border: 1px solid transparent;
373
+ border-radius: 20px;
374
+ }
375
+
376
+ .lab-credit a:hover {
377
+ color: #fff;
378
+ border-color: #00ff88;
379
+ box-shadow: 0 0 10px rgba(0,255,136,0.5);
380
+ text-shadow: 0 0 5px rgba(0,255,136,0.5);
381
+ }
382
+
383
+ @keyframes pulse {
384
+ 0% { opacity: 0.5; }
385
+ 50% { opacity: 1; }
386
+ 100% { opacity: 0.5; }
387
+ }
388
+
389
+ .spinning {
390
+ animation: pulse 0.5s infinite;
391
+ }
392
+
393
+ .winning-flash {
394
+ animation: winFlash 1s ease-out;
395
+ }
396
+
397
+ @keyframes winFlash {
398
+ 0%, 100% { background-color: transparent; }
399
+ 50% { background-color: rgba(0,255,136,0.2); }
400
+ }
401
+
402
+ .lever-container {
403
+ position: absolute;
404
+ right: -70px;
405
+ top: 50%;
406
+ transform: translateY(-50%);
407
+ z-index: 3;
408
+ width: 60px;
409
+ height: 200px;
410
+ }
411
+
412
+ .lever {
413
+ width: 100%;
414
+ height: 100%;
415
+ position: relative;
416
+ cursor: pointer;
417
+ }
418
+
419
+ .lever-mount {
420
+ position: absolute;
421
+ top: 90px;
422
+ left: -10px;
423
+ width: 40px;
424
+ height: 60px;
425
+ background: linear-gradient(180deg, #555, #333);
426
+ border-radius: 5px 0 0 5px;
427
+ box-shadow:
428
+ 0 3px 10px rgba(0,0,0,0.3),
429
+ inset 0 1px 2px rgba(255,255,255,0.1);
430
+ }
431
+
432
+ .lever-pivot {
433
+ position: absolute;
434
+ bottom: 30px;
435
+ left: 50%;
436
+ transform: translateX(-50%);
437
+ width: 30px;
438
+ height: 8px;
439
+ background: #888;
440
+ border-radius: 4px;
441
+ box-shadow: 0 2px 4px rgba(0,0,0,0.3);
442
+ }
443
+
444
+ .lever-arm {
445
+ position: absolute;
446
+ top: 40px;
447
+ left: 5px;
448
+ width: 10px;
449
+ height: 80px;
450
+ background: linear-gradient(180deg, #d0d0d0, #a0a0a0);
451
+ border-radius: 5px;
452
+ box-shadow: 0 2px 5px rgba(0,0,0,0.3);
453
+ transition: all 0.6s cubic-bezier(0.68, -0.55, 0.265, 1.55);
454
+ }
455
+
456
+ .lever-ball {
457
+ position: absolute;
458
+ top: -30px;
459
+ left: 50%;
460
+ transform: translateX(-50%);
461
+ width: 60px;
462
+ height: 60px;
463
+ background: radial-gradient(circle at 35% 35%, #ff8888, #ff4444, #cc0000);
464
+ border-radius: 50%;
465
+ box-shadow:
466
+ 0 5px 15px rgba(0,0,0,0.4),
467
+ inset -5px -5px 10px rgba(0,0,0,0.3),
468
+ inset 3px 3px 5px rgba(255,255,255,0.5);
469
+ }
470
+
471
+ .lever.pulled .lever-arm {
472
+ transform: translateY(80px);
473
+ height: 10px;
474
+ }
475
+
476
+ /* Continuous spinning animation for loading */
477
+ @keyframes continuousSpin {
478
+ from { transform: translateY(0); }
479
+ to { transform: translateY(-160px); }
480
+ }
481
+
482
+ .reel-strip.loading {
483
+ animation: continuousSpin 1s linear infinite;
484
+ }
485
+ </style>
486
+ </head>
487
+ <body>
488
+ <div class="machine-container">
489
+ <h1 class="title">DNA-DIFFUSION</h1>
490
+
491
+ <div class="cell-type-selector">
492
+ <label class="cell-type-label">Cell Type-Specific Generation:</label>
493
+ <div class="radio-group">
494
+ <label class="radio-label">
495
+ <input type="radio" name="cellType" value="K562" checked>
496
+ <span>K562</span>
497
+ </label>
498
+ <label class="radio-label">
499
+ <input type="radio" name="cellType" value="GM12878">
500
+ <span>GM12878</span>
501
+ </label>
502
+ <label class="radio-label">
503
+ <input type="radio" name="cellType" value="HepG2">
504
+ <span>HepG2</span>
505
+ </label>
506
+ </div>
507
+ </div>
508
+
509
+ <div class="reels-container" id="reelsContainer">
510
+ <div class="reels-wrapper" id="reelsWrapper"></div>
511
+ <div class="lever-container">
512
+ <div class="lever" id="lever">
513
+ <div class="lever-mount">
514
+ <div class="lever-pivot"></div>
515
+ </div>
516
+ <div class="lever-arm">
517
+ <div class="lever-ball"></div>
518
+ </div>
519
+ </div>
520
+ </div>
521
+ </div>
522
+
523
+ <div class="controls">
524
+ <button class="spin-button" id="spinButton"><span>GENERATE</span></button>
525
+ <div class="sequence-display" id="sequenceDisplay">
526
+ Press GENERATE to create sequence
527
+ </div>
528
+ </div>
529
+
530
+ <div class="info">
531
+ 200bp Regulatory Elements · Cell Type-Specific · Synthetic Biology
532
+ </div>
533
+
534
+ <div class="lab-credit">
535
+ <a href="https://pinellolab.org" target="_blank" rel="noopener noreferrer">
536
+ Pinello Lab
537
+ </a>
538
+ </div>
539
+ </div>
540
+
541
+ <script>
542
+ const NUCLEOTIDES = ['A', 'T', 'C', 'G'];
543
+ const REEL_COUNT = 200;
544
+ let TARGET_SEQUENCE = '';
545
+
546
+ let reels = [];
547
+ let isSpinning = false;
548
+
549
+ function generateRandomSequence() {
550
+ let sequence = '';
551
+ for (let i = 0; i < REEL_COUNT; i++) {
552
+ sequence += NUCLEOTIDES[Math.floor(Math.random() * 4)];
553
+ }
554
+ return sequence;
555
+ }
556
+
557
+ function createReel(index) {
558
+ const reel = document.createElement('div');
559
+ reel.className = 'reel';
560
+
561
+ const strip = document.createElement('div');
562
+ strip.className = 'reel-strip';
563
+
564
+ // Create multiple nucleotides for smooth spinning effect
565
+ for (let i = 0; i < 10; i++) {
566
+ NUCLEOTIDES.forEach(n => {
567
+ const nucleotide = document.createElement('div');
568
+ nucleotide.className = `nucleotide ${n}`;
569
+ nucleotide.textContent = n;
570
+ strip.appendChild(nucleotide);
571
+ });
572
+ }
573
+
574
+ reel.appendChild(strip);
575
+ return { element: reel, strip: strip, currentPosition: 0 };
576
+ }
577
+
578
+ function initializeReels() {
579
+ const wrapper = document.getElementById('reelsWrapper');
580
+ wrapper.innerHTML = '';
581
+ reels = [];
582
+
583
+ for (let i = 0; i < REEL_COUNT; i++) {
584
+ const reel = createReel(i);
585
+ reels.push(reel);
586
+ wrapper.appendChild(reel.element);
587
+
588
+ // Set initial position to show a random nucleotide
589
+ const randomIndex = Math.floor(Math.random() * 4);
590
+ const initialOffset = -randomIndex * 40;
591
+ reel.strip.style.transform = `translateY(${initialOffset}px)`;
592
+ reel.currentPosition = randomIndex * 40;
593
+ }
594
+ }
595
+
596
+ function startContinuousSpinning() {
597
+ reels.forEach((reel, index) => {
598
+ // Add continuous spinning animation
599
+ reel.strip.style.transition = 'none';
600
+ reel.strip.classList.add('loading');
601
+
602
+ // Add slight delay variation for visual effect
603
+ const delay = (index % 10) * 0.1;
604
+ reel.strip.style.animationDelay = `${delay}s`;
605
+ });
606
+ }
607
+
608
+ function stopAndShowSequence(sequence) {
609
+ TARGET_SEQUENCE = sequence;
610
+
611
+ reels.forEach((reel, index) => {
612
+ // Remove continuous spinning
613
+ reel.strip.classList.remove('loading');
614
+
615
+ // Calculate target position
616
+ const targetNucleotide = TARGET_SEQUENCE[index];
617
+ const targetIndex = NUCLEOTIDES.indexOf(targetNucleotide);
618
+ const finalPosition = targetIndex * 40;
619
+
620
+ // Set up the final positioning animation
621
+ setTimeout(() => {
622
+ reel.strip.style.transition = `transform ${1000 + index * 5}ms cubic-bezier(0.17, 0.67, 0.12, 0.99)`;
623
+ reel.strip.style.transform = `translateY(${-finalPosition}px)`;
624
+ reel.currentPosition = finalPosition;
625
+ }, index * 2);
626
+ });
627
+
628
+ // Show the complete sequence after animation
629
+ setTimeout(() => {
630
+ const container = document.getElementById('reelsContainer');
631
+ const display = document.getElementById('sequenceDisplay');
632
+ const button = document.getElementById('spinButton');
633
+ const lever = document.getElementById('lever');
634
+
635
+ container.classList.remove('spinning');
636
+ container.classList.add('winning-flash');
637
+
638
+ display.innerHTML = `<strong>Generated Sequence:</strong><br>${TARGET_SEQUENCE}`;
639
+ button.disabled = false;
640
+ isSpinning = false;
641
+
642
+ // Release the lever
643
+ lever.classList.remove('pulled');
644
+
645
+ setTimeout(() => {
646
+ container.classList.remove('winning-flash');
647
+ }, 1000);
648
+ }, 1500);
649
+ }
650
+
651
+ function startGeneration() {
652
+ if (isSpinning) return;
653
+
654
+ isSpinning = true;
655
+ const button = document.getElementById('spinButton');
656
+ const display = document.getElementById('sequenceDisplay');
657
+ const container = document.getElementById('reelsContainer');
658
+ const lever = document.getElementById('lever');
659
+
660
+ // Pull the lever
661
+ lever.classList.add('pulled');
662
+
663
+ button.disabled = true;
664
+ display.textContent = 'Generating cell type-specific regulatory element...';
665
+ container.classList.add('spinning');
666
+
667
+ // Start continuous spinning immediately
668
+ startContinuousSpinning();
669
+
670
+ // Get selected cell type
671
+ const cellType = document.querySelector('input[name="cellType"]:checked').value;
672
+
673
+ // Send request to parent window
674
+ window.parent.postMessage({
675
+ type: 'generate_request',
676
+ cellType: cellType
677
+ }, '*');
678
+ }
679
+
680
+ // Initialize
681
+ initializeReels();
682
+
683
+ // Event listeners
684
+ document.getElementById('spinButton').addEventListener('click', startGeneration);
685
+
686
+ // Lever click functionality
687
+ document.getElementById('lever').addEventListener('click', function() {
688
+ if (!isSpinning) {
689
+ startGeneration();
690
+ }
691
+ });
692
+
693
+ // Listen for messages from parent window
694
+ window.addEventListener('message', (event) => {
695
+ if (event.data.type === 'sequence_generated') {
696
+ // Stop spinning and show the actual sequence
697
+ stopAndShowSequence(event.data.sequence);
698
+ } else if (event.data.type === 'generation_error') {
699
+ // Stop spinning and show error
700
+ reels.forEach(reel => {
701
+ reel.strip.classList.remove('loading');
702
+ });
703
+
704
+ const container = document.getElementById('reelsContainer');
705
+ const display = document.getElementById('sequenceDisplay');
706
+ const button = document.getElementById('spinButton');
707
+ const lever = document.getElementById('lever');
708
+
709
+ container.classList.remove('spinning');
710
+ display.innerHTML = '<strong style="color: #F44336;">Error:</strong> ' + event.data.error;
711
+ button.disabled = false;
712
+ isSpinning = false;
713
+ lever.classList.remove('pulled');
714
+ }
715
+ });
716
+
717
+ // Keyboard support
718
+ document.addEventListener('keydown', (e) => {
719
+ if (e.code === 'Space' && !isSpinning) {
720
+ e.preventDefault();
721
+ startGeneration();
722
+ }
723
+ });
724
+ </script>
725
+ </body>
726
+ </html>
dna_diffusion_model.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DNA-Diffusion Model Wrapper
3
+ Singleton class to handle model loading and sequence generation
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import torch
9
+ import numpy as np
10
+ import logging
11
+ from typing import Optional, Dict, List
12
+ import time
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ class DNADiffusionModel:
17
+ """Singleton wrapper for DNA-Diffusion model"""
18
+ _instance = None
19
+ _initialized = False
20
+
21
+ # Cell type mapping from simple names to dataset identifiers
22
+ CELL_TYPE_MAPPING = {
23
+ 'K562': 'K562_ENCLB843GMH',
24
+ 'GM12878': 'GM12878_ENCLB441ZZZ',
25
+ 'HepG2': 'HepG2_ENCLB029COU',
26
+ 'hESCT0': 'hESCT0_ENCLB449ZZZ'
27
+ }
28
+
29
+ def __new__(cls):
30
+ if cls._instance is None:
31
+ cls._instance = super().__new__(cls)
32
+ return cls._instance
33
+
34
+ def __init__(self):
35
+ """Initialize the model (only runs once due to singleton pattern)"""
36
+ if not self._initialized:
37
+ self._initialize()
38
+ self._initialized = True
39
+
40
+ def _initialize(self):
41
+ """Load model and setup components"""
42
+ try:
43
+ logger.info("Initializing DNA-Diffusion model...")
44
+
45
+ # Add DNA-Diffusion to path
46
+ dna_diffusion_path = os.path.join(os.path.dirname(__file__), 'DNA-Diffusion')
47
+ if os.path.exists(dna_diffusion_path):
48
+ sys.path.insert(0, os.path.join(dna_diffusion_path, 'src'))
49
+
50
+ # Import DNA-Diffusion components
51
+ from dnadiffusion.models.pretrained_unet import PretrainedUNet
52
+ from dnadiffusion.models.diffusion import Diffusion
53
+ from dnadiffusion.data.dataloader import get_dataset_for_sampling
54
+
55
+ # Load pretrained model from HuggingFace
56
+ logger.info("Loading pretrained model from HuggingFace...")
57
+ self.model = PretrainedUNet.from_pretrained("ssenan/DNA-Diffusion")
58
+
59
+ # Move to GPU if available
60
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
61
+ logger.info(f"Using device: {self.device}")
62
+ self.model = self.model.to(self.device)
63
+ self.model.eval()
64
+
65
+ # Initialize diffusion sampler with the model
66
+ self.diffusion = Diffusion(
67
+ model=self.model,
68
+ timesteps=50,
69
+ beta_start=0.0001,
70
+ beta_end=0.2
71
+ )
72
+
73
+ # Ensure output_attention is set to False initially
74
+ if hasattr(self.model, 'output_attention'):
75
+ self.model.output_attention = False
76
+ if hasattr(self.model.model, 'output_attention'):
77
+ self.model.model.output_attention = False
78
+
79
+ # Setup dataset for sampling
80
+ data_path = os.path.join(dna_diffusion_path, "data/K562_hESCT0_HepG2_GM12878_12k_sequences_per_group.txt")
81
+ saved_data_path = os.path.join(dna_diffusion_path, "data/encode_data.pkl")
82
+
83
+ # Get dataset info
84
+ train_data, val_data, cell_num_list, numeric_to_tag_dict = get_dataset_for_sampling(
85
+ data_path=data_path,
86
+ saved_data_path=saved_data_path,
87
+ load_saved_data=True,
88
+ debug=False,
89
+ cell_types=None # Load all cell types
90
+ )
91
+
92
+ # Store dataset info
93
+ self.train_data = train_data
94
+ self.val_data = val_data
95
+ self.cell_num_list = cell_num_list
96
+ self.numeric_to_tag_dict = numeric_to_tag_dict
97
+
98
+ # Get available cell types
99
+ self.available_cell_types = [numeric_to_tag_dict[num] for num in cell_num_list]
100
+ logger.info(f"Available cell types: {self.available_cell_types}")
101
+
102
+ # Warm up the model with a test generation
103
+ logger.info("Warming up model...")
104
+ self._warmup()
105
+
106
+ logger.info("Model initialization complete!")
107
+
108
+ except Exception as e:
109
+ logger.error(f"Failed to initialize model: {str(e)}")
110
+ self.model = None
111
+ self.diffusion = None
112
+ self.dataset = None
113
+ raise
114
+
115
+ def _warmup(self):
116
+ """Warm up the model with a test generation"""
117
+ try:
118
+ # Generate one sequence for the first available cell type
119
+ if self.available_cell_types:
120
+ cell_type = list(self.CELL_TYPE_MAPPING.keys())[0]
121
+ self.generate(cell_type, guidance_scale=1.0)
122
+ except Exception as e:
123
+ logger.warning(f"Warmup generation failed: {str(e)}")
124
+
125
+ def is_ready(self) -> bool:
126
+ """Check if model is loaded and ready"""
127
+ return self.model is not None and self.diffusion is not None and self.train_data is not None
128
+
129
+ def generate(self, cell_type: str, guidance_scale: float = 1.0) -> Dict[str, any]:
130
+ """
131
+ Generate a DNA sequence for the specified cell type
132
+
133
+ Args:
134
+ cell_type: Simple cell type name (K562, GM12878, HepG2, hESCT0)
135
+ guidance_scale: Guidance scale for generation (1.0-10.0)
136
+
137
+ Returns:
138
+ Dict with 'sequence' (200bp string) and 'metadata'
139
+ """
140
+ if not self.is_ready():
141
+ raise RuntimeError("Model is not initialized")
142
+
143
+ # Validate inputs
144
+ if cell_type not in self.CELL_TYPE_MAPPING:
145
+ raise ValueError(f"Invalid cell type: {cell_type}. Must be one of {list(self.CELL_TYPE_MAPPING.keys())}")
146
+
147
+ if not 1.0 <= guidance_scale <= 10.0:
148
+ raise ValueError(f"Guidance scale must be between 1.0 and 10.0, got {guidance_scale}")
149
+
150
+ # Map to full cell type identifier
151
+ full_cell_type = self.CELL_TYPE_MAPPING[cell_type]
152
+
153
+ # Find the numeric index for this cell type
154
+ tag_to_numeric = {tag: num for num, tag in self.numeric_to_tag_dict.items()}
155
+
156
+ # Find matching cell type (case-insensitive partial match)
157
+ cell_type_numeric = None
158
+ for tag, num in tag_to_numeric.items():
159
+ if full_cell_type.lower() in tag.lower() or tag.lower() in full_cell_type.lower():
160
+ cell_type_numeric = num
161
+ logger.info(f"Matched '{full_cell_type}' to '{tag}'")
162
+ break
163
+
164
+ if cell_type_numeric is None:
165
+ raise ValueError(f"Cell type {full_cell_type} not found in dataset. Available: {list(self.numeric_to_tag_dict.values())}")
166
+
167
+ try:
168
+ logger.info(f"Generating sequence for {cell_type} (guidance={guidance_scale})...")
169
+ start_time = time.time()
170
+
171
+ # For now, use simple generation without classifier-free guidance
172
+ # TODO: Fix classifier-free guidance implementation
173
+ sequence = self._generate_simple(cell_type_numeric, guidance_scale)
174
+
175
+ generation_time = time.time() - start_time
176
+ logger.info(f"Generated sequence in {generation_time:.2f}s")
177
+
178
+ return {
179
+ 'sequence': sequence,
180
+ 'metadata': {
181
+ 'cell_type': cell_type,
182
+ 'full_cell_type': full_cell_type,
183
+ 'guidance_scale': guidance_scale,
184
+ 'generation_time': generation_time,
185
+ 'sequence_length': len(sequence)
186
+ }
187
+ }
188
+
189
+ except Exception as e:
190
+ logger.error(f"Generation failed: {str(e)}")
191
+ raise
192
+
193
+ def _generate_simple(self, cell_type_idx: int, guidance_scale: float) -> str:
194
+ """Simple generation using the diffusion model's sample method"""
195
+ with torch.no_grad():
196
+ # For guidance_scale = 1.0, use simple generation without classifier-free guidance
197
+ if guidance_scale == 1.0:
198
+ # Create initial noise
199
+ img = torch.randn((1, 1, 4, 200), device=self.device)
200
+
201
+ # Simple denoising loop without guidance
202
+ for i in reversed(range(self.diffusion.timesteps)):
203
+ t = torch.full((1,), i, device=self.device, dtype=torch.long)
204
+
205
+ # Get model prediction with classes
206
+ classes = torch.tensor([cell_type_idx], device=self.device, dtype=torch.long)
207
+ noise_pred = self.model(img, time=t, classes=classes)
208
+
209
+ # Denoising step
210
+ betas_t = self.diffusion.betas[i]
211
+ sqrt_one_minus_alphas_cumprod_t = self.diffusion.sqrt_one_minus_alphas_cumprod[i]
212
+ sqrt_recip_alphas_t = self.diffusion.sqrt_recip_alphas[i]
213
+
214
+ # Predict x0
215
+ model_mean = sqrt_recip_alphas_t * (img - betas_t * noise_pred / sqrt_one_minus_alphas_cumprod_t)
216
+
217
+ if i == 0:
218
+ img = model_mean
219
+ else:
220
+ posterior_variance_t = self.diffusion.posterior_variance[i]
221
+ noise = torch.randn_like(img)
222
+ img = model_mean + torch.sqrt(posterior_variance_t) * noise
223
+
224
+ final_image = img[0] # Remove batch dimension
225
+ else:
226
+ # Use the diffusion model's built-in sample method with guidance
227
+ # This requires proper context mask handling which is complex
228
+ # For now, fall back to simple generation
229
+ logger.warning(f"Guidance scale {guidance_scale} not fully implemented, using simple generation")
230
+ return self._generate_simple(cell_type_idx, 1.0)
231
+
232
+ # Convert to sequence
233
+ final_array = final_image.cpu().numpy()
234
+ sequence = self._array_to_sequence(final_array)
235
+
236
+ return sequence
237
+
238
+ def _array_to_sequence(self, array: np.ndarray) -> str:
239
+ """Convert model output array to DNA sequence string"""
240
+ # Get nucleotide mapping
241
+ nucleotides = ['A', 'C', 'G', 'T']
242
+
243
+ # array shape is (1, 4, 200) - channels, nucleotides, sequence_length
244
+ # Reshape to (4, 200) and get argmax along nucleotide dimension
245
+ array = array.squeeze(0) # Remove channel dimension -> (4, 200)
246
+ indices = np.argmax(array, axis=0) # Get max nucleotide for each position
247
+
248
+ # Convert indices to nucleotides
249
+ sequence = ''.join(nucleotides[int(idx)] for idx in indices)
250
+
251
+ return sequence
252
+
253
+ def get_model_info(self) -> Dict[str, any]:
254
+ """Get information about the loaded model"""
255
+ if not self.is_ready():
256
+ return {'status': 'not_initialized'}
257
+
258
+ return {
259
+ 'status': 'ready',
260
+ 'device': str(self.device),
261
+ 'cell_types': list(self.CELL_TYPE_MAPPING.keys()),
262
+ 'full_cell_types': self.available_cell_types,
263
+ 'model_name': 'ssenan/DNA-Diffusion',
264
+ 'sequence_length': 200,
265
+ 'guidance_scale_range': [1.0, 10.0]
266
+ }
267
+
268
+
269
+ # Convenience functions for direct usage
270
+ _model_instance = None
271
+
272
+ def get_model() -> DNADiffusionModel:
273
+ """Get or create the singleton model instance"""
274
+ global _model_instance
275
+ if _model_instance is None:
276
+ _model_instance = DNADiffusionModel()
277
+ return _model_instance
278
+
279
+ def generate_sequence(cell_type: str, guidance_scale: float = 1.0) -> str:
280
+ """Generate a DNA sequence (convenience function)"""
281
+ model = get_model()
282
+ result = model.generate(cell_type, guidance_scale)
283
+ return result['sequence']
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core dependencies for Gradio DNA-Diffusion app
2
+ gradio>=4.44.0
3
+ torch>=2.0.0
4
+ transformers>=4.30.0
5
+ numpy>=1.24.0
6
+
7
+ # Model loading and utilities
8
+ huggingface-hub>=0.16.0
9
+ safetensors>=0.3.0
10
+ accelerate>=0.20.0
11
+
12
+ # Note: DNA-Diffusion itself should be installed separately using uv:
13
+ # git clone https://github.com/pinellolab/DNA-Diffusion.git
14
+ # cd DNA-Diffusion && uv sync