Spaces:
Runtime error
Runtime error
da03
commited on
Commit
·
9d0127f
1
Parent(s):
487aaae
main.py
CHANGED
@@ -168,6 +168,15 @@ def normalize_images(images, target_range=(-1, 1)):
|
|
168 |
return images / 255.0
|
169 |
else:
|
170 |
raise ValueError(f"Unsupported target range: {target_range}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
|
172 |
def denormalize_image(image, source_range=(-1, 1)):
|
173 |
if source_range == (-1, 1):
|
@@ -195,28 +204,27 @@ def format_action(action_str, is_padding=False, is_leftclick=False):
|
|
195 |
# Format with sign and proper spacing
|
196 |
return prefix + " " + f"{'+ ' if x >= 0 else '- '}{x_spaced} : {'+ ' if y >= 0 else '- '}{y_spaced}"
|
197 |
|
198 |
-
def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List[Tuple[str, List[int]]]) -> np.ndarray:
|
199 |
width, height = 512, 384
|
200 |
all_click_positions = []
|
201 |
initial_images = load_initial_images(width, height)
|
202 |
print ('length of previous_frames', len(previous_frames))
|
|
|
203 |
|
204 |
# Prepare the image sequence for the model
|
205 |
assert len(initial_images) == 32
|
206 |
image_sequence = previous_frames[-32:] # Take the last 7 frames
|
207 |
i = 1
|
208 |
while len(image_sequence) < 32:
|
209 |
-
image_sequence.insert(0,
|
210 |
i += 1
|
211 |
#image_sequence.append(initial_images[len(image_sequence)])
|
212 |
-
|
213 |
# Convert the image sequence to a tensor and concatenate in the channel dimension
|
214 |
-
image_sequence_tensor = torch.from_numpy(normalize_images(
|
215 |
-
image_sequence_tensor = image_sequence_tensor.to(device)
|
216 |
-
|
217 |
-
|
218 |
-
data_min = -27.681446075439453
|
219 |
-
data_max = 30.854148864746094
|
220 |
#image_sequence_tensor = (image_sequence_tensor - data_mean) / data_std
|
221 |
|
222 |
# Prepare the prompt based on the previous actions
|
@@ -318,7 +326,7 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
|
|
318 |
#print ('changing L to N')
|
319 |
|
320 |
# Generate the next frame
|
321 |
-
new_frame = sample_frame(model, prompt, image_sequence_tensor, pos_maps=pos_maps, leftclick_maps=leftclick_maps)
|
322 |
|
323 |
# Convert the generated frame to the correct format
|
324 |
new_frame = new_frame.transpose(1, 2, 0)
|
@@ -333,7 +341,7 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
|
|
333 |
#x, y, action_type = parse_action_string(action_descriptions[-1])
|
334 |
|
335 |
|
336 |
-
return new_frame_with_trace, new_frame_denormalized
|
337 |
|
338 |
# WebSocket endpoint for continuous user interaction
|
339 |
@app.websocket("/ws")
|
@@ -513,10 +521,10 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
513 |
#if DEBUG_TEACHER_FORCING:
|
514 |
# print ('predicting', f"record_10003/image_{117+len(previous_frames)}.png")
|
515 |
print ('previous_actions', previous_actions)
|
516 |
-
next_frame, next_frame_append = predict_next_frame(previous_frames, previous_actions)
|
517 |
feedback = True
|
518 |
if feedback:
|
519 |
-
previous_frames.append(
|
520 |
else:
|
521 |
#previous_frames = []
|
522 |
previous_actions = []
|
|
|
168 |
return images / 255.0
|
169 |
else:
|
170 |
raise ValueError(f"Unsupported target range: {target_range}")
|
171 |
+
|
172 |
+
def normalize_image(image, target_range=(-1, 1)):
|
173 |
+
image = image.astype(np.float32)
|
174 |
+
if target_range == (-1, 1):
|
175 |
+
return image / 127.5 - 1
|
176 |
+
elif target_range == (0, 1):
|
177 |
+
return image / 255.0
|
178 |
+
else:
|
179 |
+
raise ValueError(f"Unsupported target range: {target_range}")
|
180 |
|
181 |
def denormalize_image(image, source_range=(-1, 1)):
|
182 |
if source_range == (-1, 1):
|
|
|
204 |
# Format with sign and proper spacing
|
205 |
return prefix + " " + f"{'+ ' if x >= 0 else '- '}{x_spaced} : {'+ ' if y >= 0 else '- '}{y_spaced}"
|
206 |
|
207 |
+
def predict_next_frame(previous_frames: List[np.ndarray, Tuple[str, np.ndarray]], previous_actions: List[Tuple[str, List[int]]]) -> np.ndarray:
|
208 |
width, height = 512, 384
|
209 |
all_click_positions = []
|
210 |
initial_images = load_initial_images(width, height)
|
211 |
print ('length of previous_frames', len(previous_frames))
|
212 |
+
padding_image = torch.zeros((height//8, width//8, 4)).to(device)
|
213 |
|
214 |
# Prepare the image sequence for the model
|
215 |
assert len(initial_images) == 32
|
216 |
image_sequence = previous_frames[-32:] # Take the last 7 frames
|
217 |
i = 1
|
218 |
while len(image_sequence) < 32:
|
219 |
+
image_sequence.insert(0, padding_image)
|
220 |
i += 1
|
221 |
#image_sequence.append(initial_images[len(image_sequence)])
|
222 |
+
|
223 |
# Convert the image sequence to a tensor and concatenate in the channel dimension
|
224 |
+
#image_sequence_tensor = torch.from_numpy(normalize_images(image_sequence_list, target_range=(-1, 1)))
|
225 |
+
#image_sequence_tensor = image_sequence_tensor.to(device)
|
226 |
+
image_sequence_tensor = torch.cat(image_sequence, dim=1)
|
227 |
+
|
|
|
|
|
228 |
#image_sequence_tensor = (image_sequence_tensor - data_mean) / data_std
|
229 |
|
230 |
# Prepare the prompt based on the previous actions
|
|
|
326 |
#print ('changing L to N')
|
327 |
|
328 |
# Generate the next frame
|
329 |
+
new_frame, new_frame_feedback = sample_frame(model, prompt, image_sequence_tensor, pos_maps=pos_maps, leftclick_maps=leftclick_maps)
|
330 |
|
331 |
# Convert the generated frame to the correct format
|
332 |
new_frame = new_frame.transpose(1, 2, 0)
|
|
|
341 |
#x, y, action_type = parse_action_string(action_descriptions[-1])
|
342 |
|
343 |
|
344 |
+
return new_frame_with_trace, new_frame_denormalized, new_frame_feedback
|
345 |
|
346 |
# WebSocket endpoint for continuous user interaction
|
347 |
@app.websocket("/ws")
|
|
|
521 |
#if DEBUG_TEACHER_FORCING:
|
522 |
# print ('predicting', f"record_10003/image_{117+len(previous_frames)}.png")
|
523 |
print ('previous_actions', previous_actions)
|
524 |
+
next_frame, next_frame_append, next_frame_feedback = predict_next_frame(previous_frames, previous_actions)
|
525 |
feedback = True
|
526 |
if feedback:
|
527 |
+
previous_frames.append(next_frame_feedback)
|
528 |
else:
|
529 |
#previous_frames = []
|
530 |
previous_actions = []
|
utils.py
CHANGED
@@ -45,13 +45,13 @@ def sample_frame(model: LatentDiffusion, prompt: str, image_sequence: torch.Tens
|
|
45 |
#print (c['c_crossattn'][0])
|
46 |
print (prompt)
|
47 |
c = {}
|
48 |
-
c = model.enc_concat_seq(c, c_dict, 'c_concat')
|
49 |
# Zero out the corresponding subtensors in c_concat for padding images
|
50 |
-
padding_mask = torch.isclose(image_sequence, torch.tensor(-1.0), rtol=1e-5, atol=1e-5).all(dim=(1, 2, 3)).unsqueeze(0)
|
51 |
-
print (padding_mask)
|
52 |
-
padding_mask = padding_mask.repeat(1, 4) # Repeat mask 4 times for each projected channel
|
53 |
-
print (image_sequence.shape, padding_mask.shape, c['c_concat'].shape)
|
54 |
-
c['c_concat'] = c['c_concat'] * (~padding_mask.unsqueeze(-1).unsqueeze(-1)) # Zero out the corresponding features
|
55 |
data_mean = -0.54
|
56 |
data_std = 6.78
|
57 |
data_min = -27.681446075439453
|
@@ -108,6 +108,7 @@ def sample_frame(model: LatentDiffusion, prompt: str, image_sequence: torch.Tens
|
|
108 |
data_max = 30.854148864746094
|
109 |
x_samples_ddim = samples_ddim
|
110 |
x_samples_ddim = x_samples_ddim * data_std + data_mean
|
|
|
111 |
x_samples_ddim = model.decode_first_stage(x_samples_ddim)
|
112 |
print ('dfsf3')
|
113 |
#x_samples_ddim = pos_map.to(c['c_concat'].device).unsqueeze(0).expand(-1, 3, -1, -1)
|
@@ -115,7 +116,7 @@ def sample_frame(model: LatentDiffusion, prompt: str, image_sequence: torch.Tens
|
|
115 |
#x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
116 |
x_samples_ddim = torch.clamp(x_samples_ddim, min=-1.0, max=1.0)
|
117 |
|
118 |
-
return x_samples_ddim.squeeze(0).cpu().numpy()
|
119 |
|
120 |
# Global variables for model and device
|
121 |
#model = None
|
|
|
45 |
#print (c['c_crossattn'][0])
|
46 |
print (prompt)
|
47 |
c = {}
|
48 |
+
#c = model.enc_concat_seq(c, c_dict, 'c_concat')
|
49 |
# Zero out the corresponding subtensors in c_concat for padding images
|
50 |
+
#padding_mask = torch.isclose(image_sequence, torch.tensor(-1.0), rtol=1e-5, atol=1e-5).all(dim=(1, 2, 3)).unsqueeze(0)
|
51 |
+
#print (padding_mask)
|
52 |
+
#padding_mask = padding_mask.repeat(1, 4) # Repeat mask 4 times for each projected channel
|
53 |
+
#print (image_sequence.shape, padding_mask.shape, c['c_concat'].shape)
|
54 |
+
#c['c_concat'] = c['c_concat'] * (~padding_mask.unsqueeze(-1).unsqueeze(-1)) # Zero out the corresponding features
|
55 |
data_mean = -0.54
|
56 |
data_std = 6.78
|
57 |
data_min = -27.681446075439453
|
|
|
108 |
data_max = 30.854148864746094
|
109 |
x_samples_ddim = samples_ddim
|
110 |
x_samples_ddim = x_samples_ddim * data_std + data_mean
|
111 |
+
x_samples_ddim_feedback = x_samples_ddim
|
112 |
x_samples_ddim = model.decode_first_stage(x_samples_ddim)
|
113 |
print ('dfsf3')
|
114 |
#x_samples_ddim = pos_map.to(c['c_concat'].device).unsqueeze(0).expand(-1, 3, -1, -1)
|
|
|
116 |
#x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
117 |
x_samples_ddim = torch.clamp(x_samples_ddim, min=-1.0, max=1.0)
|
118 |
|
119 |
+
return x_samples_ddim.squeeze(0).cpu().numpy(), x_samples_ddim_feedback.squeeze(0)
|
120 |
|
121 |
# Global variables for model and device
|
122 |
#model = None
|