yuntian-deng commited on
Commit
5a39c97
·
1 Parent(s): eea5d6f

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +16 -4
main.py CHANGED
@@ -9,6 +9,7 @@ import io
9
  import asyncio
10
  from utils import initialize_model, sample_frame
11
  import torch
 
12
 
13
  app = FastAPI()
14
 
@@ -64,6 +65,14 @@ def normalize_images(images, target_range=(-1, 1)):
64
  else:
65
  raise ValueError(f"Unsupported target range: {target_range}")
66
 
 
 
 
 
 
 
 
 
67
  def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List[Tuple[str, List[int]]]) -> np.ndarray:
68
  width, height = 256, 256
69
  initial_images = load_initial_images(width, height)
@@ -107,14 +116,17 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
107
  new_frame = sample_frame(model, prompt, image_sequence_tensor)
108
 
109
  # Convert the generated frame to the correct format
110
- new_frame = (new_frame * 255).astype(np.uint8).transpose(1, 2, 0)
111
 
112
  # Resize the frame to 256x256 if necessary
113
- if new_frame.shape[:2] != (height, width):
114
- new_frame = np.array(Image.fromarray(new_frame).resize((width, height)))
 
 
 
115
 
116
  # Draw the trace of previous actions
117
- new_frame_with_trace = draw_trace(new_frame, previous_actions)
118
 
119
  return new_frame_with_trace
120
 
 
9
  import asyncio
10
  from utils import initialize_model, sample_frame
11
  import torch
12
+ import os
13
 
14
  app = FastAPI()
15
 
 
65
  else:
66
  raise ValueError(f"Unsupported target range: {target_range}")
67
 
68
+ def denormalize_image(image, source_range=(-1, 1)):
69
+ if source_range == (-1, 1):
70
+ return ((image + 1) * 127.5).clip(0, 255).astype(np.uint8)
71
+ elif source_range == (0, 1):
72
+ return (image * 255).clip(0, 255).astype(np.uint8)
73
+ else:
74
+ raise ValueError(f"Unsupported source range: {source_range}")
75
+
76
  def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List[Tuple[str, List[int]]]) -> np.ndarray:
77
  width, height = 256, 256
78
  initial_images = load_initial_images(width, height)
 
116
  new_frame = sample_frame(model, prompt, image_sequence_tensor)
117
 
118
  # Convert the generated frame to the correct format
119
+ #new_frame = (new_frame * 255).astype(np.uint8).transpose(1, 2, 0)
120
 
121
  # Resize the frame to 256x256 if necessary
122
+ #if new_frame.shape[:2] != (height, width):
123
+ # new_frame = np.array(Image.fromarray(new_frame).resize((width, height)))
124
+
125
+ new_frame_denormalized = denormalize_image(new_frame.cpu().numpy(), source_range=(-1, 1))
126
+
127
 
128
  # Draw the trace of previous actions
129
+ new_frame_with_trace = draw_trace(new_frame_denormalized, previous_actions)
130
 
131
  return new_frame_with_trace
132