yuntian-deng commited on
Commit
c7a7c4e
·
1 Parent(s): 44eb452

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +3 -4
main.py CHANGED
@@ -79,13 +79,12 @@ def generate_random_image(width: int, height: int) -> np.ndarray:
79
 
80
  def draw_trace(image: np.ndarray, previous_actions: List[Tuple[str, List[int]]], x_scaled=-1, y_scaled=-1) -> np.ndarray:
81
  pil_image = Image.fromarray(image)
82
- pil_image = pil_image.convert("RGBA")
83
  #pil_image = Image.open('image_3.png')
84
  draw = ImageDraw.Draw(pil_image)
85
  flag = True
86
  prev_x, prev_y = None, None
87
  for i, (action_type, position) in enumerate(previous_actions):
88
- color = (255, 0, 0, 255) if action_type == "move" else (0, 255, 0, 255)
89
  x, y = position
90
  if x == 0 and y == 0 and flag:
91
  continue
@@ -101,8 +100,8 @@ def draw_trace(image: np.ndarray, previous_actions: List[Tuple[str, List[int]]],
101
  #prev_x, prev_y = previous_actions[i-1][1]
102
  draw.line([prev_x, prev_y, x, y], fill=color, width=1)
103
  prev_x, prev_y = x, y
104
- draw.ellipse([x_scaled*4-2, y_scaled*4-2, x_scaled*4+2, y_scaled*4+2], fill=(0, 255, 0, 255))
105
- pil_image = pil_image.convert("RGB")
106
 
107
  return np.array(pil_image)
108
 
 
79
 
80
  def draw_trace(image: np.ndarray, previous_actions: List[Tuple[str, List[int]]], x_scaled=-1, y_scaled=-1) -> np.ndarray:
81
  pil_image = Image.fromarray(image)
 
82
  #pil_image = Image.open('image_3.png')
83
  draw = ImageDraw.Draw(pil_image)
84
  flag = True
85
  prev_x, prev_y = None, None
86
  for i, (action_type, position) in enumerate(previous_actions):
87
+ color = (255, 0, 0) if action_type == "move" else (0, 255, 0)
88
  x, y = position
89
  if x == 0 and y == 0 and flag:
90
  continue
 
100
  #prev_x, prev_y = previous_actions[i-1][1]
101
  draw.line([prev_x, prev_y, x, y], fill=color, width=1)
102
  prev_x, prev_y = x, y
103
+ draw.ellipse([x_scaled*4-2, y_scaled*4-2, x_scaled*4+2, y_scaled*4+2], fill=(0, 255, 0))
104
+ #pil_image = pil_image.convert("RGB")
105
 
106
  return np.array(pil_image)
107