yuntian-deng commited on
Commit
8435838
·
1 Parent(s): 19b663b

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +40 -13
main.py CHANGED
@@ -1,20 +1,47 @@
1
- from fastapi import FastAPI
 
2
  from fastapi.staticfiles import StaticFiles
3
- from fastapi.responses import FileResponse
4
-
5
- from transformers import pipeline
6
 
7
  app = FastAPI()
8
 
9
- pipe_flan = pipeline("text2text-generation", model="google/flan-t5-small")
 
 
 
 
 
 
10
 
11
- @app.get("/infer_t5")
12
- def t5(input):
13
- output = pipe_flan(input)
14
- return {"output": output[0]["generated_text"]}
15
 
16
- app.mount("/", StaticFiles(directory="static", html=True), name="static")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- @app.get("/")
19
- def index() -> FileResponse:
20
- return FileResponse(path="/app/static/index.html", media_type="text/html")
 
1
+ from fastapi import FastAPI, WebSocket
2
+ from fastapi.responses import HTMLResponse
3
  from fastapi.staticfiles import StaticFiles
4
+ from typing import List
5
+ import numpy as np
 
6
 
7
  app = FastAPI()
8
 
9
+ # Mount the static directory to serve HTML, JavaScript, and CSS files
10
+ app.mount("/static", StaticFiles(directory="static"), name="static")
11
+
12
+ # Serve the index.html file at the root URL
13
+ @app.get("/")
14
+ async def get():
15
+ return HTMLResponse(open("static/index.html").read())
16
 
17
+ # Simulate your diffusion model (placeholder)
18
+ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List[str]) -> np.ndarray:
19
+ return np.zeros((800, 600, 3), dtype=np.uint8)
 
20
 
21
+ # WebSocket endpoint for continuous user interaction
22
+ @app.websocket("/ws")
23
+ async def websocket_endpoint(websocket: WebSocket):
24
+ await websocket.accept()
25
+ previous_frames = []
26
+ previous_actions = []
27
+
28
+ try:
29
+ while True:
30
+ # Receive user input (mouse movement, click, etc.)
31
+ data = await websocket.receive_json()
32
+ action_type = data.get("action_type")
33
+ mouse_position = data.get("mouse_position")
34
+
35
+ # Store the actions
36
+ previous_actions.append((action_type, mouse_position))
37
+
38
+ # Predict the next frame based on the previous frames and actions
39
+ next_frame = predict_next_frame(previous_frames, previous_actions)
40
+ previous_frames.append(next_frame)
41
+
42
+ # Send the generated frame back to the client (encoded as base64 or similar)
43
+ await websocket.send_text("Next frame generated") # Replace with real image sending logic
44
 
45
+ except Exception as e:
46
+ print(f"Error: {e}")
47
+ await websocket.close()