Spaces:
Runtime error
Runtime error
Commit
·
a677593
1
Parent(s):
1679b8f
Update main.py
Browse files
main.py
CHANGED
|
@@ -7,6 +7,8 @@ from PIL import Image, ImageDraw
|
|
| 7 |
import base64
|
| 8 |
import io
|
| 9 |
import asyncio
|
|
|
|
|
|
|
| 10 |
|
| 11 |
app = FastAPI()
|
| 12 |
|
|
@@ -36,15 +38,34 @@ def draw_trace(image: np.ndarray, previous_actions: List[Tuple[str, List[int]]])
|
|
| 36 |
|
| 37 |
return np.array(pil_image)
|
| 38 |
|
|
|
|
|
|
|
|
|
|
| 39 |
def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List[Tuple[str, List[int]]]) -> np.ndarray:
|
| 40 |
-
width, height =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
new_frame =
|
| 45 |
-
else:
|
| 46 |
-
# Use the last frame if it exists and the action is not a mouse move
|
| 47 |
-
new_frame = previous_frames[-1].copy()
|
| 48 |
|
| 49 |
# Draw the trace of previous actions
|
| 50 |
new_frame_with_trace = draw_trace(new_frame, previous_actions)
|
|
|
|
| 7 |
import base64
|
| 8 |
import io
|
| 9 |
import asyncio
|
| 10 |
+
from utils import initialize_model, sample_frame, device
|
| 11 |
+
import torch
|
| 12 |
|
| 13 |
app = FastAPI()
|
| 14 |
|
|
|
|
| 38 |
|
| 39 |
return np.array(pil_image)
|
| 40 |
|
| 41 |
+
# Initialize the model at the start of your application
|
| 42 |
+
initialize_model("config_csllm.yaml", "yuntian-deng/computer-model")
|
| 43 |
+
|
| 44 |
def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List[Tuple[str, List[int]]]) -> np.ndarray:
|
| 45 |
+
width, height = 256, 256
|
| 46 |
+
|
| 47 |
+
# Prepare the image sequence for the model
|
| 48 |
+
image_sequence = previous_frames[-7:] # Take the last 7 frames
|
| 49 |
+
while len(image_sequence) < 7:
|
| 50 |
+
image_sequence.insert(0, np.zeros((height, width, 3), dtype=np.uint8))
|
| 51 |
+
|
| 52 |
+
# Convert the image sequence to a tensor
|
| 53 |
+
image_sequence_tensor = torch.from_numpy(np.stack(image_sequence)).permute(0, 3, 1, 2).float() / 127.5 - 1
|
| 54 |
+
image_sequence_tensor = image_sequence_tensor.unsqueeze(0).to(device)
|
| 55 |
+
|
| 56 |
+
# Prepare the prompt based on the previous actions
|
| 57 |
+
action_descriptions = [f"{action} at ({pos[0]}, {pos[1]})" for action, pos in previous_actions[-7:]]
|
| 58 |
+
prompt = "A sequence of actions: " + ", ".join(action_descriptions)
|
| 59 |
+
|
| 60 |
+
# Generate the next frame
|
| 61 |
+
new_frame = sample_frame(model, prompt, image_sequence_tensor)
|
| 62 |
+
|
| 63 |
+
# Convert the generated frame to the correct format
|
| 64 |
+
new_frame = (new_frame * 255).astype(np.uint8).transpose(1, 2, 0)
|
| 65 |
|
| 66 |
+
# Resize the frame to 256x256 if necessary
|
| 67 |
+
if new_frame.shape[:2] != (height, width):
|
| 68 |
+
new_frame = np.array(Image.fromarray(new_frame).resize((width, height)))
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
# Draw the trace of previous actions
|
| 71 |
new_frame_with_trace = draw_trace(new_frame, previous_actions)
|