Spaces:
Runtime error
Runtime error
Commit
·
a677593
1
Parent(s):
1679b8f
Update main.py
Browse files
main.py
CHANGED
@@ -7,6 +7,8 @@ from PIL import Image, ImageDraw
|
|
7 |
import base64
|
8 |
import io
|
9 |
import asyncio
|
|
|
|
|
10 |
|
11 |
app = FastAPI()
|
12 |
|
@@ -36,15 +38,34 @@ def draw_trace(image: np.ndarray, previous_actions: List[Tuple[str, List[int]]])
|
|
36 |
|
37 |
return np.array(pil_image)
|
38 |
|
|
|
|
|
|
|
39 |
def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List[Tuple[str, List[int]]]) -> np.ndarray:
|
40 |
-
width, height =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
-
|
43 |
-
|
44 |
-
new_frame =
|
45 |
-
else:
|
46 |
-
# Use the last frame if it exists and the action is not a mouse move
|
47 |
-
new_frame = previous_frames[-1].copy()
|
48 |
|
49 |
# Draw the trace of previous actions
|
50 |
new_frame_with_trace = draw_trace(new_frame, previous_actions)
|
|
|
7 |
import base64
|
8 |
import io
|
9 |
import asyncio
|
10 |
+
from utils import initialize_model, sample_frame, device
|
11 |
+
import torch
|
12 |
|
13 |
app = FastAPI()
|
14 |
|
|
|
38 |
|
39 |
return np.array(pil_image)
|
40 |
|
41 |
+
# Initialize the model at the start of your application
|
42 |
+
initialize_model("config_csllm.yaml", "yuntian-deng/computer-model")
|
43 |
+
|
44 |
def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List[Tuple[str, List[int]]]) -> np.ndarray:
|
45 |
+
width, height = 256, 256
|
46 |
+
|
47 |
+
# Prepare the image sequence for the model
|
48 |
+
image_sequence = previous_frames[-7:] # Take the last 7 frames
|
49 |
+
while len(image_sequence) < 7:
|
50 |
+
image_sequence.insert(0, np.zeros((height, width, 3), dtype=np.uint8))
|
51 |
+
|
52 |
+
# Convert the image sequence to a tensor
|
53 |
+
image_sequence_tensor = torch.from_numpy(np.stack(image_sequence)).permute(0, 3, 1, 2).float() / 127.5 - 1
|
54 |
+
image_sequence_tensor = image_sequence_tensor.unsqueeze(0).to(device)
|
55 |
+
|
56 |
+
# Prepare the prompt based on the previous actions
|
57 |
+
action_descriptions = [f"{action} at ({pos[0]}, {pos[1]})" for action, pos in previous_actions[-7:]]
|
58 |
+
prompt = "A sequence of actions: " + ", ".join(action_descriptions)
|
59 |
+
|
60 |
+
# Generate the next frame
|
61 |
+
new_frame = sample_frame(model, prompt, image_sequence_tensor)
|
62 |
+
|
63 |
+
# Convert the generated frame to the correct format
|
64 |
+
new_frame = (new_frame * 255).astype(np.uint8).transpose(1, 2, 0)
|
65 |
|
66 |
+
# Resize the frame to 256x256 if necessary
|
67 |
+
if new_frame.shape[:2] != (height, width):
|
68 |
+
new_frame = np.array(Image.fromarray(new_frame).resize((width, height)))
|
|
|
|
|
|
|
69 |
|
70 |
# Draw the trace of previous actions
|
71 |
new_frame_with_trace = draw_trace(new_frame, previous_actions)
|