yuntian-deng commited on
Commit
a677593
·
1 Parent(s): 1679b8f

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +28 -7
main.py CHANGED
@@ -7,6 +7,8 @@ from PIL import Image, ImageDraw
7
  import base64
8
  import io
9
  import asyncio
 
 
10
 
11
  app = FastAPI()
12
 
@@ -36,15 +38,34 @@ def draw_trace(image: np.ndarray, previous_actions: List[Tuple[str, List[int]]])
36
 
37
  return np.array(pil_image)
38
 
 
 
 
39
  def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List[Tuple[str, List[int]]]) -> np.ndarray:
40
- width, height = 800, 600
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- if not previous_frames or previous_actions[-1][0] == "move":
43
- # Generate a new random image when there's no previous frame or the mouse moves
44
- new_frame = generate_random_image(width, height)
45
- else:
46
- # Use the last frame if it exists and the action is not a mouse move
47
- new_frame = previous_frames[-1].copy()
48
 
49
  # Draw the trace of previous actions
50
  new_frame_with_trace = draw_trace(new_frame, previous_actions)
 
7
  import base64
8
  import io
9
  import asyncio
10
+ from utils import initialize_model, sample_frame, device
11
+ import torch
12
 
13
  app = FastAPI()
14
 
 
38
 
39
  return np.array(pil_image)
40
 
41
+ # Initialize the model at the start of your application
42
+ initialize_model("config_csllm.yaml", "yuntian-deng/computer-model")
43
+
44
  def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List[Tuple[str, List[int]]]) -> np.ndarray:
45
+ width, height = 256, 256
46
+
47
+ # Prepare the image sequence for the model
48
+ image_sequence = previous_frames[-7:] # Take the last 7 frames
49
+ while len(image_sequence) < 7:
50
+ image_sequence.insert(0, np.zeros((height, width, 3), dtype=np.uint8))
51
+
52
+ # Convert the image sequence to a tensor
53
+ image_sequence_tensor = torch.from_numpy(np.stack(image_sequence)).permute(0, 3, 1, 2).float() / 127.5 - 1
54
+ image_sequence_tensor = image_sequence_tensor.unsqueeze(0).to(device)
55
+
56
+ # Prepare the prompt based on the previous actions
57
+ action_descriptions = [f"{action} at ({pos[0]}, {pos[1]})" for action, pos in previous_actions[-7:]]
58
+ prompt = "A sequence of actions: " + ", ".join(action_descriptions)
59
+
60
+ # Generate the next frame
61
+ new_frame = sample_frame(model, prompt, image_sequence_tensor)
62
+
63
+ # Convert the generated frame to the correct format
64
+ new_frame = (new_frame * 255).astype(np.uint8).transpose(1, 2, 0)
65
 
66
+ # Resize the frame to 256x256 if necessary
67
+ if new_frame.shape[:2] != (height, width):
68
+ new_frame = np.array(Image.fromarray(new_frame).resize((width, height)))
 
 
 
69
 
70
  # Draw the trace of previous actions
71
  new_frame_with_trace = draw_trace(new_frame, previous_actions)