da03 commited on
Commit
ab7919c
·
1 Parent(s): 9612f89
Files changed (1) hide show
  1. online_data_generation.py +270 -115
online_data_generation.py CHANGED
@@ -10,6 +10,17 @@ import numpy as np
10
  import subprocess
11
  from datetime import datetime
12
  from typing import List, Dict, Any, Tuple
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  # Import the existing functions
15
  from data.data_collection.synthetic_script_compute_canada import process_trajectory, initialize_clean_state
@@ -28,10 +39,18 @@ logger = logging.getLogger(__name__)
28
  # Define constants
29
  DB_FILE = "trajectory_processor.db"
30
  FRAMES_DIR = "interaction_logs"
 
 
31
  SCREEN_WIDTH = 512
32
  SCREEN_HEIGHT = 384
33
  MEMORY_LIMIT = "2g"
34
 
 
 
 
 
 
 
35
  def initialize_database():
36
  """Initialize the SQLite database if it doesn't exist."""
37
  conn = sqlite3.connect(DB_FILE)
@@ -142,133 +161,269 @@ def load_trajectory(log_file):
142
 
143
 
144
  def process_session_file(log_file, clean_state):
145
- """
146
- Process a session file, splitting into multiple trajectories at reset points.
147
- Returns a list of successfully processed trajectory IDs.
148
- """
149
- conn = sqlite3.connect(DB_FILE)
150
- cursor = conn.cursor()
151
-
152
- # Ensure output directory exists
153
- os.makedirs("generated_videos", exist_ok=True)
154
-
155
- # Get session details
156
- trajectory = load_trajectory(log_file)
157
- if not trajectory:
158
- logger.error(f"Empty trajectory for {log_file}, skipping")
159
- conn.close()
160
- return []
161
 
162
- client_id = trajectory[0].get("client_id", "unknown")
163
-
164
- # Find all reset points and EOS
165
- reset_indices = []
166
- has_eos = False
167
-
168
- for i, entry in enumerate(trajectory):
169
- if entry.get("is_reset", False):
170
- reset_indices.append(i)
171
- if entry.get("is_eos", False):
172
- has_eos = True
173
-
174
- # If no resets and no EOS, this is incomplete - skip
175
- if not reset_indices and not has_eos:
176
- logger.warning(f"Session {log_file} has no resets and no EOS, may be incomplete")
177
- conn.close()
178
- return []
179
-
180
- # Split trajectory at reset points
181
- sub_trajectories = []
182
- start_idx = 0
183
-
184
- # Add all segments between resets
185
- for reset_idx in reset_indices:
186
- if reset_idx > start_idx: # Only add non-empty segments
187
- sub_trajectories.append(trajectory[start_idx:reset_idx])
188
- start_idx = reset_idx + 1 # Start new segment after the reset
189
-
190
- # Add the final segment if it's not empty
191
- if start_idx < len(trajectory):
192
- sub_trajectories.append(trajectory[start_idx:])
193
-
194
- # Process each sub-trajectory
195
- processed_ids = []
196
-
197
- for i, sub_traj in enumerate(sub_trajectories):
198
- # Skip segments with no interaction data (just control messages)
199
- if not any(not entry.get("is_reset", False) and not entry.get("is_eos", False) for entry in sub_traj):
200
- continue
201
-
202
- # Get the next ID for this sub-trajectory
203
- cursor.execute("SELECT value FROM config WHERE key = 'next_id'")
204
- next_id = int(cursor.fetchone()[0])
205
 
206
- # Find timestamps for this segment
207
- start_time = sub_traj[0]["timestamp"]
208
- end_time = sub_traj[-1]["timestamp"]
209
 
210
- # STEP 1: Generate a video from the original frames
211
- segment_label = f"segment_{i+1}_of_{len(sub_trajectories)}"
212
- video_path = os.path.join("generated_videos", f"trajectory_{next_id:06d}_{segment_label}.mp4")
213
 
214
- # Generate video from original frames for comparison
215
- success, frame_count = generate_comparison_video(
216
- client_id,
217
- sub_traj,
218
- video_path,
219
- start_time,
220
- end_time
221
- )
222
 
223
- if not success:
224
- logger.warning(f"Failed to generate comparison video for segment {i+1}, but continuing with processing")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
- # STEP 2: Process with Docker for training data generation
227
- try:
228
- logger.info(f"Processing segment {i+1}/{len(sub_trajectories)} from {log_file} as trajectory {next_id}")
229
 
230
- # Format the trajectory as needed by process_trajectory function
231
- formatted_trajectory = format_trajectory_for_processing(sub_traj)
 
232
 
233
- # Call the external process_trajectory function
234
- args = (next_id, formatted_trajectory)
235
- process_trajectory(args, SCREEN_WIDTH, SCREEN_HEIGHT, clean_state, MEMORY_LIMIT)
236
 
237
- # Mark this segment as processed
238
- cursor.execute(
239
- """INSERT INTO processed_segments
240
- (log_file, client_id, segment_index, start_time, end_time,
241
- processed_time, trajectory_id)
242
- VALUES (?, ?, ?, ?, ?, ?, ?)""",
243
- (log_file, client_id, i, start_time, end_time,
244
- datetime.now().isoformat(), next_id)
245
  )
246
 
247
- # Increment the next ID
248
- cursor.execute("UPDATE config SET value = ? WHERE key = 'next_id'", (str(next_id + 1),))
249
- conn.commit()
250
-
251
- processed_ids.append(next_id)
252
- logger.info(f"Successfully processed segment {i+1}/{len(sub_trajectories)} from {log_file}")
253
 
254
- except Exception as e:
255
- logger.error(f"Failed to process segment {i+1}/{len(sub_trajectories)} from {log_file}: {e}")
256
- continue
257
-
258
- # Mark the entire session as processed only if at least one segment succeeded
259
- if processed_ids:
260
- try:
261
- cursor.execute(
262
- "INSERT INTO processed_sessions (log_file, client_id, processed_time) VALUES (?, ?, ?)",
263
- (log_file, client_id, datetime.now().isoformat())
264
- )
265
- conn.commit()
266
- except sqlite3.IntegrityError:
267
- # This can happen if we're re-processing a file that had some segments fail
268
- pass
269
-
270
- conn.close()
271
- return processed_ids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
 
273
 
274
  def format_trajectory_for_processing(trajectory):
 
10
  import subprocess
11
  from datetime import datetime
12
  from typing import List, Dict, Any, Tuple
13
+ from omegaconf import OmegaConf
14
+ from computer.util import load_model_from_config
15
+ from PIL import Image
16
+ import io
17
+ import torch
18
+ from einops import rearrange
19
+ import webdataset as wds
20
+ import pandas as pd
21
+ import ast
22
+ import pickle
23
+ from moviepy.editor import VideoFileClip
24
 
25
  # Import the existing functions
26
  from data.data_collection.synthetic_script_compute_canada import process_trajectory, initialize_clean_state
 
39
  # Define constants
40
  DB_FILE = "trajectory_processor.db"
41
  FRAMES_DIR = "interaction_logs"
42
+ OUTPUT_DIR = 'train_dataset_encoded_online'
43
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
44
  SCREEN_WIDTH = 512
45
  SCREEN_HEIGHT = 384
46
  MEMORY_LIMIT = "2g"
47
 
48
+ # load autoencoder
49
+ config = OmegaConf.load('../computer/autoencoder/config_kl4_lr4.5e6_load_acc1_512_384_mar10_keyboard_init_16_contmar15_acc1.yaml')
50
+ autoencoder = load_model_from_config(config, '../computer/autoencoder/saved_kl4_bsz8_acc8_lr4.5e6_load_acc1_512_384_mar10_keyboard_init_16_cont_mar15_acc1_cont_1e6_cont_2e7_cont/model-2076000.ckpt')
51
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
52
+ autoencoder = autoencoder.to(device)
53
+
54
  def initialize_database():
55
  """Initialize the SQLite database if it doesn't exist."""
56
  conn = sqlite3.connect(DB_FILE)
 
161
 
162
 
163
  def process_session_file(log_file, clean_state):
164
+ """Process a session file, splitting into multiple trajectories at reset points."""
165
+ conn = None
166
+ try:
167
+ conn = sqlite3.connect(DB_FILE)
168
+ conn.execute("BEGIN TRANSACTION") # Explicit transaction
169
+ cursor = conn.cursor()
 
 
 
 
 
 
 
 
 
 
170
 
171
+ # Ensure output directory exists
172
+ os.makedirs("generated_videos", exist_ok=True)
173
+
174
+ # Get session details
175
+ trajectory = load_trajectory(log_file)
176
+ if not trajectory:
177
+ logger.error(f"Empty trajectory for {log_file}, skipping")
178
+ return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
+ client_id = trajectory[0].get("client_id", "unknown")
 
 
181
 
182
+ # Find all reset points and EOS
183
+ reset_indices = []
184
+ has_eos = False
185
 
186
+ for i, entry in enumerate(trajectory):
187
+ if entry.get("is_reset", False):
188
+ reset_indices.append(i)
189
+ if entry.get("is_eos", False):
190
+ has_eos = True
 
 
 
191
 
192
+ # If no resets and no EOS, this is incomplete - skip
193
+ if not reset_indices and not has_eos:
194
+ logger.warning(f"Session {log_file} has no resets and no EOS, may be incomplete")
195
+ return []
196
+
197
+ # Split trajectory at reset points
198
+ sub_trajectories = []
199
+ start_idx = 0
200
+
201
+ # Add all segments between resets
202
+ for reset_idx in reset_indices:
203
+ if reset_idx > start_idx: # Only add non-empty segments
204
+ sub_trajectories.append(trajectory[start_idx:reset_idx])
205
+ start_idx = reset_idx + 1 # Start new segment after the reset
206
+
207
+ # Add the final segment if it's not empty
208
+ if start_idx < len(trajectory):
209
+ sub_trajectories.append(trajectory[start_idx:])
210
+
211
+ # Process each sub-trajectory
212
+ processed_ids = []
213
+
214
+ for i, sub_traj in enumerate(sub_trajectories):
215
+ # Skip segments with no interaction data (just control messages)
216
+ if not any(not entry.get("is_reset", False) and not entry.get("is_eos", False) for entry in sub_traj):
217
+ continue
218
 
219
+ # Get the next ID for this sub-trajectory
220
+ cursor.execute("SELECT value FROM config WHERE key = 'next_id'")
221
+ next_id = int(cursor.fetchone()[0])
222
 
223
+ # Find timestamps for this segment
224
+ start_time = sub_traj[0]["timestamp"]
225
+ end_time = sub_traj[-1]["timestamp"]
226
 
227
+ # STEP 1: Generate a video from the original frames
228
+ segment_label = f"segment_{i+1}_of_{len(sub_trajectories)}"
229
+ video_path = os.path.join("generated_videos", f"trajectory_{next_id}_{segment_label}.mp4")
230
 
231
+ # Generate video from original frames for comparison
232
+ success, frame_count = generate_comparison_video(
233
+ client_id,
234
+ sub_traj,
235
+ video_path,
236
+ start_time,
237
+ end_time
 
238
  )
239
 
240
+ if not success:
241
+ logger.warning(f"Failed to generate comparison video for segment {i+1}, but continuing with processing")
 
 
 
 
242
 
243
+ # STEP 2: Process with Docker for training data generation
244
+ try:
245
+ logger.info(f"Processing segment {i+1}/{len(sub_trajectories)} from {log_file} as trajectory {next_id}")
246
+
247
+ # Format the trajectory as needed by process_trajectory function
248
+ formatted_trajectory = format_trajectory_for_processing(sub_traj)
249
+ record_num = next_id
250
+
251
+ # Call the external process_trajectory function
252
+ args = (record_num, formatted_trajectory)
253
+ process_trajectory(args, SCREEN_WIDTH, SCREEN_HEIGHT, clean_state, MEMORY_LIMIT)
254
+
255
+ # Prepare training data format
256
+ video_file = f'raw_data/raw_data/videos/record_{record_num}.mp4'
257
+ action_file = f'raw_data/raw_data/actions/record_{record_num}.csv'
258
+ mouse_data = pd.read_csv(action_file)
259
+ mapping_dict = {}
260
+ target_data = []
261
+ # remove the existing tar file if exists
262
+ if os.path.exists(os.path.join(OUTPUT_DIR, f'record_{record_num}.tar')):
263
+ logger.info(f"Removing existing tar file {os.path.join(OUTPUT_DIR, f'record_{record_num}.tar')}")
264
+ os.remove(os.path.join(OUTPUT_DIR, f'record_{record_num}.tar'))
265
+ sink = wds.TarWriter(os.path.join(OUTPUT_DIR, f'record_{record_num}.tar'))
266
+ with VideoFileClip(video_file) as video:
267
+ fps = video.fps
268
+ assert fps == 15, f"Expected 15 FPS, got {fps}"
269
+ duration = video.duration
270
+ down_keys = set([])
271
+ for image_num in range(int(fps*duration)):
272
+ action_row = mouse_data.iloc[image_num]
273
+ x = int(action_row['X'])
274
+ y = int(action_row['Y'])
275
+ left_click = True if action_row['Left Click'] == 1 else False
276
+ right_click = True if action_row['Right Click'] == 1 else False
277
+ key_events = ast.literal_eval(action_row['Key Events'])
278
+ for key_state, key in key_events:
279
+ if key_state == "keydown":
280
+ down_keys.add(key)
281
+ elif key_state == "keyup":
282
+ down_keys.remove(key)
283
+ else:
284
+ raise ValueError(f"Unknown key event type: {key_state}")
285
+ mapping_dict[(record_num, image_num)] = (x, y, left_click, right_click, list(down_keys))
286
+ target_data.append((record_num, image_num))
287
+ frame = video.get_frame(image_num / fps)
288
+
289
+ # Normalize to [-1, 1]
290
+ image_array = (frame / 127.5 - 1.0).astype(np.float32)
291
+
292
+ # Convert to torch tensor
293
+ images_tensor = torch.tensor(image_array).unsqueeze(0)
294
+ images_tensor = rearrange(images_tensor, 'b h w c -> b c h w')
295
+
296
+ # Move to device for inference
297
+ images_tensor = images_tensor.to(device)
298
+
299
+ # Encode images
300
+ posterior = autoencoder.encode(images_tensor)
301
+ latents = posterior.sample() # Sample from the posterior
302
+
303
+ # Move back to CPU for saving
304
+ latents = latents.cpu()
305
+
306
+ # Save each latent to the tar file
307
+ latent = latents[0]
308
+ key = str(image_num)
309
+
310
+ # Convert latent to bytes
311
+ latent_bytes = io.BytesIO()
312
+ np.save(latent_bytes, latent.numpy())
313
+ latent_bytes.seek(0)
314
+
315
+ # Write to tar
316
+ sample = {
317
+ "__key__": key,
318
+ "npy": latent_bytes.getvalue(),
319
+ }
320
+ sink.write(sample)
321
+ debug = True
322
+ # Debug first batch if requested
323
+ if debug:
324
+ debug_dir = os.path.join(OUTPUT_DIR, 'debug')
325
+ os.makedirs(debug_dir, exist_ok=True)
326
+
327
+ # Decode latents back to images
328
+ reconstructions = autoencoder.decode(latents.to(device))
329
+
330
+ # Save original and reconstructed images side by side
331
+ for idx, (orig, recon) in enumerate(zip(images_tensor, reconstructions)):
332
+ # Convert to numpy and move to CPU
333
+ orig = orig.cpu().numpy()
334
+ recon = recon.cpu().numpy()
335
+
336
+ # Denormalize from [-1,1] to [0,255]
337
+ orig = (orig + 1.0) * 127.5
338
+ recon = (recon + 1.0) * 127.5
339
+
340
+ # Clip values to valid range
341
+ orig = np.clip(orig, 0, 255).astype(np.uint8)
342
+ recon = np.clip(recon, 0, 255).astype(np.uint8)
343
+
344
+ # Rearrange from CHW to HWC
345
+ orig = np.transpose(orig, (1,2,0))
346
+ recon = np.transpose(recon, (1,2,0))
347
+
348
+ # Create side-by-side comparison
349
+ comparison = np.concatenate([orig, recon], axis=1)
350
+
351
+ # Save comparison image
352
+ Image.fromarray(comparison).save(
353
+ os.path.join(debug_dir, f'debug_{video_file}_{idx}_{keys[idx]}.png')
354
+ )
355
+ print(f"\nDebug visualizations saved to {debug_dir}")
356
+ sink.close()
357
+ # merge with existing mapping_dict if exists, otherwise create new one
358
+ if os.path.exists(os.path.join(OUTPUT_DIR, 'image_action_mapping_with_key_states.pkl')):
359
+ with open(os.path.join(OUTPUT_DIR, 'image_action_mapping_with_key_states.pkl'), 'rb') as f:
360
+ existing_mapping_dict = pickle.load(f)
361
+ for key, value in existing_mapping_dict.items():
362
+ if key not in mapping_dict:
363
+ mapping_dict[key] = value
364
+ # save the mapping_dict in an atomic way
365
+ temp_path = os.path.join(OUTPUT_DIR, 'image_action_mapping_with_key_states.pkl.temp')
366
+ with open(temp_path, 'wb') as f:
367
+ pickle.dump(mapping_dict, f)
368
+ os.rename(temp_path, os.path.join(OUTPUT_DIR, 'image_action_mapping_with_key_states.pkl'))
369
+
370
+ # merge with existing target_data if exists, otherwise create new one
371
+ target_data = pd.DataFrame(target_data, columns=['record_num', 'image_num'])
372
+ if os.path.exists(os.path.join(OUTPUT_DIR, 'train_dataset.target_frames.csv')):
373
+ existing_target_data = pd.read_csv(os.path.join(OUTPUT_DIR, 'train_dataset.target_frames.csv'))
374
+ target_data = pd.concat([existing_target_data, target_data])
375
+ # deduplicate
376
+ target_data = target_data.drop_duplicates()
377
+ # save the target_data in an atomic way
378
+ temp_path = os.path.join(OUTPUT_DIR, 'train_dataset.target_frames.csv.temp')
379
+ target_data.to_csv(temp_path, index=False)
380
+ os.rename(temp_path, os.path.join(OUTPUT_DIR, 'train_dataset.target_frames.csv'))
381
+
382
+
383
+ # Mark this segment as processed
384
+ cursor.execute(
385
+ """INSERT INTO processed_segments
386
+ (log_file, client_id, segment_index, start_time, end_time,
387
+ processed_time, trajectory_id)
388
+ VALUES (?, ?, ?, ?, ?, ?, ?)""",
389
+ (log_file, client_id, i, start_time, end_time,
390
+ datetime.now().isoformat(), next_id)
391
+ )
392
+
393
+ # Increment the next ID
394
+ cursor.execute("UPDATE config SET value = ? WHERE key = 'next_id'", (str(next_id + 1),))
395
+ conn.commit()
396
+
397
+ processed_ids.append(next_id)
398
+ logger.info(f"Successfully processed segment {i+1}/{len(sub_trajectories)} from {log_file}")
399
+
400
+ except Exception as e:
401
+ logger.error(f"Failed to process segment {i+1}/{len(sub_trajectories)} from {log_file}: {e}")
402
+ continue
403
+
404
+ # Mark the entire session as processed only if at least one segment succeeded
405
+ if processed_ids:
406
+ try:
407
+ cursor.execute(
408
+ "INSERT INTO processed_sessions (log_file, client_id, processed_time) VALUES (?, ?, ?)",
409
+ (log_file, client_id, datetime.now().isoformat())
410
+ )
411
+ conn.commit()
412
+ except sqlite3.IntegrityError:
413
+ # This can happen if we're re-processing a file that had some segments fail
414
+ pass
415
+
416
+ # Commit only at the end if everything succeeds
417
+ conn.commit()
418
+ return processed_ids
419
+ except Exception as e:
420
+ logger.error(f"Error processing session {log_file}: {e}")
421
+ if conn:
422
+ conn.rollback() # Roll back on error
423
+ return []
424
+ finally:
425
+ if conn:
426
+ conn.close() # Always close connection
427
 
428
 
429
  def format_trajectory_for_processing(trajectory):