tanbushi commited on
Commit
07df554
·
1 Parent(s): 3a942e3

add cohere api

Browse files
Files changed (3) hide show
  1. app.py +2 -0
  2. routers/cohere.py +105 -0
  3. routers/openai_v1_1.py +2 -0
app.py CHANGED
@@ -17,12 +17,14 @@ set('project_root', parent_dir)
17
  # from routers.webtools_v1 import router as webtools_router
18
  from routers.users_v1 import router as users_router
19
  from routers.openai_v1_1 import router as openai_router
 
20
 
21
  app = FastAPI()
22
 
23
  # app.include_router(webtools_router, prefix="/airs/v1", tags=["webtools"])
24
  app.include_router(users_router, prefix="/airs/v1", tags=["users"])
25
  app.include_router(openai_router, prefix="/airs/v1", tags=["openai"])
 
26
 
27
  @app.get("/")
28
  def greet_json():
 
17
  # from routers.webtools_v1 import router as webtools_router
18
  from routers.users_v1 import router as users_router
19
  from routers.openai_v1_1 import router as openai_router
20
+ from routers.cohere import router as cohere_router
21
 
22
  app = FastAPI()
23
 
24
  # app.include_router(webtools_router, prefix="/airs/v1", tags=["webtools"])
25
  app.include_router(users_router, prefix="/airs/v1", tags=["users"])
26
  app.include_router(openai_router, prefix="/airs/v1", tags=["openai"])
27
+ app.include_router(cohere_router, prefix="/airs/cohere/v2", tags=["cohere"])
28
 
29
  @app.get("/")
30
  def greet_json():
routers/cohere.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, Request, Depends
2
+ import requests
3
+
4
+ from global_state import get
5
+ from db_model.user import UserModel
6
+ from auth import get_current_user
7
+ from db.tbs_db import TbsDb
8
+
9
+ router = APIRouter()
10
+
11
+ db_module_filename = f"{get('project_root')}/db/cloudflare.py"
12
+
13
+ print('db_module_filename',db_module_filename)
14
+
15
+ @router.post("/embed")
16
+ async def embed(request:Request, current_user: UserModel = Depends(get_current_user)):
17
+ data = await request.json()
18
+ print(data)
19
+ model = data.get('model', '')
20
+ print(model)
21
+ if (model=='')or(model is None):
22
+ return {'error': 'model is empty'}
23
+ api_key_info = await get_api_key(model)
24
+ api_key = api_key_info.get('api_key', '')
25
+ base_url = api_key_info.get('base_url', '')
26
+ print(api_key_info)
27
+
28
+ headers = {
29
+ 'Content-Type': 'application/json',
30
+ 'Authorization': f'Bearer {api_key}',
31
+ 'User-Agent': 'PostmanRuntime/7.43.0'
32
+ }
33
+
34
+ print(f'base_url: {base_url}')
35
+
36
+ try:
37
+ response = requests.post(url=f"{base_url}/embed", headers=headers, json=data)
38
+ print(response.json())
39
+ return response.json()
40
+ except Exception as e:
41
+ print(e)
42
+ return {'error': e}
43
+
44
+
45
+ @router.post("/rerank")
46
+ async def rerank(request:Request, current_user: UserModel = Depends(get_current_user)):
47
+ data = await request.json()
48
+ print(data)
49
+ model = data.get('model', '')
50
+ print(model)
51
+ if (model=='')or(model is None):
52
+ return {'error': 'model is empty'}
53
+ api_key_info = await get_api_key(model)
54
+ api_key = api_key_info.get('api_key', '')
55
+ base_url = api_key_info.get('base_url', '')
56
+ print(api_key_info)
57
+
58
+ headers = {
59
+ 'Content-Type': 'application/json',
60
+ 'Authorization': f'Bearer {api_key}',
61
+ 'User-Agent': 'PostmanRuntime/7.43.0'
62
+ }
63
+
64
+ print(f'base_url: {base_url}')
65
+
66
+ try:
67
+ response = requests.post(url=f"{base_url}/rerank", headers=headers, json=data)
68
+ print(response.json())
69
+ return response.json()
70
+ except Exception as e:
71
+ print(e)
72
+ return {'error': e}
73
+
74
+ # # 从数据库获取默认模型
75
+ # async def get_default_model():
76
+ # query = f"SELECT * FROM api_names order by default_order limit 1"
77
+ # response = TbsDb(db_module_filename, "Cloudflare").get_item(query)
78
+ # try:
79
+ # result = response['result'][0]['results'][0]['api_name']
80
+ # except:
81
+ # result = ''
82
+ # return result
83
+
84
+ async def get_api_key(model):
85
+ query = f"""
86
+ SELECT an.api_name, ak.api_key, an.base_url, ag.group_name
87
+ FROM api_keys ak
88
+ JOIN api_groups ag ON ak.api_group_id = ag.id
89
+ JOIN api_names an ON an.api_group_id = ag.id
90
+ WHERE an.api_name='{model}' and disabled=0
91
+ ORDER BY ak.last_call_at
92
+ limit 1
93
+ """
94
+ # WHERE ak.category='LLM' and an.api_name='{model}' and disabled=0
95
+
96
+ response = TbsDb(db_module_filename, "Cloudflare").get_item(query)
97
+ try:
98
+ result = response['result'][0]['results'][0]
99
+ api_key = result['api_key']
100
+ except:
101
+ api_key = ''
102
+
103
+ query = f"update api_keys set last_call_at=datetime('now') where api_key='{api_key}'"
104
+ TbsDb(db_module_filename, "Cloudflare").execute_query(query)
105
+ return result
routers/openai_v1_1.py CHANGED
@@ -26,6 +26,8 @@ async def chat_completions(request:Request, current_user: UserModel = Depends(ge
26
  'Authorization': f'Bearer {api_key}',
27
  'User-Agent': 'PostmanRuntime/7.43.0'
28
  }
 
 
29
 
30
  try:
31
  response = requests.post(url=f"{base_url}/chat/completions", headers=headers, json=data)
 
26
  'Authorization': f'Bearer {api_key}',
27
  'User-Agent': 'PostmanRuntime/7.43.0'
28
  }
29
+
30
+ print(f'base_url: {base_url}')
31
 
32
  try:
33
  response = requests.post(url=f"{base_url}/chat/completions", headers=headers, json=data)