Spaces:
Runtime error
Runtime error
Commit
·
3208a74
1
Parent(s):
aca606e
Update main.py
Browse files
main.py
CHANGED
@@ -42,16 +42,43 @@ def draw_trace(image: np.ndarray, previous_actions: List[Tuple[str, List[int]]])
|
|
42 |
model = initialize_model("config_csllm.yaml", "yuntian-deng/computer-model")
|
43 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
44 |
model = model.to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List[Tuple[str, List[int]]]) -> np.ndarray:
|
46 |
width, height = 256, 256
|
47 |
-
|
|
|
48 |
# Prepare the image sequence for the model
|
49 |
image_sequence = previous_frames[-7:] # Take the last 7 frames
|
50 |
while len(image_sequence) < 7:
|
51 |
-
image_sequence.insert(0, np.zeros((height, width, 3), dtype=np.uint8))
|
52 |
-
|
|
|
|
|
53 |
# Convert the image sequence to a tensor and concatenate in the channel dimension
|
54 |
-
image_sequence_tensor = torch.from_numpy(
|
|
|
|
|
55 |
image_sequence_tensor = image_sequence_tensor.to(device)
|
56 |
|
57 |
|
|
|
42 |
model = initialize_model("config_csllm.yaml", "yuntian-deng/computer-model")
|
43 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
44 |
model = model.to(device)
|
45 |
+
|
46 |
+
def load_initial_images(width, height):
|
47 |
+
initial_images = []
|
48 |
+
for i in range(7):
|
49 |
+
image_path = f"image_{i}.png"
|
50 |
+
if os.path.exists(image_path):
|
51 |
+
img = Image.open(image_path).resize((width, height))
|
52 |
+
initial_images.append(np.array(img))
|
53 |
+
else:
|
54 |
+
print(f"Warning: {image_path} not found. Using blank image instead.")
|
55 |
+
initial_images.append(np.zeros((height, width, 3), dtype=np.uint8))
|
56 |
+
return initial_images
|
57 |
+
|
58 |
+
def normalize_images(images, target_range=(-1, 1)):
|
59 |
+
images = np.stack(images).astype(np.float32)
|
60 |
+
if target_range == (-1, 1):
|
61 |
+
return images / 127.5 - 1
|
62 |
+
elif target_range == (0, 1):
|
63 |
+
return images / 255.0
|
64 |
+
else:
|
65 |
+
raise ValueError(f"Unsupported target range: {target_range}")
|
66 |
+
|
67 |
def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List[Tuple[str, List[int]]]) -> np.ndarray:
|
68 |
width, height = 256, 256
|
69 |
+
initial_images = load_initial_images(width, height)
|
70 |
+
|
71 |
# Prepare the image sequence for the model
|
72 |
image_sequence = previous_frames[-7:] # Take the last 7 frames
|
73 |
while len(image_sequence) < 7:
|
74 |
+
#image_sequence.insert(0, np.zeros((height, width, 3), dtype=np.uint8))
|
75 |
+
image_sequence.insert(0, initial_images[len(image_sequence)])
|
76 |
+
|
77 |
+
|
78 |
# Convert the image sequence to a tensor and concatenate in the channel dimension
|
79 |
+
image_sequence_tensor = torch.from_numpy(normalize_images(image_sequence, target_range=(-1, 1)))
|
80 |
+
|
81 |
+
#image_sequence_tensor = torch.from_numpy(np.stack(image_sequence)).float() / 127.5 - 1
|
82 |
image_sequence_tensor = image_sequence_tensor.to(device)
|
83 |
|
84 |
|