yuntian-deng commited on
Commit
3208a74
·
1 Parent(s): aca606e

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +31 -4
main.py CHANGED
@@ -42,16 +42,43 @@ def draw_trace(image: np.ndarray, previous_actions: List[Tuple[str, List[int]]])
42
  model = initialize_model("config_csllm.yaml", "yuntian-deng/computer-model")
43
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
44
  model = model.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List[Tuple[str, List[int]]]) -> np.ndarray:
46
  width, height = 256, 256
47
-
 
48
  # Prepare the image sequence for the model
49
  image_sequence = previous_frames[-7:] # Take the last 7 frames
50
  while len(image_sequence) < 7:
51
- image_sequence.insert(0, np.zeros((height, width, 3), dtype=np.uint8))
52
-
 
 
53
  # Convert the image sequence to a tensor and concatenate in the channel dimension
54
- image_sequence_tensor = torch.from_numpy(np.stack(image_sequence)).float() / 127.5 - 1
 
 
55
  image_sequence_tensor = image_sequence_tensor.to(device)
56
 
57
 
 
42
  model = initialize_model("config_csllm.yaml", "yuntian-deng/computer-model")
43
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
44
  model = model.to(device)
45
+
46
+ def load_initial_images(width, height):
47
+ initial_images = []
48
+ for i in range(7):
49
+ image_path = f"image_{i}.png"
50
+ if os.path.exists(image_path):
51
+ img = Image.open(image_path).resize((width, height))
52
+ initial_images.append(np.array(img))
53
+ else:
54
+ print(f"Warning: {image_path} not found. Using blank image instead.")
55
+ initial_images.append(np.zeros((height, width, 3), dtype=np.uint8))
56
+ return initial_images
57
+
58
+ def normalize_images(images, target_range=(-1, 1)):
59
+ images = np.stack(images).astype(np.float32)
60
+ if target_range == (-1, 1):
61
+ return images / 127.5 - 1
62
+ elif target_range == (0, 1):
63
+ return images / 255.0
64
+ else:
65
+ raise ValueError(f"Unsupported target range: {target_range}")
66
+
67
  def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List[Tuple[str, List[int]]]) -> np.ndarray:
68
  width, height = 256, 256
69
+ initial_images = load_initial_images(width, height)
70
+
71
  # Prepare the image sequence for the model
72
  image_sequence = previous_frames[-7:] # Take the last 7 frames
73
  while len(image_sequence) < 7:
74
+ #image_sequence.insert(0, np.zeros((height, width, 3), dtype=np.uint8))
75
+ image_sequence.insert(0, initial_images[len(image_sequence)])
76
+
77
+
78
  # Convert the image sequence to a tensor and concatenate in the channel dimension
79
+ image_sequence_tensor = torch.from_numpy(normalize_images(image_sequence, target_range=(-1, 1)))
80
+
81
+ #image_sequence_tensor = torch.from_numpy(np.stack(image_sequence)).float() / 127.5 - 1
82
  image_sequence_tensor = image_sequence_tensor.to(device)
83
 
84