tanbushi commited on
Commit
2c7fa70
·
1 Parent(s): daa9d8a

test sqlite-web

Browse files
Files changed (3) hide show
  1. requirements.txt +2 -1
  2. routers/openai_v1.py +32 -12
  3. routers/webtools_v1.py +1 -1
requirements.txt CHANGED
@@ -3,4 +3,5 @@ uvicorn[standard]
3
  captcha
4
  python-dotenv
5
  langchain_google_genai==1.0.10
6
- langchain-openai==0.1.25 # 使用 openai 或兼容 openai
 
 
3
  captcha
4
  python-dotenv
5
  langchain_google_genai==1.0.10
6
+ langchain-openai==0.1.25 # 使用 openai 或兼容 openai
7
+ sqlite-web
routers/openai_v1.py CHANGED
@@ -15,17 +15,26 @@ async def chat_completions(request:Request):
15
  print('chat_completions')
16
  auth_header = request.headers['authorization']
17
  user_api_key = auth_header.split()[1] # 分割字符串并取第二个元素
 
 
 
 
18
  data = await request.json()
19
  try:
20
  model = data['model']
21
  except:
22
  model = ''
 
23
  if model=='':
24
  model = await get_default_model()
25
  api_key_info = await get_api_key(model)
 
26
  api_key = api_key_info['api_key']
27
  group_name = api_key_info['group_name']
28
-
 
 
 
29
 
30
  # 下面是google api
31
  if group_name=='gemini':
@@ -35,17 +44,15 @@ async def chat_completions(request:Request):
35
  api_key = api_key,
36
  model = model,
37
  )
38
-
39
 
40
  # 下面就是 chatgpt 兼容 api
41
- if group_name=='glm':
42
- from langchain_openai import ChatOpenAI
43
- # 初始化 ChatOpenAI 模型
44
- llm = ChatOpenAI(
45
- model = model,
46
- api_key = api_key,
47
- base_url = 'https://open.bigmodel.cn/api/paas/v4', # 下一步将base_url 写到数据库里
48
- )
49
 
50
  messages = data['messages']
51
  rslt = [(item["role"], item["content"]) for item in messages] # 转换为所需的元组列表格式
@@ -53,7 +60,10 @@ async def chat_completions(request:Request):
53
  prompt_template = ChatPromptTemplate.from_messages(rslt)
54
 
55
  chain = prompt_template | llm | parser
56
- result = chain.invoke({})
 
 
 
57
  return result
58
 
59
  async def get_default_model():
@@ -66,7 +76,7 @@ async def get_default_model():
66
 
67
  async def get_api_key(model):
68
  sql = f"""
69
- SELECT an.api_name, ak.api_key, ag.group_name
70
  FROM api_keys ak
71
  JOIN api_groups ag ON ak.api_group_id = ag.id
72
  JOIN api_names an ON an.api_group_id = ag.id
@@ -75,6 +85,7 @@ ORDER BY ak.last_call_at
75
  limit 1
76
  """
77
  results = Db(os.getenv("DB_PATH")).list_query(sql)
 
78
  try:
79
  result = results[0]
80
  api_key = result['api_key']
@@ -84,3 +95,12 @@ limit 1
84
  sql = f"update api_keys set last_call_at=datetime('now') where api_key='{api_key}'"
85
  Db(os.getenv("DB_PATH")).update_query(sql)
86
  return result
 
 
 
 
 
 
 
 
 
 
15
  print('chat_completions')
16
  auth_header = request.headers['authorization']
17
  user_api_key = auth_header.split()[1] # 分割字符串并取第二个元素
18
+ user_id = await get_user_id(user_api_key)
19
+ if user_id==0:
20
+ return {'error': {'code': '401', 'message': 'api_key 已过期或验证不正确!'}}
21
+
22
  data = await request.json()
23
  try:
24
  model = data['model']
25
  except:
26
  model = ''
27
+
28
  if model=='':
29
  model = await get_default_model()
30
  api_key_info = await get_api_key(model)
31
+
32
  api_key = api_key_info['api_key']
33
  group_name = api_key_info['group_name']
34
+ try:
35
+ base_url = api_key_info['base_url']
36
+ except:
37
+ base_url = ''
38
 
39
  # 下面是google api
40
  if group_name=='gemini':
 
44
  api_key = api_key,
45
  model = model,
46
  )
 
47
 
48
  # 下面就是 chatgpt 兼容 api
49
+ from langchain_openai import ChatOpenAI
50
+ # 初始化 ChatOpenAI 模型
51
+ llm = ChatOpenAI(
52
+ model = model,
53
+ api_key = api_key,
54
+ base_url = base_url,
55
+ )
 
56
 
57
  messages = data['messages']
58
  rslt = [(item["role"], item["content"]) for item in messages] # 转换为所需的元组列表格式
 
60
  prompt_template = ChatPromptTemplate.from_messages(rslt)
61
 
62
  chain = prompt_template | llm | parser
63
+ try:
64
+ result = chain.invoke({})
65
+ except Exception as e:
66
+ return {'error': str(e)}
67
  return result
68
 
69
  async def get_default_model():
 
76
 
77
  async def get_api_key(model):
78
  sql = f"""
79
+ SELECT an.api_name, ak.api_key, an.base_url, ag.group_name
80
  FROM api_keys ak
81
  JOIN api_groups ag ON ak.api_group_id = ag.id
82
  JOIN api_names an ON an.api_group_id = ag.id
 
85
  limit 1
86
  """
87
  results = Db(os.getenv("DB_PATH")).list_query(sql)
88
+ # return results
89
  try:
90
  result = results[0]
91
  api_key = result['api_key']
 
95
  sql = f"update api_keys set last_call_at=datetime('now') where api_key='{api_key}'"
96
  Db(os.getenv("DB_PATH")).update_query(sql)
97
  return result
98
+
99
+ async def get_user_id(api_key):
100
+ sql = f"select id from users where api_key='{api_key}'"
101
+ results = Db(os.getenv("DB_PATH")).list_query(sql)
102
+ try:
103
+ result = results[0]['id']
104
+ except:
105
+ result = 0
106
+ return result
routers/webtools_v1.py CHANGED
@@ -1,5 +1,5 @@
1
  from fastapi import FastAPI, APIRouter, Response, Depends
2
- from captcha.image import ImageCaptcha
3
  from io import BytesIO
4
 
5
  router = APIRouter()
 
1
  from fastapi import FastAPI, APIRouter, Response, Depends
2
+ # from captcha.image import ImageCaptcha
3
  from io import BytesIO
4
 
5
  router = APIRouter()