yuntian-deng commited on
Commit
e8de28c
·
1 Parent(s): c7afaf1

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +3 -2
main.py CHANGED
@@ -7,7 +7,7 @@ from PIL import Image, ImageDraw
7
  import base64
8
  import io
9
  import asyncio
10
- from utils import initialize_model, sample_frame, device
11
  import torch
12
 
13
  app = FastAPI()
@@ -40,7 +40,8 @@ def draw_trace(image: np.ndarray, previous_actions: List[Tuple[str, List[int]]])
40
 
41
  # Initialize the model at the start of your application
42
  model = initialize_model("config_csllm.yaml", "yuntian-deng/computer-model")
43
-
 
44
  def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List[Tuple[str, List[int]]]) -> np.ndarray:
45
  width, height = 256, 256
46
 
 
7
  import base64
8
  import io
9
  import asyncio
10
+ from utils import initialize_model, sample_frame
11
  import torch
12
 
13
  app = FastAPI()
 
40
 
41
  # Initialize the model at the start of your application
42
  model = initialize_model("config_csllm.yaml", "yuntian-deng/computer-model")
43
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
44
+ model = model.to(device)
45
  def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List[Tuple[str, List[int]]]) -> np.ndarray:
46
  width, height = 256, 256
47