Spaces:
Runtime error
Runtime error
Commit
·
db037f3
1
Parent(s):
41b76bb
Update main.py
Browse files
main.py
CHANGED
@@ -67,7 +67,7 @@ def create_position_map(pos, image_size=64, original_width=1024, original_height
|
|
67 |
pos_map = torch.zeros((1, image_size, image_size))
|
68 |
pos_map[0, y_scaled, x_scaled] = 1.0
|
69 |
|
70 |
-
return pos_map
|
71 |
|
72 |
# Serve the index.html file at the root URL
|
73 |
@app.get("/")
|
@@ -77,7 +77,7 @@ async def get():
|
|
77 |
def generate_random_image(width: int, height: int) -> np.ndarray:
|
78 |
return np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
|
79 |
|
80 |
-
def draw_trace(image: np.ndarray, previous_actions: List[Tuple[str, List[int]]]) -> np.ndarray:
|
81 |
pil_image = Image.fromarray(image)
|
82 |
#pil_image = Image.open('image_3.png')
|
83 |
draw = ImageDraw.Draw(pil_image)
|
@@ -95,10 +95,12 @@ def draw_trace(image: np.ndarray, previous_actions: List[Tuple[str, List[int]]])
|
|
95 |
y = y * 256 / 640
|
96 |
draw.ellipse([x-2, y-2, x+2, y+2], fill=color)
|
97 |
|
|
|
98 |
if prev_x is not None:
|
99 |
#prev_x, prev_y = previous_actions[i-1][1]
|
100 |
draw.line([prev_x, prev_y, x, y], fill=color, width=1)
|
101 |
prev_x, prev_y = x, y
|
|
|
102 |
|
103 |
return np.array(pil_image)
|
104 |
|
@@ -204,7 +206,7 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
|
|
204 |
|
205 |
prompt = " ".join(action_descriptions[-8:])
|
206 |
|
207 |
-
pos_map = create_position_map(parse_action_string(action_descriptions[-1]))
|
208 |
|
209 |
|
210 |
#prompt = ''
|
@@ -220,7 +222,7 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
|
|
220 |
new_frame_denormalized = denormalize_image(new_frame, source_range=(-1, 1))
|
221 |
|
222 |
# Draw the trace of previous actions
|
223 |
-
new_frame_with_trace = draw_trace(new_frame_denormalized, previous_actions)
|
224 |
|
225 |
return new_frame_with_trace, new_frame_denormalized
|
226 |
|
|
|
67 |
pos_map = torch.zeros((1, image_size, image_size))
|
68 |
pos_map[0, y_scaled, x_scaled] = 1.0
|
69 |
|
70 |
+
return pos_map, x_scaled, y_scaled
|
71 |
|
72 |
# Serve the index.html file at the root URL
|
73 |
@app.get("/")
|
|
|
77 |
def generate_random_image(width: int, height: int) -> np.ndarray:
|
78 |
return np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
|
79 |
|
80 |
+
def draw_trace(image: np.ndarray, previous_actions: List[Tuple[str, List[int]]], x_scaled=-1, y_scaled=-1) -> np.ndarray:
|
81 |
pil_image = Image.fromarray(image)
|
82 |
#pil_image = Image.open('image_3.png')
|
83 |
draw = ImageDraw.Draw(pil_image)
|
|
|
95 |
y = y * 256 / 640
|
96 |
draw.ellipse([x-2, y-2, x+2, y+2], fill=color)
|
97 |
|
98 |
+
|
99 |
if prev_x is not None:
|
100 |
#prev_x, prev_y = previous_actions[i-1][1]
|
101 |
draw.line([prev_x, prev_y, x, y], fill=color, width=1)
|
102 |
prev_x, prev_y = x, y
|
103 |
+
draw.ellipse([x_scaled*4-2, y_scaled*4-2, x_scaled*4+2, y_scaled*4+2], fill=(0, 255, 0))
|
104 |
|
105 |
return np.array(pil_image)
|
106 |
|
|
|
206 |
|
207 |
prompt = " ".join(action_descriptions[-8:])
|
208 |
|
209 |
+
pos_map, x_scaled, y_scaled = create_position_map(parse_action_string(action_descriptions[-1]))
|
210 |
|
211 |
|
212 |
#prompt = ''
|
|
|
222 |
new_frame_denormalized = denormalize_image(new_frame, source_range=(-1, 1))
|
223 |
|
224 |
# Draw the trace of previous actions
|
225 |
+
new_frame_with_trace = draw_trace(new_frame_denormalized, previous_actions, x_scaled, y_scaled)
|
226 |
|
227 |
return new_frame_with_trace, new_frame_denormalized
|
228 |
|