da03 commited on
Commit
50eea75
·
1 Parent(s): a10a91e
Files changed (1) hide show
  1. main.py +39 -11
main.py CHANGED
@@ -13,6 +13,7 @@ import os
13
  import time
14
 
15
  DEBUG = True
 
16
  app = FastAPI()
17
 
18
  # Mount the static directory to serve HTML, JavaScript, and CSS files
@@ -128,15 +129,14 @@ model = model.to(device)
128
 
129
  def load_initial_images(width, height):
130
  initial_images = []
131
- for i in range(7):
132
- initial_images.append(np.zeros((height, width, 3), dtype=np.uint8))
133
- #image_path = f"image_{i}.png"
134
- #if os.path.exists(image_path):
135
- # img = Image.open(image_path).resize((width, height))
136
- # initial_images.append(np.array(img))
137
- #else:
138
- # print(f"Warning: {image_path} not found. Using blank image instead.")
139
- # initial_images.append(np.zeros((height, width, 3), dtype=np.uint8))
140
  return initial_images
141
 
142
  def normalize_images(images, target_range=(-1, 1)):
@@ -156,13 +156,15 @@ def denormalize_image(image, source_range=(-1, 1)):
156
  else:
157
  raise ValueError(f"Unsupported source range: {source_range}")
158
 
159
- def format_action(action_str, is_padding=False):
160
  if is_padding:
161
  return "N N N N N N : N N N N N"
162
 
163
  # Split the x~y coordinates
164
  x, y = map(int, action_str.split('~'))
165
  prefix = 'N'
 
 
166
  # Convert numbers to padded strings and add spaces between digits
167
  x_str = f"{abs(x):04d}"
168
  y_str = f"{abs(y):04d}"
@@ -200,6 +202,22 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
200
  prev_x = 0
201
  prev_y = 0
202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
  for action_type, pos in previous_actions: #[-8:]:
205
  if action_type == "move":
@@ -217,7 +235,17 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
217
  prev_x = norm_x
218
  prev_y = norm_y
219
  elif action_type == "left_click":
220
- action_descriptions.append("left_click")
 
 
 
 
 
 
 
 
 
 
221
  elif action_type == "right_click":
222
  action_descriptions.append("right_click")
223
 
 
13
  import time
14
 
15
  DEBUG = True
16
+ DEBUG_TEACHER_FORCING = True
17
  app = FastAPI()
18
 
19
  # Mount the static directory to serve HTML, JavaScript, and CSS files
 
129
 
130
  def load_initial_images(width, height):
131
  initial_images = []
132
+ if DEBUG_TEACHER_FORCING:
133
+ # Load the previous 7 frames for image_81
134
+ for i in range(74, 81): # Load images 74-80
135
+ img = Image.open(f"record_100/image_{i}.png").resize((width, height))
136
+ initial_images.append(np.array(img))
137
+ else:
138
+ for i in range(7):
139
+ initial_images.append(np.zeros((height, width, 3), dtype=np.uint8))
 
140
  return initial_images
141
 
142
  def normalize_images(images, target_range=(-1, 1)):
 
156
  else:
157
  raise ValueError(f"Unsupported source range: {source_range}")
158
 
159
+ def format_action(action_str, is_padding=False, is_leftclick=False):
160
  if is_padding:
161
  return "N N N N N N : N N N N N"
162
 
163
  # Split the x~y coordinates
164
  x, y = map(int, action_str.split('~'))
165
  prefix = 'N'
166
+ if is_leftclick:
167
+ prefix = 'L'
168
  # Convert numbers to padded strings and add spaces between digits
169
  x_str = f"{abs(x):04d}"
170
  y_str = f"{abs(y):04d}"
 
202
  prev_x = 0
203
  prev_y = 0
204
 
205
+ if DEBUG_TEACHER_FORCING:
206
+ # Use the predefined actions for image_81
207
+ debug_actions = [
208
+ 'N + 0 8 5 3 : + 0 4 5 0', 'N + 0 8 7 1 : + 0 4 6 3',
209
+ 'N + 0 8 9 0 : + 0 4 7 5', 'N + 0 9 0 8 : + 0 4 8 8',
210
+ 'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
211
+ 'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
212
+ 'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
213
+ 'L + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
214
+ 'L + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
215
+ 'N + 0 9 2 7 : + 0 5 0 1'
216
+ ]
217
+ previous_actions = []
218
+ for action in debug_actions:
219
+ x, y, action_type = parse_action_string(action)
220
+ previous_actions.append((action_type, (x, y)))
221
 
222
  for action_type, pos in previous_actions: #[-8:]:
223
  if action_type == "move":
 
235
  prev_x = norm_x
236
  prev_y = norm_y
237
  elif action_type == "left_click":
238
+ x, y = pos
239
+ #norm_x = int(round(x / 256 * 1024)) #x + (1920 - 256) / 2
240
+ #norm_y = int(round(y / 256 * 640)) #y + (1080 - 256) / 2
241
+ norm_x = x + (1920 - 512) / 2
242
+ norm_y = y + (1080 - 512) / 2
243
+ #if DEBUG:
244
+ # norm_x = x
245
+ # norm_y = y
246
+ #action_descriptions.append(f"{(norm_x-prev_x):.0f}~{(norm_y-prev_y):.0f}")
247
+ #action_descriptions.append(format_action(f'{norm_x-prev_x:.0f}~{norm_y-prev_y:.0f}', x==0 and y==0))
248
+ action_descriptions.append(format_action(f'{norm_x:.0f}~{norm_y:.0f}', x==0 and y==0, True))
249
  elif action_type == "right_click":
250
  action_descriptions.append("right_click")
251