Spaces:
Runtime error
Runtime error
da03
commited on
Commit
·
09eba18
1
Parent(s):
b403864
main.py
CHANGED
@@ -13,7 +13,7 @@ import os
|
|
13 |
import time
|
14 |
|
15 |
DEBUG = False
|
16 |
-
DEBUG_TEACHER_FORCING =
|
17 |
app = FastAPI()
|
18 |
|
19 |
# Mount the static directory to serve HTML, JavaScript, and CSS files
|
@@ -407,6 +407,8 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
407 |
mouse_position = (x, y)
|
408 |
|
409 |
previous_actions.append((action_type, mouse_position))
|
|
|
|
|
410 |
try:
|
411 |
while True:
|
412 |
try:
|
@@ -437,7 +439,8 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
437 |
if False:
|
438 |
previous_actions.append((action_type, mouse_position))
|
439 |
#previous_actions = [(action_type, mouse_position)]
|
440 |
-
|
|
|
441 |
# Log the start time
|
442 |
start_time = time.time()
|
443 |
|
|
|
13 |
import time
|
14 |
|
15 |
DEBUG = False
|
16 |
+
DEBUG_TEACHER_FORCING = False
|
17 |
app = FastAPI()
|
18 |
|
19 |
# Mount the static directory to serve HTML, JavaScript, and CSS files
|
|
|
407 |
mouse_position = (x, y)
|
408 |
|
409 |
previous_actions.append((action_type, mouse_position))
|
410 |
+
if not DEBUG_TEACHER_FORCING:
|
411 |
+
previous_actions = []
|
412 |
try:
|
413 |
while True:
|
414 |
try:
|
|
|
439 |
if False:
|
440 |
previous_actions.append((action_type, mouse_position))
|
441 |
#previous_actions = [(action_type, mouse_position)]
|
442 |
+
if not DEBUG_TEACHER_FORCING:
|
443 |
+
previous_actions.append((action_type, mouse_position))
|
444 |
# Log the start time
|
445 |
start_time = time.time()
|
446 |
|
utils.py
CHANGED
@@ -57,6 +57,12 @@ def sample_frame(model: LatentDiffusion, prompt: str, image_sequence: torch.Tens
|
|
57 |
print ('finished sleeping')
|
58 |
DDPM = False
|
59 |
DDPM = True
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
if DDPM:
|
61 |
samples_ddim = model.p_sample_loop(cond=c, shape=[1, 4, 48, 64], return_intermediates=False, verbose=True)
|
62 |
else:
|
@@ -68,8 +74,13 @@ def sample_frame(model: LatentDiffusion, prompt: str, image_sequence: torch.Tens
|
|
68 |
# unconditional_guidance_scale=5.0,
|
69 |
# unconditional_conditioning=uc,
|
70 |
# eta=0)
|
71 |
-
|
72 |
-
|
|
|
|
|
|
|
|
|
|
|
73 |
#x_samples_ddim = pos_map.to(c['c_concat'].device).unsqueeze(0).expand(-1, 3, -1, -1)
|
74 |
#x_samples_ddim = model.decode_first_stage(x_samples_ddim)
|
75 |
#x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
|
|
57 |
print ('finished sleeping')
|
58 |
DDPM = False
|
59 |
DDPM = True
|
60 |
+
|
61 |
+
DEBUG = True
|
62 |
+
if DEBUG:
|
63 |
+
c['c_concat'] = c['c_concat']*0
|
64 |
+
print ('utils prompt', prompt)
|
65 |
+
|
66 |
if DDPM:
|
67 |
samples_ddim = model.p_sample_loop(cond=c, shape=[1, 4, 48, 64], return_intermediates=False, verbose=True)
|
68 |
else:
|
|
|
74 |
# unconditional_guidance_scale=5.0,
|
75 |
# unconditional_conditioning=uc,
|
76 |
# eta=0)
|
77 |
+
if DEBUG:
|
78 |
+
print ('samples_ddim.shape', samples_ddim.shape)
|
79 |
+
x_samples_ddim = samples_ddim[:, :3]
|
80 |
+
# upsample to 512 x 368
|
81 |
+
x_samples_ddim = torch.nn.functional.interpolate(x_samples_ddim, size=(512, 368), mode='bilinear')
|
82 |
+
else:
|
83 |
+
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
84 |
#x_samples_ddim = pos_map.to(c['c_concat'].device).unsqueeze(0).expand(-1, 3, -1, -1)
|
85 |
#x_samples_ddim = model.decode_first_stage(x_samples_ddim)
|
86 |
#x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|