tanbushi commited on
Commit
c154b96
·
1 Parent(s): 30c165e
Files changed (2) hide show
  1. routers/cohere.py +0 -2
  2. routers/openai_v1_1.py +20 -4
routers/cohere.py CHANGED
@@ -10,8 +10,6 @@ router = APIRouter()
10
 
11
  db_module_filename = f"{get('project_root')}/db/cloudflare.py"
12
 
13
- print('db_module_filename',db_module_filename)
14
-
15
  @router.post("/embed")
16
  async def embed(request:Request, current_user: UserModel = Depends(get_current_user)):
17
  data = await request.json()
 
10
 
11
  db_module_filename = f"{get('project_root')}/db/cloudflare.py"
12
 
 
 
13
  @router.post("/embed")
14
  async def embed(request:Request, current_user: UserModel = Depends(get_current_user)):
15
  data = await request.json()
routers/openai_v1_1.py CHANGED
@@ -1,6 +1,8 @@
1
  from fastapi import APIRouter, Depends, Request
2
  import requests
3
  from datetime import datetime
 
 
4
 
5
  from global_state import get
6
  from db.tbs_db import TbsDb
@@ -9,29 +11,43 @@ from db_model.user import UserModel
9
 
10
  router = APIRouter()
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  db_module_filename = f"{get('project_root')}/db/cloudflare.py"
13
 
14
  @router.post("/chat/completions")
15
  async def chat_completions(request:Request, current_user: UserModel = Depends(get_current_user)):
16
  data = await request.json()
17
- model = data.get('model', '')
18
  if (model=='')or(model is None):
19
  model = await get_default_model()
20
  api_key_info = await get_api_key(model)
21
  api_key = api_key_info.get('api_key', '')
22
  base_url = api_key_info.get('base_url', '')
 
 
 
23
 
24
  headers = {
25
  'Content-Type': 'application/json',
26
  'Authorization': f'Bearer {api_key}',
27
  'User-Agent': 'PostmanRuntime/7.43.0'
28
  }
29
-
30
- print(f'base_url: {base_url}')
31
 
32
  try:
33
  response = requests.post(url=f"{base_url}/chat/completions", headers=headers, json=data)
34
- # print(response.json())
35
  return response.json()
36
  except Exception as e:
37
  print(e)
 
1
  from fastapi import APIRouter, Depends, Request
2
  import requests
3
  from datetime import datetime
4
+ from pydantic import BaseModel, Field
5
+ from typing import List, Dict
6
 
7
  from global_state import get
8
  from db.tbs_db import TbsDb
 
11
 
12
  router = APIRouter()
13
 
14
+ class Message(BaseModel):
15
+ role: str
16
+ content: str
17
+
18
+ class GeminiModel(BaseModel):
19
+ model: str
20
+ temperature: float = Field(..., gt=0)
21
+ top_p: float = Field(..., gt=0)
22
+ # frequency_penalty: float = Field(..., ge=0)
23
+ # presence_penalty: float = Field(..., ge=0)
24
+ n: int = Field(..., ge=0)
25
+ stream: bool
26
+ messages: List[Message]
27
+
28
  db_module_filename = f"{get('project_root')}/db/cloudflare.py"
29
 
30
  @router.post("/chat/completions")
31
  async def chat_completions(request:Request, current_user: UserModel = Depends(get_current_user)):
32
  data = await request.json()
33
+ model = data.get('model', '')
34
  if (model=='')or(model is None):
35
  model = await get_default_model()
36
  api_key_info = await get_api_key(model)
37
  api_key = api_key_info.get('api_key', '')
38
  base_url = api_key_info.get('base_url', '')
39
+ group_name=api_key_info.get('group_name', '')
40
+ if group_name=='gemini':
41
+ data = GeminiModel(**data).model_dump()
42
 
43
  headers = {
44
  'Content-Type': 'application/json',
45
  'Authorization': f'Bearer {api_key}',
46
  'User-Agent': 'PostmanRuntime/7.43.0'
47
  }
 
 
48
 
49
  try:
50
  response = requests.post(url=f"{base_url}/chat/completions", headers=headers, json=data)
 
51
  return response.json()
52
  except Exception as e:
53
  print(e)