yuntian-deng commited on
Commit
db037f3
·
1 Parent(s): 41b76bb

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +6 -4
main.py CHANGED
@@ -67,7 +67,7 @@ def create_position_map(pos, image_size=64, original_width=1024, original_height
67
  pos_map = torch.zeros((1, image_size, image_size))
68
  pos_map[0, y_scaled, x_scaled] = 1.0
69
 
70
- return pos_map
71
 
72
  # Serve the index.html file at the root URL
73
  @app.get("/")
@@ -77,7 +77,7 @@ async def get():
77
  def generate_random_image(width: int, height: int) -> np.ndarray:
78
  return np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
79
 
80
- def draw_trace(image: np.ndarray, previous_actions: List[Tuple[str, List[int]]]) -> np.ndarray:
81
  pil_image = Image.fromarray(image)
82
  #pil_image = Image.open('image_3.png')
83
  draw = ImageDraw.Draw(pil_image)
@@ -95,10 +95,12 @@ def draw_trace(image: np.ndarray, previous_actions: List[Tuple[str, List[int]]])
95
  y = y * 256 / 640
96
  draw.ellipse([x-2, y-2, x+2, y+2], fill=color)
97
 
 
98
  if prev_x is not None:
99
  #prev_x, prev_y = previous_actions[i-1][1]
100
  draw.line([prev_x, prev_y, x, y], fill=color, width=1)
101
  prev_x, prev_y = x, y
 
102
 
103
  return np.array(pil_image)
104
 
@@ -204,7 +206,7 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
204
 
205
  prompt = " ".join(action_descriptions[-8:])
206
 
207
- pos_map = create_position_map(parse_action_string(action_descriptions[-1]))
208
 
209
 
210
  #prompt = ''
@@ -220,7 +222,7 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
220
  new_frame_denormalized = denormalize_image(new_frame, source_range=(-1, 1))
221
 
222
  # Draw the trace of previous actions
223
- new_frame_with_trace = draw_trace(new_frame_denormalized, previous_actions)
224
 
225
  return new_frame_with_trace, new_frame_denormalized
226
 
 
67
  pos_map = torch.zeros((1, image_size, image_size))
68
  pos_map[0, y_scaled, x_scaled] = 1.0
69
 
70
+ return pos_map, x_scaled, y_scaled
71
 
72
  # Serve the index.html file at the root URL
73
  @app.get("/")
 
77
  def generate_random_image(width: int, height: int) -> np.ndarray:
78
  return np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
79
 
80
+ def draw_trace(image: np.ndarray, previous_actions: List[Tuple[str, List[int]]], x_scaled=-1, y_scaled=-1) -> np.ndarray:
81
  pil_image = Image.fromarray(image)
82
  #pil_image = Image.open('image_3.png')
83
  draw = ImageDraw.Draw(pil_image)
 
95
  y = y * 256 / 640
96
  draw.ellipse([x-2, y-2, x+2, y+2], fill=color)
97
 
98
+
99
  if prev_x is not None:
100
  #prev_x, prev_y = previous_actions[i-1][1]
101
  draw.line([prev_x, prev_y, x, y], fill=color, width=1)
102
  prev_x, prev_y = x, y
103
+ draw.ellipse([x_scaled*4-2, y_scaled*4-2, x_scaled*4+2, y_scaled*4+2], fill=(0, 255, 0))
104
 
105
  return np.array(pil_image)
106
 
 
206
 
207
  prompt = " ".join(action_descriptions[-8:])
208
 
209
+ pos_map, x_scaled, y_scaled = create_position_map(parse_action_string(action_descriptions[-1]))
210
 
211
 
212
  #prompt = ''
 
222
  new_frame_denormalized = denormalize_image(new_frame, source_range=(-1, 1))
223
 
224
  # Draw the trace of previous actions
225
+ new_frame_with_trace = draw_trace(new_frame_denormalized, previous_actions, x_scaled, y_scaled)
226
 
227
  return new_frame_with_trace, new_frame_denormalized
228