File size: 40,974 Bytes
c54f817
8435838
c575e18
89b9813
8435838
89b9813
 
 
e363f86
e62ac65
e8de28c
a677593
5a39c97
a2d3df0
b2e55f9
 
fc0bb07
b916cdf
e858976
 
100405c
 
 
 
34e3236
f86d0d6
41df862
04846cf
b2e55f9
 
bb610d0
4c59739
89b9813
ed50e67
c26dd84
f5a1d14
c26dd84
34b3fb7
 
f201f64
 
552b730
 
 
f201f64
9df0e98
 
4013838
 
31c8397
2d60859
04846cf
 
 
 
 
 
 
 
678650b
a34a2cf
7ba4343
ae85fa8
fe31dcf
 
 
 
 
 
 
9df0e98
 
c26dd84
9309dab
100405c
 
 
e071c19
100405c
a677593
0b874ec
100405c
7c553e5
c26dd84
4013838
 
fe31dcf
04846cf
 
 
 
 
f0fd514
04846cf
 
 
 
 
4013838
 
c26dd84
f0fd514
e8de28c
497c0a8
ff78f6c
 
b2e55f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9efab58
 
 
 
 
 
 
 
 
 
 
b2e55f9
 
 
 
 
3208a74
b2e55f9
3208a74
b2e55f9
 
6ee36ca
b2e55f9
c0820d5
6ee36ca
21002d0
4c8934b
 
21002d0
fc0bb07
 
 
b2e55f9
 
 
 
 
 
 
 
 
 
 
 
 
6466ec0
 
 
9725fa4
 
 
e519ed0
6f44b29
e519ed0
 
6466ec0
b2e55f9
 
 
6466ec0
 
b2e55f9
 
 
 
 
f2aec00
9efab58
 
3a81ef8
 
 
 
b2e55f9
 
 
e24540a
 
 
 
e519ed0
6f44b29
e519ed0
 
 
8ce373e
b2e55f9
 
 
fc0bb07
b2e55f9
7ea045b
 
 
b2e55f9
 
fc0bb07
 
 
 
7ea045b
fc0bb07
 
7ea045b
fc0bb07
b2e55f9
 
 
 
 
 
 
 
2f5fac4
005a619
7ea045b
 
e0f0dd9
 
b21a8cb
04846cf
4896769
 
2d60859
 
04846cf
2d60859
 
 
 
 
 
 
 
 
 
b2e55f9
 
 
 
100405c
fc0bb07
 
f0ce987
fc0bb07
b2e55f9
 
 
 
 
 
 
 
 
 
5754a1c
eb2c07e
 
 
 
 
c20bd13
eb2c07e
b2e55f9
 
 
 
c575e18
8435838
 
b2e55f9
c0820d5
 
 
1679b8f
8435838
313bb52
8435838
b2e55f9
 
 
 
e858976
7ea045b
 
 
 
 
 
21002d0
 
 
 
e7ad005
a52dde1
21002d0
e858976
 
 
 
2da2765
 
6454884
 
274b099
 
a52dde1
7ea045b
 
274b099
fff0431
 
 
 
 
 
 
 
274b099
 
 
 
 
 
 
 
 
 
 
 
 
 
a52dde1
274b099
7ea045b
 
 
 
a52dde1
274b099
 
 
7ea045b
 
 
 
 
 
 
274b099
9766588
 
7ea045b
9766588
 
 
 
 
 
 
7ea045b
 
 
9766588
7ea045b
9766588
 
 
 
facfd46
 
7ea045b
facfd46
7ea045b
 
 
facfd46
7ea045b
facfd46
 
 
 
21002d0
 
a52dde1
21002d0
 
 
e7ad005
 
 
 
 
a52dde1
 
 
 
 
 
21002d0
 
 
3f1526b
 
 
21002d0
 
 
 
 
 
 
3f1526b
21002d0
3f1526b
21002d0
3f1526b
 
2225158
 
3f1526b
 
2225158
 
 
 
 
 
3f1526b
2225158
3f1526b
21002d0
3f1526b
21002d0
 
 
 
 
 
3f1526b
 
21002d0
 
 
 
 
3f1526b
21002d0
3f1526b
 
 
21002d0
 
 
3f1526b
21002d0
 
3f1526b
21002d0
 
 
a52dde1
21002d0
6454884
e600968
6454884
e62ac65
0adb69d
2da2765
 
b2e55f9
e858976
 
 
0adb69d
e858976
 
021f723
 
 
b2e55f9
 
021f723
 
b2e55f9
 
 
 
7bbe40c
 
 
21002d0
 
 
a52dde1
 
 
 
8b76adf
 
 
6454884
b2e55f9
 
58da6d7
 
 
b2e55f9
 
58da6d7
 
 
b2e55f9
 
dfecf95
 
 
9fce44e
e519ed0
9fce44e
6f44b29
e519ed0
 
138b189
e519ed0
7ea045b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9309dab
e519ed0
0adb69d
eb2c07e
2da2765
eb2c07e
 
 
e858976
 
9ced953
e858976
b2e55f9
e62ac65
 
 
 
 
0adb69d
c1c036e
 
 
 
 
 
 
 
 
 
c0820d5
 
 
6454884
 
2da2765
016f9e1
 
 
 
6454884
2da2765
 
6454884
0adb69d
2da2765
0adb69d
d20b0fd
6454884
 
c1c036e
 
 
 
 
 
b115af2
 
 
0adb69d
a2d7c6c
6466ec0
 
2da2765
 
0adb69d
d20b0fd
 
 
 
a2d7c6c
d20b0fd
 
 
 
 
 
 
 
 
 
 
 
8b76adf
 
 
d20b0fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6454884
 
 
 
2da2765
6454884
0adb69d
6454884
 
 
 
 
274b099
 
 
21002d0
274b099
 
 
9766588
 
 
21002d0
9766588
 
 
facfd46
 
 
21002d0
facfd46
 
 
7ea045b
 
 
21002d0
7ea045b
 
 
 
 
 
 
6454884
2da2765
 
6454884
c1c036e
 
 
 
 
6454884
 
0adb69d
b953cf8
d20b0fd
0adb69d
 
8435838
e62ac65
 
8435838
e62ac65
c0820d5
 
3f1526b
 
0adb69d
c575e18
8435838
1679b8f
0adb69d
 
e62ac65
 
21002d0
e7ad005
 
21002d0
e7ad005
21002d0
 
 
e7ad005
21002d0
e7ad005
21002d0
e7ad005
 
21002d0
e858976
 
 
 
 
 
 
 
1679b8f
c0820d5
fff0431
c0820d5
 
 
 
 
 
 
 
 
 
fff0431
 
c0820d5
 
fff0431
 
 
 
 
 
c0820d5
 
 
 
 
 
a899811
8b76adf
 
a899811
c0820d5
 
fff0431
c0820d5
 
 
8b76adf
 
c0820d5
 
 
 
 
fff0431
c0820d5
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from typing import List, Tuple
import numpy as np
from PIL import Image, ImageDraw
import base64
import io
import json
import asyncio
from utils import initialize_model, sample_frame
import torch
import os
import time
from typing import Any, Dict
from ldm.models.diffusion.ddpm import LatentDiffusion, DDIMSampler
import concurrent.futures

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


DEBUG_MODE = False
DEBUG_MODE_2 = False
NUM_MAX_FRAMES = 1
TIMESTEPS = 1000
SCREEN_WIDTH = 512
SCREEN_HEIGHT = 384
NUM_SAMPLING_STEPS = 32
USE_RNN = False

MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-384k"

MODEL_NAME = "yuntian-deng/computer-model-noss-forsure"

MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-2k"

MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-10k"

MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-54k"



MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-unfreezernn-160k"
MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-freezernn-origunet-nospatial-368k"
MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-unfreezernn-198k"
MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-freezernn-origunet-nospatial-674k"
MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-freezernn-origunet-nospatial-online-74k"
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-online-70k"
MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-freezernn-origunet-nospatial-online-x0-46k"
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-142k"
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-338k"
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-ddpm32-x0-140k"
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-ddpm32-eps-144k"
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-joint-onlineonly-70k"
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-joint-onlineonly-eps22-40k"
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-joint-onlineonly-22-38k"
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-joint-onlineonly-222-42k"
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-joint-onlineonly-2222-70k"
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-joint-onlineonly-222222-48k"
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-joint-onlineonly-222222k7-06k"
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-joint-onlineonly-222222k7-114k"
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-joint-onlineonly-222222k7-136k"
#MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-joint-onlineonly-222222k7-184k"
#MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-joint-onlineonly-222222k7-272k"
#MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-joint-onlineonly-222222k7-272k"
#MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-joint-oo-eps222222k72-270k"
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-joint-onlineonly-222222k72-108k"


print (f'setting: DEBUG_MODE: {DEBUG_MODE}, DEBUG_MODE_2: {DEBUG_MODE_2}, NUM_MAX_FRAMES: {NUM_MAX_FRAMES}, NUM_SAMPLING_STEPS: {NUM_SAMPLING_STEPS}, MODEL_NAME: {MODEL_NAME}')

with open('latent_stats.json', 'r') as f:
    latent_stats = json.load(f)
DATA_NORMALIZATION = {'mean': torch.tensor(latent_stats['mean']).to(device), 'std': torch.tensor(latent_stats['std']).to(device)}
LATENT_DIMS = (16, SCREEN_HEIGHT // 8, SCREEN_WIDTH // 8)

# Initialize the model at the start of your application
#model = initialize_model("config_csllm.yaml", "yuntian-deng/computer-model")
#model = initialize_model("config_rnn.yaml", "yuntian-deng/computer-model")
#model = initialize_model("config_final_model.yaml", "yuntian-deng/computer-model-noss")
#model = initialize_model("config_final_model.yaml", "yuntian-deng/computer-model")

if 'origunet' in MODEL_NAME:
    if 'x0' in MODEL_NAME and 'eps' not in MODEL_NAME:
        if 'ddpm32' in MODEL_NAME:
            TIMESTEPS = 32
            model = initialize_model("config_final_model_origunet_nospatial_x0_ddpm32.yaml", MODEL_NAME)
        else:
            model = initialize_model("config_final_model_origunet_nospatial_x0.yaml", MODEL_NAME)
    else:
        if 'ddpm32' in MODEL_NAME:
            TIMESTEPS = 32
            model = initialize_model("config_final_model_origunet_nospatial_ddpm32.yaml", MODEL_NAME)
        else:
            model = initialize_model("config_final_model_origunet_nospatial.yaml", MODEL_NAME)
else:
    model = initialize_model("config_final_model.yaml", MODEL_NAME)


model = model.to(device)
#model = torch.compile(model)
padding_image = torch.zeros(*LATENT_DIMS).unsqueeze(0).to(device)
padding_image = (padding_image - DATA_NORMALIZATION['mean'].view(1, -1, 1, 1)) / DATA_NORMALIZATION['std'].view(1, -1, 1, 1)

# Valid keyboard inputs
KEYS = ['\t', '\n', '\r', ' ', '!', '"', '#', '$', '%', '&', "'", '(',
        ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7',
        '8', '9', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`',
        'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o',
        'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~',
        'accept', 'add', 'alt', 'altleft', 'altright', 'apps', 'backspace',
        'browserback', 'browserfavorites', 'browserforward', 'browserhome',
        'browserrefresh', 'browsersearch', 'browserstop', 'capslock', 'clear',
        'convert', 'ctrl', 'ctrlleft', 'ctrlright', 'decimal', 'del', 'delete',
        'divide', 'down', 'end', 'enter', 'esc', 'escape', 'execute', 'f1', 'f10',
        'f11', 'f12', 'f13', 'f14', 'f15', 'f16', 'f17', 'f18', 'f19', 'f2', 'f20',
        'f21', 'f22', 'f23', 'f24', 'f3', 'f4', 'f5', 'f6', 'f7', 'f8', 'f9',
        'final', 'fn', 'hanguel', 'hangul', 'hanja', 'help', 'home', 'insert', 'junja',
        'kana', 'kanji', 'launchapp1', 'launchapp2', 'launchmail',
        'launchmediaselect', 'left', 'modechange', 'multiply', 'nexttrack',
        'nonconvert', 'num0', 'num1', 'num2', 'num3', 'num4', 'num5', 'num6',
        'num7', 'num8', 'num9', 'numlock', 'pagedown', 'pageup', 'pause', 'pgdn',
        'pgup', 'playpause', 'prevtrack', 'print', 'printscreen', 'prntscrn',
        'prtsc', 'prtscr', 'return', 'right', 'scrolllock', 'select', 'separator',
        'shift', 'shiftleft', 'shiftright', 'sleep', 'space', 'stop', 'subtract', 'tab',
        'up', 'volumedown', 'volumemute', 'volumeup', 'win', 'winleft', 'winright', 'yen',
        'command', 'option', 'optionleft', 'optionright']

KEYMAPPING = {
    'arrowup': 'up',
    'arrowdown': 'down',
    'arrowleft': 'left',
    'arrowright': 'right',
    'meta': 'command',
    'contextmenu': 'apps',
    'control': 'ctrl',
}

INVALID_KEYS = ['f13', 'f14', 'f15', 'f16', 'f17', 'f18', 'f19', 'f20',
                'f21', 'f22', 'f23', 'f24', 'select', 'separator', 'execute']
VALID_KEYS = [key for key in KEYS if key not in INVALID_KEYS]
itos = VALID_KEYS
stoi = {key: i for i, key in enumerate(itos)}

app = FastAPI()

# Mount the static directory to serve HTML, JavaScript, and CSS files
app.mount("/static", StaticFiles(directory="static"), name="static")

# Add this at the top with other global variables
connection_counter = 0

# Connection timeout settings
CONNECTION_TIMEOUT = 20 + 1  # 20 seconds timeout plus 1 second grace period
WARNING_TIME = 10 + 1 # 10 seconds warning before timeout plus 1 second grace period

# Create a thread pool executor
thread_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)

def prepare_model_inputs(
    previous_frame: torch.Tensor,
    hidden_states: Any,
    x: int,
    y: int,
    right_click: bool,
    left_click: bool,
    keys_down: List[str],
    stoi: Dict[str, int],
    itos: List[str],
    time_step: int
) -> Dict[str, torch.Tensor]:
    """Prepare inputs for the model."""
    # Clamp coordinates to valid ranges
    x = min(max(0, x), SCREEN_WIDTH - 1) if x is not None else 0
    y = min(max(0, y), SCREEN_HEIGHT - 1) if y is not None else 0
    if DEBUG_MODE:
        print ('DEBUG MODE, SETTING TIME STEP TO 0')
        time_step = 0
    if DEBUG_MODE_2:
        if time_step > NUM_MAX_FRAMES-1:
            print ('DEBUG MODE_2, SETTING TIME STEP TO 0')
            time_step = 0
    
    inputs = {
        'image_features': previous_frame.to(device),
        'is_padding': torch.BoolTensor([time_step == 0]).to(device),
        'x': torch.LongTensor([x]).unsqueeze(0).to(device),
        'y': torch.LongTensor([y]).unsqueeze(0).to(device),
        'is_leftclick': torch.BoolTensor([left_click]).unsqueeze(0).to(device),
        'is_rightclick': torch.BoolTensor([right_click]).unsqueeze(0).to(device),
        'key_events': torch.zeros(len(itos), dtype=torch.long).to(device)
    }
    for key in keys_down:
        key = key.lower()
        if key in KEYMAPPING:
            key = KEYMAPPING[key]
        if key in stoi:
            inputs['key_events'][stoi[key]] = 1
        else:
            print (f'Key {key} not found in stoi')
    
    if hidden_states is not None:
        inputs['hidden_states'] = hidden_states
    if DEBUG_MODE:
        print ('DEBUG MODE, REMOVING INPUTS')
        if 'hidden_states' in inputs:
            del inputs['hidden_states']
    if DEBUG_MODE_2:
        if time_step > NUM_MAX_FRAMES-1:
            print ('DEBUG MODE_2, REMOVING HIDDEN STATES')
            if 'hidden_states' in inputs:
                del inputs['hidden_states']
    print (f'Time step: {time_step}')
    return inputs

@torch.no_grad()
async def process_frame(
    model: LatentDiffusion,
    inputs: Dict[str, torch.Tensor],
    use_rnn: bool = False,
    num_sampling_steps: int = 32
) -> Tuple[torch.Tensor, np.ndarray, Any, Dict[str, float]]:
    """Process a single frame through the model."""
    # Run the heavy computation in a separate thread
    loop = asyncio.get_running_loop()
    return await loop.run_in_executor(
        thread_executor,
        lambda: _process_frame_sync(model, inputs, use_rnn, num_sampling_steps)
    )

def _process_frame_sync(model, inputs, use_rnn, num_sampling_steps):
    """Synchronous version of process_frame that runs in a thread"""
    timing = {}
    # Temporal encoding
    start = time.perf_counter()
    output_from_rnn, hidden_states = model.temporal_encoder.forward_step(inputs)
    timing['temporal_encoder'] = time.perf_counter() - start
    
    # UNet sampling
    start = time.perf_counter()
    print (f"model.clip_denoised: {model.clip_denoised}")
    model.clip_denoised = False
    print (f"USE_RNN: {use_rnn}, NUM_SAMPLING_STEPS: {num_sampling_steps}")
    if use_rnn:
        sample_latent = output_from_rnn[:, :16]
    else:
        #NUM_SAMPLING_STEPS = 8
        if num_sampling_steps >= TIMESTEPS:
            sample_latent = model.p_sample_loop(cond={'c_concat': output_from_rnn}, shape=[1, *LATENT_DIMS], return_intermediates=False, verbose=True)
        else:
            if num_sampling_steps == 1:
                x = torch.randn([1, *LATENT_DIMS], device=device)
                t = torch.full((1,), TIMESTEPS-1, device=device, dtype=torch.long)
                sample_latent = model.apply_model(x, t, {'c_concat': output_from_rnn})
            else:
                sampler = DDIMSampler(model)
                sample_latent, _ = sampler.sample(
                    S=num_sampling_steps,
                    conditioning={'c_concat': output_from_rnn},
                    batch_size=1,
                    shape=LATENT_DIMS,
                    verbose=False
                )
    timing['unet'] = time.perf_counter() - start
    
    # Decoding
    start = time.perf_counter()
    sample = sample_latent * DATA_NORMALIZATION['std'].view(1, -1, 1, 1) + DATA_NORMALIZATION['mean'].view(1, -1, 1, 1)
    
    # Use time.sleep(10) here since it's in a separate thread
    #time.sleep(10)
    
    sample = model.decode_first_stage(sample)
    sample = sample.squeeze(0).clamp(-1, 1)
    timing['decode'] = time.perf_counter() - start
    
    # Convert to image
    sample_img = ((sample[:3].transpose(0,1).transpose(1,2).cpu().float().numpy() + 1) * 127.5).astype(np.uint8)
    
    timing['total'] = sum(timing.values())
    
    return sample_latent, sample_img, hidden_states, timing

def print_timing_stats(timing_info: Dict[str, float], frame_num: int):
    """Print timing statistics for a frame."""
    print(f"\nFrame {frame_num} timing (seconds):")
    for key, value in timing_info.items():
        print(f"  {key.title()}: {value:.4f}")
    print(f"  FPS: {1.0/timing_info['full_frame']:.2f}")

# Serve the index.html file at the root URL
@app.get("/")
async def get():
    return HTMLResponse(open("static/index.html").read())

# WebSocket endpoint for continuous user interaction
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):    
    global connection_counter
    connection_counter += 1
    client_id = f"{int(time.time())}_{connection_counter}"
    print(f"New WebSocket connection: {client_id}")
    await websocket.accept()

    try:
        previous_frame = padding_image
        hidden_states = None
        keys_down = set()  # Initialize as an empty set
        frame_num = -1
        
        # Client-specific settings
        client_settings = {
            "use_rnn": USE_RNN,  # Start with default global value
            "sampling_steps": NUM_SAMPLING_STEPS  # Start with default global value
        }
        
        # Connection timeout tracking
        last_user_activity_time = time.perf_counter()
        timeout_warning_sent = False
        timeout_task = None
        connection_active = True  # Flag to track if connection is still active
        user_has_interacted = False  # Flag to track if user has started interacting
        
        # Start timing for global FPS calculation
        connection_start_time = time.perf_counter()
        frame_count = 0
        
        # Input queue management - use asyncio.Queue instead of a list
        input_queue = asyncio.Queue()
        is_processing = False
        
        # Add a function to reset the simulation
        async def reset_simulation():
            nonlocal previous_frame, hidden_states, keys_down, frame_num, is_processing, input_queue, user_has_interacted
            # Keep the client settings during reset
            temp_client_settings = client_settings.copy()
            
            # Log the reset action
            log_interaction(
                client_id, 
                {"type": "reset"}, 
                is_end_of_session=False,
                is_reset=True  # Add this parameter to the log_interaction function
            )
            
            # Clear the input queue
            while not input_queue.empty():
                try:
                    input_queue.get_nowait()
                    input_queue.task_done()
                except asyncio.QueueEmpty:
                    break
            
            # Reset all state variables
            previous_frame = padding_image
            hidden_states = None
            keys_down = set()
            frame_num = -1
            is_processing = False
            user_has_interacted = False  # Reset user interaction state
            
            # Restore client settings
            client_settings.update(temp_client_settings)
            
            print(f"[{time.perf_counter():.3f}] Simulation reset to initial state (preserved settings: USE_RNN={client_settings['use_rnn']}, SAMPLING_STEPS={client_settings['sampling_steps']})")
            print(f"[{time.perf_counter():.3f}] User interaction state reset - waiting for user to interact again")
            
            # Send confirmation to client
            await websocket.send_json({"type": "reset_confirmed"})
            
            # Also send the current settings to update the UI
            await websocket.send_json({
                "type": "settings",
                "sampling_steps": client_settings["sampling_steps"],
                "use_rnn": client_settings["use_rnn"]
            })
        
        # Add a function to update sampling steps
        async def update_sampling_steps(steps):
            nonlocal client_settings
            
            # Validate the input
            if steps < 1:
                print(f"[{time.perf_counter():.3f}] Invalid sampling steps value: {steps}")
                await websocket.send_json({"type": "error", "message": "Invalid sampling steps value"})
                return
                
            # Update the client-specific setting
            old_steps = client_settings["sampling_steps"]
            client_settings["sampling_steps"] = steps
            
            print(f"[{time.perf_counter():.3f}] Updated sampling steps for client {client_id} from {old_steps} to {steps}")
            
            # Send confirmation to client
            await websocket.send_json({"type": "steps_updated", "steps": steps})
        
        # Add a function to update USE_RNN setting
        async def update_use_rnn(use_rnn):
            nonlocal client_settings
            
            # Update the client-specific setting
            old_setting = client_settings["use_rnn"]
            client_settings["use_rnn"] = use_rnn
            
            print(f"[{time.perf_counter():.3f}] Updated USE_RNN for client {client_id} from {old_setting} to {use_rnn}")
            
            # Send confirmation to client
            await websocket.send_json({"type": "rnn_updated", "use_rnn": use_rnn})
        
        # Add timeout checking function
        async def check_timeout():
            nonlocal timeout_warning_sent, timeout_task, connection_active, user_has_interacted
            
            while True:
                try:
                    # Check if WebSocket is still connected and connection is still active
                    if not connection_active or websocket.client_state.value >= 2:  # CLOSING or CLOSED
                        print(f"[{time.perf_counter():.3f}] Connection inactive or WebSocket closed, stopping timeout check for client {client_id}")
                        return
                    
                    # Don't start timeout tracking until user has actually interacted
                    if not user_has_interacted:
                        print(f"[{time.perf_counter():.3f}] User hasn't interacted yet, skipping timeout check for client {client_id}")
                        await asyncio.sleep(1)  # Check every second
                        continue
                    
                    current_time = time.perf_counter()
                    time_since_activity = current_time - last_user_activity_time
                    
                    print(f"[{current_time:.3f}] Timeout check - time_since_activity: {time_since_activity:.1f}s, WARNING_TIME: {WARNING_TIME}s, CONNECTION_TIMEOUT: {CONNECTION_TIMEOUT}s")
                    
                    # Send warning at 10 seconds
                    if time_since_activity >= WARNING_TIME and not timeout_warning_sent:
                        print(f"[{current_time:.3f}] Sending timeout warning to client {client_id}")
                        await websocket.send_json({
                            "type": "timeout_warning",
                            "timeout_in": CONNECTION_TIMEOUT - WARNING_TIME
                        })
                        timeout_warning_sent = True
                        print(f"[{current_time:.3f}] Timeout warning sent, timeout_warning_sent: {timeout_warning_sent}")
                    
                    # Close connection at 20 seconds
                    if time_since_activity >= CONNECTION_TIMEOUT:
                        print(f"[{current_time:.3f}] TIMEOUT REACHED! Closing connection {client_id} due to timeout")
                        print(f"[{current_time:.3f}] time_since_activity: {time_since_activity:.1f}s >= CONNECTION_TIMEOUT: {CONNECTION_TIMEOUT}s")
                        
                        # Clear the input queue before closing
                        queue_size_before = input_queue.qsize()
                        print(f"[{current_time:.3f}] Clearing input queue, size before: {queue_size_before}")
                        while not input_queue.empty():
                            try:
                                input_queue.get_nowait()
                                input_queue.task_done()
                            except asyncio.QueueEmpty:
                                break
                        print(f"[{current_time:.3f}] Input queue cleared, size after: {input_queue.qsize()}")
                        
                        print(f"[{current_time:.3f}] About to close WebSocket connection...")
                        await websocket.close(code=1000, reason="User inactivity timeout")
                        print(f"[{current_time:.3f}] WebSocket.close() called, returning from check_timeout")
                        return
                    
                    await asyncio.sleep(1)  # Check every second
                    
                except Exception as e:
                    print(f"[{time.perf_counter():.3f}] Error in timeout check for client {client_id}: {e}")
                    import traceback
                    traceback.print_exc()
                    break
        
        # Function to update user activity
        def update_user_activity():
            nonlocal last_user_activity_time, timeout_warning_sent
            old_time = last_user_activity_time
            last_user_activity_time = time.perf_counter()
            print(f"[{time.perf_counter():.3f}] User activity detected for client {client_id}")
            print(f"[{time.perf_counter():.3f}] last_user_activity_time updated: {old_time:.3f} -> {last_user_activity_time:.3f}")
            
            if timeout_warning_sent:
                print(f"[{time.perf_counter():.3f}] User activity detected, resetting timeout warning for client {client_id}")
                timeout_warning_sent = False
                print(f"[{time.perf_counter():.3f}] timeout_warning_sent reset to: {timeout_warning_sent}")
                # Send activity reset notification to client
                asyncio.create_task(websocket.send_json({"type": "activity_reset"}))
                print(f"[{time.perf_counter():.3f}] Activity reset message sent to client")
        
        # Start timeout checking
        timeout_task = asyncio.create_task(check_timeout())
        print(f"[{time.perf_counter():.3f}] Timeout task started for client {client_id} (waiting for user interaction)")
        
        async def process_input(data):
            nonlocal previous_frame, hidden_states, keys_down, frame_num, frame_count, is_processing, user_has_interacted
            
            try:
                process_start_time = time.perf_counter()
                queue_size = input_queue.qsize()
                print(f"[{process_start_time:.3f}] Starting to process input. Queue size before: {queue_size}")
                frame_num += 1
                frame_count += 1  # Increment total frame counter
                
                # Calculate global FPS
                total_elapsed = process_start_time - connection_start_time
                global_fps = frame_count / total_elapsed if total_elapsed > 0 else 0
                
                # change x and y to be between 0 and width/height-1 in data
                data['x'] = max(0, min(data['x'], SCREEN_WIDTH - 1))
                data['y'] = max(0, min(data['y'], SCREEN_HEIGHT - 1))
                x = data.get("x")
                y = data.get("y")
                assert 0 <= x < SCREEN_WIDTH, f"x: {x} is out of range"
                assert 0 <= y < SCREEN_HEIGHT, f"y: {y} is out of range"
                is_left_click = data.get("is_left_click")
                is_right_click = data.get("is_right_click")
                keys_down_list = data.get("keys_down", [])  # Get as list
                keys_up_list = data.get("keys_up", [])
                is_auto_input = data.get("is_auto_input", False)
                if is_auto_input:
                    print (f'[{time.perf_counter():.3f}] Auto-input detected')
                else:
                    # Update user activity for non-auto inputs
                    update_user_activity()
                    # Mark that user has started interacting
                    if not user_has_interacted:
                        user_has_interacted = True
                        print(f"[{time.perf_counter():.3f}] User has started interacting with canvas for client {client_id}")
                wheel_delta_x = data.get("wheel_delta_x", 0)
                wheel_delta_y = data.get("wheel_delta_y", 0)
                print(f'[{time.perf_counter():.3f}] Processing: x: {x}, y: {y}, is_left_click: {is_left_click}, is_right_click: {is_right_click}, keys_down_list: {keys_down_list}, keys_up_list: {keys_up_list}, wheel: ({wheel_delta_x},{wheel_delta_y}), time_since_activity: {time.perf_counter() - last_user_activity_time:.3f}')
                
                # Update the set based on the received data
                for key in keys_down_list:
                    key = key.lower()
                    if key in KEYMAPPING:
                        key = KEYMAPPING[key]
                    keys_down.add(key)
                for key in keys_up_list:
                    key = key.lower()
                    if key in KEYMAPPING:
                        key = KEYMAPPING[key]
                    if key in keys_down:  # Check if key exists to avoid KeyError
                        keys_down.remove(key)
                if DEBUG_MODE:
                    print (f"DEBUG MODE, REMOVING HIDDEN STATES")
                    previous_frame = padding_image

                if DEBUG_MODE_2:
                    print (f'dsfdasdf frame_num: {frame_num}')
                    if frame_num > NUM_MAX_FRAMES-1:
                        print (f"DEBUG MODE_2, REMOVING HIDDEN STATES")
                        previous_frame = padding_image
                        frame_num = 0
                inputs = prepare_model_inputs(previous_frame, hidden_states, x, y, is_right_click, is_left_click, list(keys_down), stoi, itos, frame_num)
                
                # Use client-specific settings
                client_use_rnn = client_settings["use_rnn"]
                client_sampling_steps = client_settings["sampling_steps"]
                
                print(f"[{time.perf_counter():.3f}] Starting model inference with client settings - USE_RNN: {client_use_rnn}, SAMPLING_STEPS: {client_sampling_steps}...")
                
                # Pass client-specific settings to process_frame
                previous_frame, sample_img, hidden_states, timing_info = await process_frame(
                    model, 
                    inputs, 
                    use_rnn=client_use_rnn, 
                    num_sampling_steps=client_sampling_steps
                )
                
                print (f'Client {client_id} settings: USE_RNN: {client_use_rnn}, SAMPLING_STEPS: {client_sampling_steps}')

                
                timing_info['full_frame'] = time.perf_counter() - process_start_time
                
                print(f"[{time.perf_counter():.3f}] Model inference complete. Queue size now: {input_queue.qsize()}")
                # Use the provided function to print timing statistics
                print_timing_stats(timing_info, frame_num)
                
                # Print global FPS measurement
                print(f"  Global FPS: {global_fps:.2f} (total: {frame_count} frames in {total_elapsed:.2f}s)")
                
                
                img = Image.fromarray(sample_img)
                buffered = io.BytesIO()
                img.save(buffered, format="PNG")
                img_str = base64.b64encode(buffered.getvalue()).decode()
                
                # Send the generated frame back to the client
                print(f"[{time.perf_counter():.3f}] Sending image to client...")
                try:
                    await websocket.send_json({"image": img_str})
                    print(f"[{time.perf_counter():.3f}] Image sent. Queue size before next_input: {input_queue.qsize()}")
                except RuntimeError as e:
                    if "Cannot call 'send' once a close message has been sent" in str(e):
                        print(f"[{time.perf_counter():.3f}] WebSocket closed, skipping image send")
                    else:
                        raise e
                except Exception as e:
                    print(f"[{time.perf_counter():.3f}] Error sending image: {e}")

                # Log the input
                log_interaction(client_id, data, generated_frame=sample_img)
            finally:
                is_processing = False
                print(f"[{time.perf_counter():.3f}] Processing complete. Queue size before checking next input: {input_queue.qsize()}")
                # Check if we have more inputs to process after this one
                if not input_queue.empty():
                    print(f"[{time.perf_counter():.3f}] Queue not empty, processing next input")
                    asyncio.create_task(process_next_input())
        
        async def process_next_input():
            nonlocal is_processing
            
            current_time = time.perf_counter()
            if input_queue.empty():
                print(f"[{current_time:.3f}] No inputs to process. Queue is empty.")
                is_processing = False
                return
            
            # Check if WebSocket is still open by checking if it's in a closed state
            if websocket.client_state.value >= 2:  # CLOSING or CLOSED
                print(f"[{current_time:.3f}] WebSocket in closed state ({websocket.client_state.value}), stopping processing")
                is_processing = False
                return
            
            #if is_processing:
            #    print(f"[{current_time:.3f}] Already processing an input. Will check again later.")
            #    return
            
            # Set is_processing to True before proceeding
            is_processing = True
            
            queue_size = input_queue.qsize()
            print(f"[{current_time:.3f}] Processing next input. Queue size: {queue_size}")
            
            try:
                # Initialize variables to track progress
                skipped = 0
                latest_input = None
                
                # Process the queue one item at a time
                while not input_queue.empty():
                    current_input = await input_queue.get()
                    input_queue.task_done()
                    
                    # Always update the latest input
                    latest_input = current_input
                    
                    # Check if this is an interesting event
                    is_interesting = (current_input.get("is_left_click") or 
                                      current_input.get("is_right_click") or 
                                      (current_input.get("keys_down") and len(current_input.get("keys_down")) > 0) or 
                                      (current_input.get("keys_up") and len(current_input.get("keys_up")) > 0) or
                                      current_input.get("wheel_delta_x", 0) != 0 or
                                      current_input.get("wheel_delta_y", 0) != 0)
                    
                    # Process immediately if interesting
                    if is_interesting:
                        print(f"[{current_time:.3f}] Found interesting input (skipped {skipped} events)")
                        await process_input(current_input)  # AWAIT here instead of creating a task
                        is_processing = False
                        return
                    
                    # Otherwise, continue to the next item
                    skipped += 1
                    
                    # If this is the last item and no interesting inputs were found
                    if input_queue.empty():
                        print(f"[{current_time:.3f}] No interesting inputs, processing latest movement (skipped {skipped-1} events)")
                        await process_input(latest_input)  # AWAIT here instead of creating a task
                        is_processing = False
                        return
            except Exception as e:
                print(f"[{current_time:.3f}] Error in process_next_input: {e}")
                import traceback
                traceback.print_exc()
                is_processing = False  # Make sure to reset on error
        
        while True:
            try:
                # Receive user input
                print(f"[{time.perf_counter():.3f}] Waiting for input... Queue size: {input_queue.qsize()}, is_processing: {is_processing}")
                data = await websocket.receive_json()
                receive_time = time.perf_counter()
                
                if data.get("type") == "heartbeat":
                    await websocket.send_json({"type": "heartbeat_response"})
                    continue
                
                # Handle reset command
                if data.get("type") == "reset":
                    print(f"[{receive_time:.3f}] Received reset command")
                    update_user_activity()  # Reset activity timer
                    await reset_simulation()
                    continue
                
                # Handle sampling steps update
                if data.get("type") == "update_sampling_steps":
                    print(f"[{receive_time:.3f}] Received request to update sampling steps")
                    update_user_activity()  # Reset activity timer
                    await update_sampling_steps(data.get("steps", 32))
                    continue
                
                # Handle USE_RNN update
                if data.get("type") == "update_use_rnn":
                    print(f"[{receive_time:.3f}] Received request to update USE_RNN")
                    update_user_activity()  # Reset activity timer
                    await update_use_rnn(data.get("use_rnn", False))
                    continue
                
                # Handle settings request
                if data.get("type") == "get_settings":
                    print(f"[{receive_time:.3f}] Received request for current settings")
                    update_user_activity()  # Reset activity timer
                    await websocket.send_json({
                        "type": "settings",
                        "sampling_steps": client_settings["sampling_steps"],
                        "use_rnn": client_settings["use_rnn"]
                    })
                    continue
                
                # Add the input to our queue
                await input_queue.put(data)
                print(f"[{receive_time:.3f}] Received input. Queue size now: {input_queue.qsize()}")
                
                # Check if WebSocket is still open before processing
                if websocket.client_state.value >= 2:  # CLOSING or CLOSED
                    print(f"[{receive_time:.3f}] WebSocket closed, skipping processing")
                    continue
                
                # If we're not currently processing, start processing this input
                if not is_processing:
                    print(f"[{receive_time:.3f}] Not currently processing, will call process_next_input()")
                    is_processing = True
                    asyncio.create_task(process_next_input())  # Create task but don't await it
                else:
                    print(f"[{receive_time:.3f}] Currently processing, new input queued for later")
            
            except asyncio.TimeoutError:
                print("WebSocket connection timed out")
            
            except WebSocketDisconnect:
                # Log final EOS entry
                log_interaction(client_id, {}, is_end_of_session=True)
                print(f"[{time.perf_counter():.3f}] WebSocket disconnected: {client_id}")
                print(f"[{time.perf_counter():.3f}] WebSocketDisconnect exception caught")
                break

    except Exception as e:
        print(f"Error in WebSocket connection {client_id}: {e}")
        import traceback
        traceback.print_exc()
    
    finally:
        # Clean up timeout task
        print(f"[{time.perf_counter():.3f}] Cleaning up connection {client_id}")
        connection_active = False  # Signal that connection is being cleaned up
        if timeout_task and not timeout_task.done():
            print(f"[{time.perf_counter():.3f}] Cancelling timeout task for client {client_id}")
            timeout_task.cancel()
            try:
                await timeout_task
                print(f"[{time.perf_counter():.3f}] Timeout task cancelled successfully for client {client_id}")
            except asyncio.CancelledError:
                print(f"[{time.perf_counter():.3f}] Timeout task cancelled with CancelledError for client {client_id}")
                pass
        else:
            print(f"[{time.perf_counter():.3f}] Timeout task already done or doesn't exist for client {client_id}")
        
        # Print final FPS statistics when connection ends
        if frame_num >= 0:  # Only if we processed at least one frame
            total_time = time.perf_counter() - connection_start_time
            print(f"\nConnection {client_id} summary:")
            print(f"  Total frames processed: {frame_count}")
            print(f"  Total elapsed time: {total_time:.2f} seconds")
            print(f"  Average FPS: {frame_count/total_time:.2f}")
        
        print(f"WebSocket connection closed: {client_id}")

def log_interaction(client_id, data, generated_frame=None, is_end_of_session=False, is_reset=False):
    """Log user interaction and optionally the generated frame."""
    timestamp = time.time()
    
    # Create directory structure if it doesn't exist
    os.makedirs("interaction_logs", exist_ok=True)
    
    # Structure the log entry
    log_entry = {
        "timestamp": timestamp,
        "client_id": client_id,
        "is_eos": is_end_of_session,
        "is_reset": is_reset
    }
    
    # Include type if present (for reset, etc.)
    if data.get("type"):
        log_entry["type"] = data.get("type")
    
    # Only include input data if this isn't just a control message
    if not is_end_of_session and not is_reset:
        log_entry["inputs"] = {
            "x": data.get("x"),
            "y": data.get("y"),
            "is_left_click": data.get("is_left_click"),
            "is_right_click": data.get("is_right_click"),
            "keys_down": data.get("keys_down", []),
            "keys_up": data.get("keys_up", []),
            "wheel_delta_x": data.get("wheel_delta_x", 0),
            "wheel_delta_y": data.get("wheel_delta_y", 0),
            "is_auto_input": data.get("is_auto_input", False)
        }
    else:
        # For EOS/reset records, just include minimal info
        log_entry["inputs"] = None
    
    # Save to a file (one file per session)
    if not os.path.exists("interaction_logs"):
        os.makedirs("interaction_logs", exist_ok=True)
    session_file = f"interaction_logs/session_{client_id}.jsonl"
    with open(session_file, "a") as f:
        f.write(json.dumps(log_entry) + "\n")
    
    # Optionally save the frame if provided
    if generated_frame is not None and not is_end_of_session and not is_reset:
        frame_dir = f"interaction_logs/frames_{client_id}"
        os.makedirs(frame_dir, exist_ok=True)
        frame_file = f"{frame_dir}/{timestamp:.6f}.png"
        # Save the frame as PNG
        Image.fromarray(generated_frame).save(frame_file)