Spaces:
Sleeping
Sleeping
mvp ok
Browse files- app.py +10 -64
- auth.py +2 -3
- db/cloudflare.py +0 -3
- db/tbs_db.py +1 -0
- db_model/chat.py +11 -0
- routers/openai_v1.py +42 -49
app.py
CHANGED
@@ -1,83 +1,29 @@
|
|
1 |
# uvicorn app:app --host 0.0.0.0 --port 7860 --reload
|
2 |
|
|
|
|
|
3 |
from pathlib import Path
|
|
|
|
|
|
|
4 |
from global_state import set
|
5 |
|
|
|
|
|
6 |
# 获取当前文件的父目录的绝对路径,即:project_root
|
7 |
parent_dir = Path(__file__).resolve().parent
|
8 |
set('project_root', parent_dir)
|
9 |
|
10 |
-
|
11 |
-
from fastapi import FastAPI, HTTPException, Response, Depends
|
12 |
-
from fastapi.responses import HTMLResponse
|
13 |
-
# from starlette.requests import Request
|
14 |
-
|
15 |
# from routers.webtools_v1 import router as webtools_router
|
16 |
from routers.users_v1 import router as users_router
|
17 |
-
|
18 |
-
|
19 |
-
# from db import Db
|
20 |
-
|
21 |
-
from dotenv import load_dotenv
|
22 |
-
import os
|
23 |
-
|
24 |
-
# 当在本地使用.env 文件时加载.env文件
|
25 |
-
load_dotenv()
|
26 |
-
|
27 |
-
# 获取环境变量,本地使用 load_dotenv 和 hf 里直接使用环境变量配置时,都可使用下面的语句
|
28 |
-
DB_PATH = os.getenv("DB_PATH")
|
29 |
-
|
30 |
-
# 依赖注入,获取 DB_PATH
|
31 |
-
def get_db_path():
|
32 |
-
return DB_PATH
|
33 |
|
34 |
app = FastAPI()
|
35 |
|
36 |
# app.include_router(webtools_router, prefix="/airs/v1", tags=["webtools"])
|
37 |
app.include_router(users_router, prefix="/airs/v1", tags=["users"])
|
38 |
-
|
39 |
|
40 |
@app.get("/")
|
41 |
-
def greet_json(
|
42 |
return {"Hello": "World!"}
|
43 |
-
# @app.get("/init_db")
|
44 |
-
# def init_db():
|
45 |
-
# sql = """
|
46 |
-
# CREATE TABLE api_keys (
|
47 |
-
# id INTEGER PRIMARY KEY AUTOINCREMENT,
|
48 |
-
# api_key TEXT NOT NULL,
|
49 |
-
# type TEXT NOT NULL,
|
50 |
-
# status INTEGER NOT NULL,
|
51 |
-
# idx INTEGER NOT NULL,
|
52 |
-
# dest_api_key TEXT NOT NULL
|
53 |
-
# );
|
54 |
-
# """
|
55 |
-
# Db('api_keys.db').execute_query(sql)
|
56 |
-
|
57 |
-
# def create_user_api_key(user_id, api_key, type, status, idx, dest_api_key):
|
58 |
-
# sql = f"""
|
59 |
-
# INSERT INTO api_keys (api_key, type, status, idx, dest_api_key)
|
60 |
-
# """
|
61 |
-
# Db('api_keys.db').execute_query(sql)
|
62 |
-
|
63 |
-
# create_user_api_key('i_am_tanbushi', 'llm', '1', '0', 'dest_api_key')
|
64 |
-
|
65 |
-
# @app.get("/list_files")
|
66 |
-
# def list_files():
|
67 |
-
# # return "cur_path"
|
68 |
-
# directory = os.getcwd()
|
69 |
-
|
70 |
-
# # 遍历目录树
|
71 |
-
# retstr="第一行\n第二行\n第三行\n"
|
72 |
-
# # retstr = retstr + '1' + "\n"
|
73 |
-
# # retstr = retstr + '2' + "\n"
|
74 |
-
# # retstr = retstr + '3' + "\n"
|
75 |
-
# print(retstr)
|
76 |
-
# # for dirpath, dirnames, filenames in os.walk(directory):
|
77 |
-
# # for filename in filenames:
|
78 |
-
# # retstr = retstr + os.path.join(dirpath, filename) +'\\n'
|
79 |
-
# # print(os.path.join(dirpath, filename))
|
80 |
-
|
81 |
-
# return HTMLResponse(f"<pre>{retstr}</pre>")
|
82 |
-
# return HTMLResponse(content=f"<pre>{retstr}</pre>", media_type="text/html")
|
83 |
-
# return retstr
|
|
|
1 |
# uvicorn app:app --host 0.0.0.0 --port 7860 --reload
|
2 |
|
3 |
+
from fastapi import FastAPI, HTTPException, Response, Depends
|
4 |
+
from fastapi.responses import HTMLResponse
|
5 |
from pathlib import Path
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
import os
|
8 |
+
|
9 |
from global_state import set
|
10 |
|
11 |
+
load_dotenv()
|
12 |
+
|
13 |
# 获取当前文件的父目录的绝对路径,即:project_root
|
14 |
parent_dir = Path(__file__).resolve().parent
|
15 |
set('project_root', parent_dir)
|
16 |
|
|
|
|
|
|
|
|
|
|
|
17 |
# from routers.webtools_v1 import router as webtools_router
|
18 |
from routers.users_v1 import router as users_router
|
19 |
+
from routers.openai_v1 import router as openai_router
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
app = FastAPI()
|
22 |
|
23 |
# app.include_router(webtools_router, prefix="/airs/v1", tags=["webtools"])
|
24 |
app.include_router(users_router, prefix="/airs/v1", tags=["users"])
|
25 |
+
app.include_router(openai_router, prefix="/airs/v1", tags=["openai"])
|
26 |
|
27 |
@app.get("/")
|
28 |
+
def greet_json():
|
29 |
return {"Hello": "World!"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auth.py
CHANGED
@@ -8,11 +8,9 @@ from db_model.user import UserModel
|
|
8 |
security = HTTPBearer()
|
9 |
|
10 |
def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
11 |
-
token = credentials.credentials
|
12 |
token = credentials.credentials
|
13 |
# 假设你有一个函数来验证Token并返回用户
|
14 |
user = validate_token(token)
|
15 |
-
print(f"\n\n\n\n{user}")
|
16 |
if user is None:
|
17 |
raise HTTPException(
|
18 |
status_code=status.HTTP_401_UNAUTHORIZED,
|
@@ -26,7 +24,8 @@ def validate_token(token: str):
|
|
26 |
db_module_filename = f"{get('project_root')}/db/cloudflare.py"
|
27 |
query = f"SELECT * FROM users where api_key='{token}'"
|
28 |
response = TbsDb(db_module_filename, "Cloudflare").get_item(query)
|
29 |
-
|
|
|
30 |
result = response['result'][0]['results']
|
31 |
if len(result) == 0:
|
32 |
return None
|
|
|
8 |
security = HTTPBearer()
|
9 |
|
10 |
def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
|
|
11 |
token = credentials.credentials
|
12 |
# 假设你有一个函数来验证Token并返回用户
|
13 |
user = validate_token(token)
|
|
|
14 |
if user is None:
|
15 |
raise HTTPException(
|
16 |
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
24 |
db_module_filename = f"{get('project_root')}/db/cloudflare.py"
|
25 |
query = f"SELECT * FROM users where api_key='{token}'"
|
26 |
response = TbsDb(db_module_filename, "Cloudflare").get_item(query)
|
27 |
+
if response is None:
|
28 |
+
return None
|
29 |
result = response['result'][0]['results']
|
30 |
if len(result) == 0:
|
31 |
return None
|
db/cloudflare.py
CHANGED
@@ -1,9 +1,7 @@
|
|
1 |
import requests, os
|
2 |
-
# from dotenv import load_dotenv
|
3 |
|
4 |
class Cloudflare():
|
5 |
def __init__(self):
|
6 |
-
# load_dotenv()
|
7 |
self.CF_ACCOOUNT_ID=os.getenv("CF_ACCOOUNT_ID")
|
8 |
self.DATABASE_ID=os.getenv("DATABASE_ID")
|
9 |
self.X_Auth_Key=os.getenv("X_Auth_Key")
|
@@ -27,6 +25,5 @@ class Cloudflare():
|
|
27 |
}
|
28 |
response = requests.post(url, headers=headers, json=input)
|
29 |
resp_json = response.json()
|
30 |
-
|
31 |
return resp_json
|
32 |
|
|
|
1 |
import requests, os
|
|
|
2 |
|
3 |
class Cloudflare():
|
4 |
def __init__(self):
|
|
|
5 |
self.CF_ACCOOUNT_ID=os.getenv("CF_ACCOOUNT_ID")
|
6 |
self.DATABASE_ID=os.getenv("DATABASE_ID")
|
7 |
self.X_Auth_Key=os.getenv("X_Auth_Key")
|
|
|
25 |
}
|
26 |
response = requests.post(url, headers=headers, json=input)
|
27 |
resp_json = response.json()
|
|
|
28 |
return resp_json
|
29 |
|
db/tbs_db.py
CHANGED
@@ -20,4 +20,5 @@ class TbsDb():
|
|
20 |
|
21 |
def execute_query(self, query):
|
22 |
return self.db_module.execute_query(query)
|
|
|
23 |
|
|
|
20 |
|
21 |
def execute_query(self, query):
|
22 |
return self.db_module.execute_query(query)
|
23 |
+
|
24 |
|
db_model/chat.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel
|
2 |
+
|
3 |
+
# 定义消息模型
|
4 |
+
class MessageModel(BaseModel):
|
5 |
+
role: str
|
6 |
+
content: str
|
7 |
+
|
8 |
+
# 定义聊天模型
|
9 |
+
class ChatModel(BaseModel):
|
10 |
+
model: str = None
|
11 |
+
messages: list[MessageModel]
|
routers/openai_v1.py
CHANGED
@@ -1,64 +1,53 @@
|
|
1 |
-
from fastapi import APIRouter,
|
2 |
-
|
3 |
-
from dotenv import load_dotenv
|
4 |
-
import os, json
|
5 |
from langchain_core.output_parsers import StrOutputParser
|
6 |
from langchain_core.prompts import ChatPromptTemplate
|
7 |
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
10 |
|
11 |
router = APIRouter()
|
12 |
|
13 |
-
|
14 |
-
async def chat_completions(request:Request):
|
15 |
-
print('chat_completions')
|
16 |
-
auth_header = request.headers['authorization']
|
17 |
-
user_api_key = auth_header.split()[1] # 分割字符串并取第二个元素
|
18 |
-
user_id = await get_user_id(user_api_key)
|
19 |
-
if user_id==0:
|
20 |
-
return {'error': {'code': '401', 'message': 'api_key 已过期或验证不正确!'}}
|
21 |
|
22 |
-
|
|
|
23 |
try:
|
24 |
-
model =
|
25 |
except:
|
26 |
model = ''
|
27 |
-
|
28 |
-
if model=='':
|
29 |
model = await get_default_model()
|
30 |
-
api_key_info = await get_api_key(model)
|
31 |
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
except:
|
37 |
-
base_url = ''
|
38 |
|
39 |
-
#
|
40 |
-
|
41 |
-
from langchain_google_genai import ChatGoogleGenerativeAI
|
42 |
-
# 初始化模型
|
43 |
llm = ChatGoogleGenerativeAI(
|
44 |
api_key = api_key,
|
45 |
model = model,
|
46 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
-
#
|
49 |
-
|
50 |
-
# 初始化 ChatOpenAI 模型
|
51 |
-
llm = ChatOpenAI(
|
52 |
-
model = model,
|
53 |
-
api_key = api_key,
|
54 |
-
base_url = base_url,
|
55 |
-
)
|
56 |
-
|
57 |
-
messages = data['messages']
|
58 |
-
rslt = [(item["role"], item["content"]) for item in messages] # 转换为所需的元组列表格式
|
59 |
parser = StrOutputParser()
|
60 |
-
prompt_template = ChatPromptTemplate.from_messages(
|
61 |
-
|
62 |
chain = prompt_template | llm | parser
|
63 |
try:
|
64 |
result = chain.invoke({})
|
@@ -66,16 +55,18 @@ async def chat_completions(request:Request):
|
|
66 |
return {'error': str(e)}
|
67 |
return result
|
68 |
|
|
|
69 |
async def get_default_model():
|
70 |
-
|
|
|
71 |
try:
|
72 |
-
result = results[0]['api_name']
|
73 |
except:
|
74 |
result = ''
|
75 |
return result
|
76 |
|
77 |
async def get_api_key(model):
|
78 |
-
|
79 |
SELECT an.api_name, ak.api_key, an.base_url, ag.group_name
|
80 |
FROM api_keys ak
|
81 |
JOIN api_groups ag ON ak.api_group_id = ag.id
|
@@ -84,16 +75,18 @@ WHERE ak.category='LLM' and an.api_name='{model}' and disabled=0
|
|
84 |
ORDER BY ak.last_call_at
|
85 |
limit 1
|
86 |
"""
|
87 |
-
|
|
|
|
|
88 |
# return results
|
89 |
try:
|
90 |
-
result = results[0]
|
91 |
api_key = result['api_key']
|
92 |
except:
|
93 |
api_key = ''
|
94 |
|
95 |
-
|
96 |
-
|
97 |
return result
|
98 |
|
99 |
async def get_user_id(api_key):
|
|
|
1 |
+
from fastapi import APIRouter, Depends
|
2 |
+
import os
|
|
|
|
|
3 |
from langchain_core.output_parsers import StrOutputParser
|
4 |
from langchain_core.prompts import ChatPromptTemplate
|
5 |
|
6 |
+
from global_state import get
|
7 |
+
from db.tbs_db import TbsDb
|
8 |
+
from auth import get_current_user
|
9 |
+
from db_model.user import UserModel
|
10 |
+
from db_model.chat import ChatModel
|
11 |
|
12 |
router = APIRouter()
|
13 |
|
14 |
+
db_module_filename = f"{get('project_root')}/db/cloudflare.py"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
+
@router.post("/chat/completions")
|
17 |
+
async def chat_completions(chat_model:ChatModel, current_user: UserModel = Depends(get_current_user)):
|
18 |
try:
|
19 |
+
model = chat_model.model
|
20 |
except:
|
21 |
model = ''
|
22 |
+
|
23 |
+
if (model=='')or(model is None):
|
24 |
model = await get_default_model()
|
|
|
25 |
|
26 |
+
api_key_info = await get_api_key(model)
|
27 |
+
api_key = api_key_info.get('api_key', '')
|
28 |
+
group_name = api_key_info.get('group_name', '')
|
29 |
+
base_url = api_key_info.get('base_url', '')
|
|
|
|
|
30 |
|
31 |
+
if group_name=='gemini': # google api,生成 gemini 的 llm
|
32 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
|
|
|
33 |
llm = ChatGoogleGenerativeAI(
|
34 |
api_key = api_key,
|
35 |
model = model,
|
36 |
)
|
37 |
+
else: # 下面就是 chatgpt 兼容 api
|
38 |
+
from langchain_openai import ChatOpenAI
|
39 |
+
# 初始化 ChatOpenAI 模型
|
40 |
+
llm = ChatOpenAI(
|
41 |
+
model = model,
|
42 |
+
api_key = api_key,
|
43 |
+
base_url = base_url,
|
44 |
+
)
|
45 |
|
46 |
+
# 生成prompt模板
|
47 |
+
lc_messages = [(message.role, message.content) for message in chat_model.messages]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
parser = StrOutputParser()
|
49 |
+
prompt_template = ChatPromptTemplate.from_messages(lc_messages)
|
50 |
+
|
51 |
chain = prompt_template | llm | parser
|
52 |
try:
|
53 |
result = chain.invoke({})
|
|
|
55 |
return {'error': str(e)}
|
56 |
return result
|
57 |
|
58 |
+
# 从数据库获取默认模型
|
59 |
async def get_default_model():
|
60 |
+
query = f"SELECT * FROM api_names order by default_order limit 1"
|
61 |
+
response = TbsDb(db_module_filename, "Cloudflare").get_item(query)
|
62 |
try:
|
63 |
+
result = response['result'][0]['results'][0]['api_name']
|
64 |
except:
|
65 |
result = ''
|
66 |
return result
|
67 |
|
68 |
async def get_api_key(model):
|
69 |
+
query = f"""
|
70 |
SELECT an.api_name, ak.api_key, an.base_url, ag.group_name
|
71 |
FROM api_keys ak
|
72 |
JOIN api_groups ag ON ak.api_group_id = ag.id
|
|
|
75 |
ORDER BY ak.last_call_at
|
76 |
limit 1
|
77 |
"""
|
78 |
+
response = TbsDb(db_module_filename, "Cloudflare").get_item(query)
|
79 |
+
# return response
|
80 |
+
# results = Db(os.getenv("DB_PATH")).list_query(sql)
|
81 |
# return results
|
82 |
try:
|
83 |
+
result = response['result'][0]['results'][0]
|
84 |
api_key = result['api_key']
|
85 |
except:
|
86 |
api_key = ''
|
87 |
|
88 |
+
query = f"update api_keys set last_call_at=datetime('now') where api_key='{api_key}'"
|
89 |
+
TbsDb(db_module_filename, "Cloudflare").execute_query(query)
|
90 |
return result
|
91 |
|
92 |
async def get_user_id(api_key):
|