File size: 16,402 Bytes
c54f817
8435838
c575e18
89b9813
8435838
89b9813
 
 
e62ac65
e8de28c
a677593
5a39c97
a2d3df0
b2e55f9
 
fc0bb07
b916cdf
e858976
 
b2e55f9
 
 
 
 
 
 
5a8e3de
89b9813
b2e55f9
a677593
0b874ec
e77b83d
0b874ec
e8de28c
 
497c0a8
3208a74
b2e55f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3208a74
b2e55f9
3208a74
b2e55f9
 
6ee36ca
b2e55f9
6ee36ca
fc0bb07
 
 
b2e55f9
 
 
 
 
 
 
 
 
 
 
 
 
6466ec0
 
 
 
b2e55f9
 
 
6466ec0
 
b2e55f9
 
 
 
 
f2aec00
b2e55f9
 
 
 
 
 
 
 
fc0bb07
b2e55f9
 
 
 
fc0bb07
 
 
 
 
 
 
 
 
b2e55f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc0bb07
 
9ced953
fc0bb07
b2e55f9
 
 
 
 
 
 
 
 
 
5754a1c
eb2c07e
 
 
 
 
c20bd13
eb2c07e
b2e55f9
 
 
 
c575e18
8435838
 
b2e55f9
1679b8f
 
8435838
313bb52
8435838
b2e55f9
 
 
 
e858976
 
 
 
 
2da2765
 
6454884
 
 
 
 
e62ac65
0adb69d
2da2765
 
b2e55f9
e858976
 
 
0adb69d
e858976
 
b2e55f9
 
 
 
 
 
0adb69d
6454884
b2e55f9
 
 
 
 
 
6454884
b2e55f9
0adb69d
fc0bb07
0adb69d
eb2c07e
2da2765
eb2c07e
 
 
e858976
 
9ced953
e858976
b2e55f9
e62ac65
 
 
 
 
0adb69d
e62ac65
2da2765
6454884
 
2da2765
016f9e1
 
 
 
6454884
2da2765
 
6454884
0adb69d
2da2765
0adb69d
d20b0fd
6454884
 
b115af2
 
 
0adb69d
a2d7c6c
6466ec0
 
2da2765
 
0adb69d
d20b0fd
 
 
 
a2d7c6c
d20b0fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6454884
 
 
 
2da2765
6454884
0adb69d
6454884
 
 
 
 
 
2da2765
 
6454884
 
 
0adb69d
b953cf8
d20b0fd
0adb69d
 
8435838
e62ac65
 
8435838
e62ac65
 
0adb69d
c575e18
8435838
1679b8f
0adb69d
 
e62ac65
 
e858976
 
 
 
 
 
 
 
1679b8f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from typing import List, Tuple
import numpy as np
from PIL import Image, ImageDraw
import base64
import io
import asyncio
from utils import initialize_model, sample_frame
import torch
import os
import time
from typing import Any, Dict
from ldm.models.diffusion.ddpm import LatentDiffusion, DDIMSampler
import concurrent.futures

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
SCREEN_WIDTH = 512
SCREEN_HEIGHT = 384
NUM_SAMPLING_STEPS = 8
DATA_NORMALIZATION = {
    'mean': -0.54,
    'std': 6.78,
}
LATENT_DIMS = (4, SCREEN_HEIGHT // 8, SCREEN_WIDTH // 8)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Initialize the model at the start of your application
#model = initialize_model("config_csllm.yaml", "yuntian-deng/computer-model")
model = initialize_model("config_rnn.yaml", "yuntian-deng/computer-model")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
#model = torch.compile(model)

padding_image = torch.zeros(1, SCREEN_HEIGHT // 8, SCREEN_WIDTH // 8, 4)
padding_image = (padding_image - DATA_NORMALIZATION['mean']) / DATA_NORMALIZATION['std']
padding_image = padding_image.to(device)

# Valid keyboard inputs
KEYS = ['\t', '\n', '\r', ' ', '!', '"', '#', '$', '%', '&', "'", '(',
        ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7',
        '8', '9', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`',
        'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o',
        'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~',
        'accept', 'add', 'alt', 'altleft', 'altright', 'apps', 'backspace',
        'browserback', 'browserfavorites', 'browserforward', 'browserhome',
        'browserrefresh', 'browsersearch', 'browserstop', 'capslock', 'clear',
        'convert', 'ctrl', 'ctrlleft', 'ctrlright', 'decimal', 'del', 'delete',
        'divide', 'down', 'end', 'enter', 'esc', 'escape', 'execute', 'f1', 'f10',
        'f11', 'f12', 'f13', 'f14', 'f15', 'f16', 'f17', 'f18', 'f19', 'f2', 'f20',
        'f21', 'f22', 'f23', 'f24', 'f3', 'f4', 'f5', 'f6', 'f7', 'f8', 'f9',
        'final', 'fn', 'hanguel', 'hangul', 'hanja', 'help', 'home', 'insert', 'junja',
        'kana', 'kanji', 'launchapp1', 'launchapp2', 'launchmail',
        'launchmediaselect', 'left', 'modechange', 'multiply', 'nexttrack',
        'nonconvert', 'num0', 'num1', 'num2', 'num3', 'num4', 'num5', 'num6',
        'num7', 'num8', 'num9', 'numlock', 'pagedown', 'pageup', 'pause', 'pgdn',
        'pgup', 'playpause', 'prevtrack', 'print', 'printscreen', 'prntscrn',
        'prtsc', 'prtscr', 'return', 'right', 'scrolllock', 'select', 'separator',
        'shift', 'shiftleft', 'shiftright', 'sleep', 'space', 'stop', 'subtract', 'tab',
        'up', 'volumedown', 'volumemute', 'volumeup', 'win', 'winleft', 'winright', 'yen',
        'command', 'option', 'optionleft', 'optionright']
INVALID_KEYS = ['f13', 'f14', 'f15', 'f16', 'f17', 'f18', 'f19', 'f20',
                'f21', 'f22', 'f23', 'f24', 'select', 'separator', 'execute']
VALID_KEYS = [key for key in KEYS if key not in INVALID_KEYS]
itos = VALID_KEYS
stoi = {key: i for i, key in enumerate(itos)}

app = FastAPI()

# Mount the static directory to serve HTML, JavaScript, and CSS files
app.mount("/static", StaticFiles(directory="static"), name="static")

# Add this at the top with other global variables

# Create a thread pool executor
thread_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)

def prepare_model_inputs(
    previous_frame: torch.Tensor,
    hidden_states: Any,
    x: int,
    y: int,
    right_click: bool,
    left_click: bool,
    keys_down: List[str],
    stoi: Dict[str, int],
    itos: List[str],
    time_step: int
) -> Dict[str, torch.Tensor]:
    """Prepare inputs for the model."""
    # Clamp coordinates to valid ranges
    x = min(max(0, x), SCREEN_WIDTH - 1) if x is not None else 0
    y = min(max(0, y), SCREEN_HEIGHT - 1) if y is not None else 0
    
    inputs = {
        'image_features': previous_frame.to(device),
        'is_padding': torch.BoolTensor([time_step == 0]).to(device),
        'x': torch.LongTensor([x]).unsqueeze(0).to(device),
        'y': torch.LongTensor([y]).unsqueeze(0).to(device),
        'is_leftclick': torch.BoolTensor([left_click]).unsqueeze(0).to(device),
        'is_rightclick': torch.BoolTensor([right_click]).unsqueeze(0).to(device),
        'key_events': torch.zeros(len(itos), dtype=torch.long).to(device)
    }
    for key in keys_down:
        key = key.lower()
        inputs['key_events'][stoi[key]] = 1
    
    if hidden_states is not None:
        inputs['hidden_states'] = hidden_states
    
    return inputs

@torch.no_grad()
async def process_frame(
    model: LatentDiffusion,
    inputs: Dict[str, torch.Tensor]
) -> Tuple[torch.Tensor, np.ndarray, Any, Dict[str, float]]:
    """Process a single frame through the model."""
    # Run the heavy computation in a separate thread
    loop = asyncio.get_running_loop()
    return await loop.run_in_executor(
        thread_executor,
        lambda: _process_frame_sync(model, inputs)
    )

def _process_frame_sync(model, inputs):
    """Synchronous version of process_frame that runs in a thread"""
    timing = {}
    # Temporal encoding
    start = time.perf_counter()
    output_from_rnn, hidden_states = model.temporal_encoder.forward_step(inputs)
    timing['temporal_encoder'] = time.perf_counter() - start
    
    # UNet sampling
    start = time.perf_counter()
    sampler = DDIMSampler(model)
    sample_latent, _ = sampler.sample(
        S=NUM_SAMPLING_STEPS,
        conditioning={'c_concat': output_from_rnn},
        batch_size=1,
        shape=LATENT_DIMS,
        verbose=False
    )
    timing['unet'] = time.perf_counter() - start
    
    # Decoding
    start = time.perf_counter()
    sample = sample_latent * DATA_NORMALIZATION['std'] + DATA_NORMALIZATION['mean']
    
    # Use time.sleep(10) here since it's in a separate thread
    time.sleep(10)
    
    sample = model.decode_first_stage(sample)
    sample = sample.squeeze(0).clamp(-1, 1)
    timing['decode'] = time.perf_counter() - start
    
    # Convert to image
    sample_img = ((sample[:3].transpose(0,1).transpose(1,2).cpu().float().numpy() + 1) * 127.5).astype(np.uint8)
    
    timing['total'] = sum(timing.values())
    
    return sample_latent, sample_img, hidden_states, timing

def print_timing_stats(timing_info: Dict[str, float], frame_num: int):
    """Print timing statistics for a frame."""
    print(f"\nFrame {frame_num} timing (seconds):")
    for key, value in timing_info.items():
        print(f"  {key.title()}: {value:.4f}")
    print(f"  FPS: {1.0/timing_info['full_frame']:.2f}")

# Serve the index.html file at the root URL
@app.get("/")
async def get():
    return HTMLResponse(open("static/index.html").read())

# WebSocket endpoint for continuous user interaction
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):    
    client_id = id(websocket)  # Use a unique identifier for each connection
    print(f"New WebSocket connection: {client_id}")
    await websocket.accept()

    try:
        previous_frame = padding_image
        hidden_states = None
        keys_down = set()  # Initialize as an empty set
        frame_num = -1
        
        # Start timing for global FPS calculation
        connection_start_time = time.perf_counter()
        frame_count = 0
        
        # Input queue management - use asyncio.Queue instead of a list
        input_queue = asyncio.Queue()
        is_processing = False
        
        async def process_input(data):
            nonlocal previous_frame, hidden_states, keys_down, frame_num, frame_count, is_processing
            
            try:
                process_start_time = time.perf_counter()
                queue_size = input_queue.qsize()
                print(f"[{process_start_time:.3f}] Starting to process input. Queue size before: {queue_size}")
                frame_num += 1
                frame_count += 1  # Increment total frame counter
                
                # Calculate global FPS
                total_elapsed = process_start_time - connection_start_time
                global_fps = frame_count / total_elapsed if total_elapsed > 0 else 0
                
                x = data.get("x")
                y = data.get("y")
                is_left_click = data.get("is_left_click")
                is_right_click = data.get("is_right_click")
                keys_down_list = data.get("keys_down", [])  # Get as list
                keys_up_list = data.get("keys_up", [])
                print(f'[{time.perf_counter():.3f}] Processing: x: {x}, y: {y}, is_left_click: {is_left_click}, is_right_click: {is_right_click}, keys_down_list: {keys_down_list}, keys_up_list: {keys_up_list}')
                
                # Update the set based on the received data
                for key in keys_down_list:
                    keys_down.add(key)
                for key in keys_up_list:
                    if key in keys_down:  # Check if key exists to avoid KeyError
                        keys_down.remove(key)
                
                inputs = prepare_model_inputs(previous_frame, hidden_states, x, y, is_right_click, is_left_click, list(keys_down), stoi, itos, frame_num)
                print(f"[{time.perf_counter():.3f}] Starting model inference...")
                previous_frame, sample_img, hidden_states, timing_info = await process_frame(model, inputs)
                timing_info['full_frame'] = time.perf_counter() - process_start_time
                
                print(f"[{time.perf_counter():.3f}] Model inference complete. Queue size now: {input_queue.qsize()}")
                # Use the provided function to print timing statistics
                print_timing_stats(timing_info, frame_num)
                
                # Print global FPS measurement
                print(f"  Global FPS: {global_fps:.2f} (total: {frame_count} frames in {total_elapsed:.2f}s)")
                
                
                img = Image.fromarray(sample_img)
                buffered = io.BytesIO()
                img.save(buffered, format="PNG")
                img_str = base64.b64encode(buffered.getvalue()).decode()
                
                # Send the generated frame back to the client
                print(f"[{time.perf_counter():.3f}] Sending image to client...")
                await websocket.send_json({"image": img_str})
                print(f"[{time.perf_counter():.3f}] Image sent. Queue size before next_input: {input_queue.qsize()}")
            finally:
                is_processing = False
                print(f"[{time.perf_counter():.3f}] Processing complete. Queue size before checking next input: {input_queue.qsize()}")
                # Check if we have more inputs to process after this one
                if not input_queue.empty():
                    print(f"[{time.perf_counter():.3f}] Queue not empty, processing next input")
                    asyncio.create_task(process_next_input())
        
        async def process_next_input():
            nonlocal is_processing
            
            current_time = time.perf_counter()
            if input_queue.empty():
                print(f"[{current_time:.3f}] No inputs to process. Queue is empty.")
                is_processing = False
                return
            
            #if is_processing:
            #    print(f"[{current_time:.3f}] Already processing an input. Will check again later.")
            #    return
            
            # Set is_processing to True before proceeding
            is_processing = True
            
            queue_size = input_queue.qsize()
            print(f"[{current_time:.3f}] Processing next input. Queue size: {queue_size}")
            
            try:
                # Initialize variables to track progress
                skipped = 0
                latest_input = None
                
                # Process the queue one item at a time
                while not input_queue.empty():
                    current_input = await input_queue.get()
                    input_queue.task_done()
                    
                    # Always update the latest input
                    latest_input = current_input
                    
                    # Check if this is an interesting event
                    is_interesting = (current_input.get("is_left_click") or 
                                      current_input.get("is_right_click") or 
                                      (current_input.get("keys_down") and len(current_input.get("keys_down")) > 0) or 
                                      (current_input.get("keys_up") and len(current_input.get("keys_up")) > 0))
                    
                    # Process immediately if interesting
                    if is_interesting:
                        print(f"[{current_time:.3f}] Found interesting input (skipped {skipped} events)")
                        await process_input(current_input)  # AWAIT here instead of creating a task
                        is_processing = False
                        return
                    
                    # Otherwise, continue to the next item
                    skipped += 1
                    
                    # If this is the last item and no interesting inputs were found
                    if input_queue.empty():
                        print(f"[{current_time:.3f}] No interesting inputs, processing latest movement (skipped {skipped-1} events)")
                        await process_input(latest_input)  # AWAIT here instead of creating a task
                        is_processing = False
                        return
            except Exception as e:
                print(f"[{current_time:.3f}] Error in process_next_input: {e}")
                import traceback
                traceback.print_exc()
                is_processing = False  # Make sure to reset on error
        
        while True:
            try:
                # Receive user input
                print(f"[{time.perf_counter():.3f}] Waiting for input... Queue size: {input_queue.qsize()}, is_processing: {is_processing}")
                data = await websocket.receive_json()
                receive_time = time.perf_counter()
                
                if data.get("type") == "heartbeat":
                    await websocket.send_json({"type": "heartbeat_response"})
                    continue
                
                # Add the input to our queue
                await input_queue.put(data)
                print(f"[{receive_time:.3f}] Received input. Queue size now: {input_queue.qsize()}")
                
                # If we're not currently processing, start processing this input
                if not is_processing:
                    print(f"[{receive_time:.3f}] Not currently processing, will call process_next_input()")
                    is_processing = True
                    asyncio.create_task(process_next_input())  # Create task but don't await it
                else:
                    print(f"[{receive_time:.3f}] Currently processing, new input queued for later")
            
            except asyncio.TimeoutError:
                print("WebSocket connection timed out")
            
            except WebSocketDisconnect:
                print("WebSocket disconnected")
                break

    except Exception as e:
        print(f"Error in WebSocket connection {client_id}: {e}")
        import traceback
        traceback.print_exc()
    
    finally:
        # Print final FPS statistics when connection ends
        if frame_num >= 0:  # Only if we processed at least one frame
            total_time = time.perf_counter() - connection_start_time
            print(f"\nConnection {client_id} summary:")
            print(f"  Total frames processed: {frame_count}")
            print(f"  Total elapsed time: {total_time:.2f} seconds")
            print(f"  Average FPS: {frame_count/total_time:.2f}")
        
        print(f"WebSocket connection closed: {client_id}")