da03 commited on
Commit
9efab58
·
1 Parent(s): 678650b
Files changed (2) hide show
  1. main.py +13 -0
  2. online_data_generation.py +58 -2
main.py CHANGED
@@ -118,6 +118,17 @@ KEYS = ['\t', '\n', '\r', ' ', '!', '"', '#', '$', '%', '&', "'", '(',
118
  'shift', 'shiftleft', 'shiftright', 'sleep', 'space', 'stop', 'subtract', 'tab',
119
  'up', 'volumedown', 'volumemute', 'volumeup', 'win', 'winleft', 'winright', 'yen',
120
  'command', 'option', 'optionleft', 'optionright']
 
 
 
 
 
 
 
 
 
 
 
121
  INVALID_KEYS = ['f13', 'f14', 'f15', 'f16', 'f17', 'f18', 'f19', 'f20',
122
  'f21', 'f22', 'f23', 'f24', 'select', 'separator', 'execute']
123
  VALID_KEYS = [key for key in KEYS if key not in INVALID_KEYS]
@@ -174,6 +185,8 @@ def prepare_model_inputs(
174
  }
175
  for key in keys_down:
176
  key = key.lower()
 
 
177
  if key in stoi:
178
  inputs['key_events'][stoi[key]] = 1
179
  else:
 
118
  'shift', 'shiftleft', 'shiftright', 'sleep', 'space', 'stop', 'subtract', 'tab',
119
  'up', 'volumedown', 'volumemute', 'volumeup', 'win', 'winleft', 'winright', 'yen',
120
  'command', 'option', 'optionleft', 'optionright']
121
+
122
+ KEYMAPPING = {
123
+ 'arrowup': 'up',
124
+ 'arrowdown': 'down',
125
+ 'arrowleft': 'left',
126
+ 'arrowright': 'right',
127
+ 'meta': 'command',
128
+ 'contextmenu': 'apps',
129
+ 'control': 'ctrl',
130
+ }
131
+
132
  INVALID_KEYS = ['f13', 'f14', 'f15', 'f16', 'f17', 'f18', 'f19', 'f20',
133
  'f21', 'f22', 'f23', 'f24', 'select', 'separator', 'execute']
134
  VALID_KEYS = [key for key in KEYS if key not in INVALID_KEYS]
 
185
  }
186
  for key in keys_down:
187
  key = key.lower()
188
+ if key in KEYMAPPING:
189
+ key = KEYMAPPING[key]
190
  if key in stoi:
191
  inputs['key_events'][stoi[key]] = 1
192
  else:
online_data_generation.py CHANGED
@@ -56,6 +56,45 @@ autoencoder = autoencoder.to(device)
56
  # Global flag for graceful shutdown
57
  running = True
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  def signal_handler(sig, frame):
60
  """Handle Ctrl+C and other termination signals"""
61
  global running
@@ -459,10 +498,27 @@ def format_trajectory_for_processing(trajectory):
459
  # Extract input data
460
  inputs = entry.get("inputs", {})
461
  key_events = []
 
462
  for key in inputs.get("keys_down", []):
463
- key_events.append(("keydown", key))
 
 
 
 
 
 
 
 
464
  for key in inputs.get("keys_up", []):
465
- key_events.append(("keyup", key))
 
 
 
 
 
 
 
 
466
  event = {
467
  "pos": (inputs.get("x"), inputs.get("y")),
468
  "left_click": inputs.get("is_left_click", False),
 
56
  # Global flag for graceful shutdown
57
  running = True
58
 
59
+
60
+ KEYMAPPING = {
61
+ 'arrowup': 'up',
62
+ 'arrowdown': 'down',
63
+ 'arrowleft': 'left',
64
+ 'arrowright': 'right',
65
+ 'meta': 'command',
66
+ 'contextmenu': 'apps',
67
+ 'control': 'ctrl',
68
+ }
69
+
70
+ KEYS = ['\t', '\n', '\r', ' ', '!', '"', '#', '$', '%', '&', "'", '(',
71
+ ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7',
72
+ '8', '9', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`',
73
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o',
74
+ 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~',
75
+ 'accept', 'add', 'alt', 'altleft', 'altright', 'apps', 'backspace',
76
+ 'browserback', 'browserfavorites', 'browserforward', 'browserhome',
77
+ 'browserrefresh', 'browsersearch', 'browserstop', 'capslock', 'clear',
78
+ 'convert', 'ctrl', 'ctrlleft', 'ctrlright', 'decimal', 'del', 'delete',
79
+ 'divide', 'down', 'end', 'enter', 'esc', 'escape', 'execute', 'f1', 'f10',
80
+ 'f11', 'f12', 'f13', 'f14', 'f15', 'f16', 'f17', 'f18', 'f19', 'f2', 'f20',
81
+ 'f21', 'f22', 'f23', 'f24', 'f3', 'f4', 'f5', 'f6', 'f7', 'f8', 'f9',
82
+ 'final', 'fn', 'hanguel', 'hangul', 'hanja', 'help', 'home', 'insert', 'junja',
83
+ 'kana', 'kanji', 'launchapp1', 'launchapp2', 'launchmail',
84
+ 'launchmediaselect', 'left', 'modechange', 'multiply', 'nexttrack',
85
+ 'nonconvert', 'num0', 'num1', 'num2', 'num3', 'num4', 'num5', 'num6',
86
+ 'num7', 'num8', 'num9', 'numlock', 'pagedown', 'pageup', 'pause', 'pgdn',
87
+ 'pgup', 'playpause', 'prevtrack', 'print', 'printscreen', 'prntscrn',
88
+ 'prtsc', 'prtscr', 'return', 'right', 'scrolllock', 'select', 'separator',
89
+ 'shift', 'shiftleft', 'shiftright', 'sleep', 'space', 'stop', 'subtract', 'tab',
90
+ 'up', 'volumedown', 'volumemute', 'volumeup', 'win', 'winleft', 'winright', 'yen',
91
+ 'command', 'option', 'optionleft', 'optionright']
92
+ INVALID_KEYS = ['f13', 'f14', 'f15', 'f16', 'f17', 'f18', 'f19', 'f20',
93
+ 'f21', 'f22', 'f23', 'f24', 'select', 'separator', 'execute']
94
+ VALID_KEYS = [key for key in KEYS if key not in INVALID_KEYS]
95
+ itos = VALID_KEYS
96
+ stoi = {key: i for i, key in enumerate(itos)}
97
+
98
  def signal_handler(sig, frame):
99
  """Handle Ctrl+C and other termination signals"""
100
  global running
 
498
  # Extract input data
499
  inputs = entry.get("inputs", {})
500
  key_events = []
501
+ down_keys = set([])
502
  for key in inputs.get("keys_down", []):
503
+ key = key.lower()
504
+ if key in KEYMAPPING:
505
+ print (f"Key {key} mapped to {KEYMAPPING[key]}")
506
+ key = KEYMAPPING[key]
507
+ if key not in stoi:
508
+ print (f"Key {key} not found in stoi")
509
+ if key not in down_keys and key in stoi:
510
+ down_keys.add(key)
511
+ key_events.append(("keydown", key))
512
  for key in inputs.get("keys_up", []):
513
+ key = key.lower()
514
+ if key in KEYMAPPING:
515
+ print (f"Key {key} mapped to {KEYMAPPING[key]}")
516
+ key = KEYMAPPING[key]
517
+ if key not in stoi:
518
+ print (f"Key {key} not found in stoi")
519
+ if key in down_keys and key in stoi:
520
+ down_keys.remove(key)
521
+ key_events.append(("keyup", key))
522
  event = {
523
  "pos": (inputs.get("x"), inputs.get("y")),
524
  "left_click": inputs.get("is_left_click", False),