Mbonea commited on
Commit
6acbc8e
·
1 Parent(s): 925b038

proxy time

Browse files
Files changed (1) hide show
  1. App/Chat/PoeChatrouter.py +39 -2
App/Chat/PoeChatrouter.py CHANGED
@@ -1,11 +1,48 @@
1
- from fastapi import APIRouter
2
  from .utils.PoeBot import SendMessage, GenerateImage
3
  from .Schemas import BotRequest
4
-
 
5
 
6
  chat_router = APIRouter(tags=["Chat"])
7
 
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  @chat_router.post("/chat")
10
  async def chat(req: BotRequest):
11
  return await SendMessage(req)
 
1
+ from fastapi import APIRouter, HTTPException
2
  from .utils.PoeBot import SendMessage, GenerateImage
3
  from .Schemas import BotRequest
4
+ from aiohttp import ClientSession
5
+ from pydantic import BaseModel
6
 
7
  chat_router = APIRouter(tags=["Chat"])
8
 
9
 
10
+ class InputData(BaseModel):
11
+ input: dict
12
+ version: str = "727e49a643e999d602a896c774a0658ffefea21465756a6ce24b7ea4165eba6a"
13
+
14
+
15
+ async def fetch_predictions(data):
16
+ async with ClientSession() as session:
17
+ async with session.post(
18
+ "https://replicate.com/api/predictions", json=data
19
+ ) as response:
20
+ return await response.json(), response.status
21
+
22
+
23
+ async def fetch_result(id):
24
+ url = f"https://replicate.com/api/predictions/{id}"
25
+ async with ClientSession() as session:
26
+ async with session.get(url) as response:
27
+ return await response.json(), response.status
28
+
29
+
30
+ @chat_router.post("/predictions")
31
+ async def get_predictions(input_data: InputData):
32
+ data = {
33
+ "input": input_data.input,
34
+ "is_training": False,
35
+ "create_model": "0",
36
+ "stream": False,
37
+ "version": input_data.version,
38
+ }
39
+ try:
40
+ predictions, status_code = await fetch_predictions(data)
41
+ return predictions, status_code
42
+ except Exception as e:
43
+ raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}")
44
+
45
+
46
  @chat_router.post("/chat")
47
  async def chat(req: BotRequest):
48
  return await SendMessage(req)