da03 commited on
Commit
c815936
·
1 Parent(s): 21ceee1
Files changed (1) hide show
  1. main.py +3 -0
main.py CHANGED
@@ -130,6 +130,9 @@ def draw_trace(image: np.ndarray, previous_actions: List[Tuple[str, List[int]]],
130
  if x_scaled >= 0 and y_scaled >= 0:
131
  x_current = x_scaled * 8
132
  y_current = y_scaled * 8
 
 
 
133
  draw.ellipse([x_current-3, y_current-3, x_current+3, y_current+3], fill=(0, 255, 0))
134
 
135
  return np.array(pil_image)
 
130
  if x_scaled >= 0 and y_scaled >= 0:
131
  x_current = x_scaled * 8
132
  y_current = y_scaled * 8
133
+ if not DEBUG_TEACHER_FORCING:
134
+ x_current = x_current *8
135
+ y_current = y_current *8
136
  draw.ellipse([x_current-3, y_current-3, x_current+3, y_current+3], fill=(0, 255, 0))
137
 
138
  return np.array(pil_image)