tanbushi commited on
Commit
10bcd3f
·
1 Parent(s): e66bcc7
Files changed (6) hide show
  1. app.py +10 -64
  2. auth.py +2 -3
  3. db/cloudflare.py +0 -3
  4. db/tbs_db.py +1 -0
  5. db_model/chat.py +11 -0
  6. routers/openai_v1.py +42 -49
app.py CHANGED
@@ -1,83 +1,29 @@
1
  # uvicorn app:app --host 0.0.0.0 --port 7860 --reload
2
 
 
 
3
  from pathlib import Path
 
 
 
4
  from global_state import set
5
 
 
 
6
  # 获取当前文件的父目录的绝对路径,即:project_root
7
  parent_dir = Path(__file__).resolve().parent
8
  set('project_root', parent_dir)
9
 
10
-
11
- from fastapi import FastAPI, HTTPException, Response, Depends
12
- from fastapi.responses import HTMLResponse
13
- # from starlette.requests import Request
14
-
15
  # from routers.webtools_v1 import router as webtools_router
16
  from routers.users_v1 import router as users_router
17
- # from routers.openai_v1 import router as openai_router
18
-
19
- # from db import Db
20
-
21
- from dotenv import load_dotenv
22
- import os
23
-
24
- # 当在本地使用.env 文件时加载.env文件
25
- load_dotenv()
26
-
27
- # 获取环境变量,本地使用 load_dotenv 和 hf 里直接使用环境变量配置时,都可使用下面的语句
28
- DB_PATH = os.getenv("DB_PATH")
29
-
30
- # 依赖注入,获取 DB_PATH
31
- def get_db_path():
32
- return DB_PATH
33
 
34
  app = FastAPI()
35
 
36
  # app.include_router(webtools_router, prefix="/airs/v1", tags=["webtools"])
37
  app.include_router(users_router, prefix="/airs/v1", tags=["users"])
38
- # app.include_router(openai_router, prefix="/airs/v1", tags=["openai"])
39
 
40
  @app.get("/")
41
- def greet_json(db_path: str = Depends(get_db_path)):
42
  return {"Hello": "World!"}
43
- # @app.get("/init_db")
44
- # def init_db():
45
- # sql = """
46
- # CREATE TABLE api_keys (
47
- # id INTEGER PRIMARY KEY AUTOINCREMENT,
48
- # api_key TEXT NOT NULL,
49
- # type TEXT NOT NULL,
50
- # status INTEGER NOT NULL,
51
- # idx INTEGER NOT NULL,
52
- # dest_api_key TEXT NOT NULL
53
- # );
54
- # """
55
- # Db('api_keys.db').execute_query(sql)
56
-
57
- # def create_user_api_key(user_id, api_key, type, status, idx, dest_api_key):
58
- # sql = f"""
59
- # INSERT INTO api_keys (api_key, type, status, idx, dest_api_key)
60
- # """
61
- # Db('api_keys.db').execute_query(sql)
62
-
63
- # create_user_api_key('i_am_tanbushi', 'llm', '1', '0', 'dest_api_key')
64
-
65
- # @app.get("/list_files")
66
- # def list_files():
67
- # # return "cur_path"
68
- # directory = os.getcwd()
69
-
70
- # # 遍历目录树
71
- # retstr="第一行\n第二行\n第三行\n"
72
- # # retstr = retstr + '1' + "\n"
73
- # # retstr = retstr + '2' + "\n"
74
- # # retstr = retstr + '3' + "\n"
75
- # print(retstr)
76
- # # for dirpath, dirnames, filenames in os.walk(directory):
77
- # # for filename in filenames:
78
- # # retstr = retstr + os.path.join(dirpath, filename) +'\\n'
79
- # # print(os.path.join(dirpath, filename))
80
-
81
- # return HTMLResponse(f"<pre>{retstr}</pre>")
82
- # return HTMLResponse(content=f"<pre>{retstr}</pre>", media_type="text/html")
83
- # return retstr
 
1
  # uvicorn app:app --host 0.0.0.0 --port 7860 --reload
2
 
3
+ from fastapi import FastAPI, HTTPException, Response, Depends
4
+ from fastapi.responses import HTMLResponse
5
  from pathlib import Path
6
+ from dotenv import load_dotenv
7
+ import os
8
+
9
  from global_state import set
10
 
11
+ load_dotenv()
12
+
13
  # 获取当前文件的父目录的绝对路径,即:project_root
14
  parent_dir = Path(__file__).resolve().parent
15
  set('project_root', parent_dir)
16
 
 
 
 
 
 
17
  # from routers.webtools_v1 import router as webtools_router
18
  from routers.users_v1 import router as users_router
19
+ from routers.openai_v1 import router as openai_router
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  app = FastAPI()
22
 
23
  # app.include_router(webtools_router, prefix="/airs/v1", tags=["webtools"])
24
  app.include_router(users_router, prefix="/airs/v1", tags=["users"])
25
+ app.include_router(openai_router, prefix="/airs/v1", tags=["openai"])
26
 
27
  @app.get("/")
28
+ def greet_json():
29
  return {"Hello": "World!"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
auth.py CHANGED
@@ -8,11 +8,9 @@ from db_model.user import UserModel
8
  security = HTTPBearer()
9
 
10
  def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)):
11
- token = credentials.credentials
12
  token = credentials.credentials
13
  # 假设你有一个函数来验证Token并返回用户
14
  user = validate_token(token)
15
- print(f"\n\n\n\n{user}")
16
  if user is None:
17
  raise HTTPException(
18
  status_code=status.HTTP_401_UNAUTHORIZED,
@@ -26,7 +24,8 @@ def validate_token(token: str):
26
  db_module_filename = f"{get('project_root')}/db/cloudflare.py"
27
  query = f"SELECT * FROM users where api_key='{token}'"
28
  response = TbsDb(db_module_filename, "Cloudflare").get_item(query)
29
- print(f"\n\n\n\n{response}")
 
30
  result = response['result'][0]['results']
31
  if len(result) == 0:
32
  return None
 
8
  security = HTTPBearer()
9
 
10
  def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)):
 
11
  token = credentials.credentials
12
  # 假设你有一个函数来验证Token并返回用户
13
  user = validate_token(token)
 
14
  if user is None:
15
  raise HTTPException(
16
  status_code=status.HTTP_401_UNAUTHORIZED,
 
24
  db_module_filename = f"{get('project_root')}/db/cloudflare.py"
25
  query = f"SELECT * FROM users where api_key='{token}'"
26
  response = TbsDb(db_module_filename, "Cloudflare").get_item(query)
27
+ if response is None:
28
+ return None
29
  result = response['result'][0]['results']
30
  if len(result) == 0:
31
  return None
db/cloudflare.py CHANGED
@@ -1,9 +1,7 @@
1
  import requests, os
2
- # from dotenv import load_dotenv
3
 
4
  class Cloudflare():
5
  def __init__(self):
6
- # load_dotenv()
7
  self.CF_ACCOOUNT_ID=os.getenv("CF_ACCOOUNT_ID")
8
  self.DATABASE_ID=os.getenv("DATABASE_ID")
9
  self.X_Auth_Key=os.getenv("X_Auth_Key")
@@ -27,6 +25,5 @@ class Cloudflare():
27
  }
28
  response = requests.post(url, headers=headers, json=input)
29
  resp_json = response.json()
30
-
31
  return resp_json
32
 
 
1
  import requests, os
 
2
 
3
  class Cloudflare():
4
  def __init__(self):
 
5
  self.CF_ACCOOUNT_ID=os.getenv("CF_ACCOOUNT_ID")
6
  self.DATABASE_ID=os.getenv("DATABASE_ID")
7
  self.X_Auth_Key=os.getenv("X_Auth_Key")
 
25
  }
26
  response = requests.post(url, headers=headers, json=input)
27
  resp_json = response.json()
 
28
  return resp_json
29
 
db/tbs_db.py CHANGED
@@ -20,4 +20,5 @@ class TbsDb():
20
 
21
  def execute_query(self, query):
22
  return self.db_module.execute_query(query)
 
23
 
 
20
 
21
  def execute_query(self, query):
22
  return self.db_module.execute_query(query)
23
+
24
 
db_model/chat.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+
3
+ # 定义消息模型
4
+ class MessageModel(BaseModel):
5
+ role: str
6
+ content: str
7
+
8
+ # 定义聊天模型
9
+ class ChatModel(BaseModel):
10
+ model: str = None
11
+ messages: list[MessageModel]
routers/openai_v1.py CHANGED
@@ -1,64 +1,53 @@
1
- from fastapi import APIRouter, Request, HTTPException
2
- # from db import Db
3
- from dotenv import load_dotenv
4
- import os, json
5
  from langchain_core.output_parsers import StrOutputParser
6
  from langchain_core.prompts import ChatPromptTemplate
7
 
8
- # 加载.env文件
9
- load_dotenv()
 
 
 
10
 
11
  router = APIRouter()
12
 
13
- @router.post("/chat/completions")
14
- async def chat_completions(request:Request):
15
- print('chat_completions')
16
- auth_header = request.headers['authorization']
17
- user_api_key = auth_header.split()[1] # 分割字符串并取第二个元素
18
- user_id = await get_user_id(user_api_key)
19
- if user_id==0:
20
- return {'error': {'code': '401', 'message': 'api_key 已过期或验证不正确!'}}
21
 
22
- data = await request.json()
 
23
  try:
24
- model = data['model']
25
  except:
26
  model = ''
27
-
28
- if model=='':
29
  model = await get_default_model()
30
- api_key_info = await get_api_key(model)
31
 
32
- api_key = api_key_info['api_key']
33
- group_name = api_key_info['group_name']
34
- try:
35
- base_url = api_key_info['base_url']
36
- except:
37
- base_url = ''
38
 
39
- # 下面是google api
40
- if group_name=='gemini':
41
- from langchain_google_genai import ChatGoogleGenerativeAI
42
- # 初始化模型
43
  llm = ChatGoogleGenerativeAI(
44
  api_key = api_key,
45
  model = model,
46
  )
 
 
 
 
 
 
 
 
47
 
48
- # 下面就是 chatgpt 兼容 api
49
- from langchain_openai import ChatOpenAI
50
- # 初始化 ChatOpenAI 模型
51
- llm = ChatOpenAI(
52
- model = model,
53
- api_key = api_key,
54
- base_url = base_url,
55
- )
56
-
57
- messages = data['messages']
58
- rslt = [(item["role"], item["content"]) for item in messages] # 转换为所需的元组列表格式
59
  parser = StrOutputParser()
60
- prompt_template = ChatPromptTemplate.from_messages(rslt)
61
-
62
  chain = prompt_template | llm | parser
63
  try:
64
  result = chain.invoke({})
@@ -66,16 +55,18 @@ async def chat_completions(request:Request):
66
  return {'error': str(e)}
67
  return result
68
 
 
69
  async def get_default_model():
70
- results = Db(os.getenv("DB_PATH")).list_query("SELECT * FROM api_names order by default_order limit 1")
 
71
  try:
72
- result = results[0]['api_name']
73
  except:
74
  result = ''
75
  return result
76
 
77
  async def get_api_key(model):
78
- sql = f"""
79
  SELECT an.api_name, ak.api_key, an.base_url, ag.group_name
80
  FROM api_keys ak
81
  JOIN api_groups ag ON ak.api_group_id = ag.id
@@ -84,16 +75,18 @@ WHERE ak.category='LLM' and an.api_name='{model}' and disabled=0
84
  ORDER BY ak.last_call_at
85
  limit 1
86
  """
87
- results = Db(os.getenv("DB_PATH")).list_query(sql)
 
 
88
  # return results
89
  try:
90
- result = results[0]
91
  api_key = result['api_key']
92
  except:
93
  api_key = ''
94
 
95
- sql = f"update api_keys set last_call_at=datetime('now') where api_key='{api_key}'"
96
- Db(os.getenv("DB_PATH")).update_query(sql)
97
  return result
98
 
99
  async def get_user_id(api_key):
 
1
+ from fastapi import APIRouter, Depends
2
+ import os
 
 
3
  from langchain_core.output_parsers import StrOutputParser
4
  from langchain_core.prompts import ChatPromptTemplate
5
 
6
+ from global_state import get
7
+ from db.tbs_db import TbsDb
8
+ from auth import get_current_user
9
+ from db_model.user import UserModel
10
+ from db_model.chat import ChatModel
11
 
12
  router = APIRouter()
13
 
14
+ db_module_filename = f"{get('project_root')}/db/cloudflare.py"
 
 
 
 
 
 
 
15
 
16
+ @router.post("/chat/completions")
17
+ async def chat_completions(chat_model:ChatModel, current_user: UserModel = Depends(get_current_user)):
18
  try:
19
+ model = chat_model.model
20
  except:
21
  model = ''
22
+
23
+ if (model=='')or(model is None):
24
  model = await get_default_model()
 
25
 
26
+ api_key_info = await get_api_key(model)
27
+ api_key = api_key_info.get('api_key', '')
28
+ group_name = api_key_info.get('group_name', '')
29
+ base_url = api_key_info.get('base_url', '')
 
 
30
 
31
+ if group_name=='gemini': # google api,生成 gemini 的 llm
32
+ from langchain_google_genai import ChatGoogleGenerativeAI
 
 
33
  llm = ChatGoogleGenerativeAI(
34
  api_key = api_key,
35
  model = model,
36
  )
37
+ else: # 下面就是 chatgpt 兼容 api
38
+ from langchain_openai import ChatOpenAI
39
+ # 初始化 ChatOpenAI 模型
40
+ llm = ChatOpenAI(
41
+ model = model,
42
+ api_key = api_key,
43
+ base_url = base_url,
44
+ )
45
 
46
+ # 生成prompt模板
47
+ lc_messages = [(message.role, message.content) for message in chat_model.messages]
 
 
 
 
 
 
 
 
 
48
  parser = StrOutputParser()
49
+ prompt_template = ChatPromptTemplate.from_messages(lc_messages)
50
+
51
  chain = prompt_template | llm | parser
52
  try:
53
  result = chain.invoke({})
 
55
  return {'error': str(e)}
56
  return result
57
 
58
+ # 从数据库获取默认模型
59
  async def get_default_model():
60
+ query = f"SELECT * FROM api_names order by default_order limit 1"
61
+ response = TbsDb(db_module_filename, "Cloudflare").get_item(query)
62
  try:
63
+ result = response['result'][0]['results'][0]['api_name']
64
  except:
65
  result = ''
66
  return result
67
 
68
  async def get_api_key(model):
69
+ query = f"""
70
  SELECT an.api_name, ak.api_key, an.base_url, ag.group_name
71
  FROM api_keys ak
72
  JOIN api_groups ag ON ak.api_group_id = ag.id
 
75
  ORDER BY ak.last_call_at
76
  limit 1
77
  """
78
+ response = TbsDb(db_module_filename, "Cloudflare").get_item(query)
79
+ # return response
80
+ # results = Db(os.getenv("DB_PATH")).list_query(sql)
81
  # return results
82
  try:
83
+ result = response['result'][0]['results'][0]
84
  api_key = result['api_key']
85
  except:
86
  api_key = ''
87
 
88
+ query = f"update api_keys set last_call_at=datetime('now') where api_key='{api_key}'"
89
+ TbsDb(db_module_filename, "Cloudflare").execute_query(query)
90
  return result
91
 
92
  async def get_user_id(api_key):