api-mapper / routers /openai_v1.py
tanbushi's picture
modify users, use cf db
f5cf708
raw
history blame
3.09 kB
from fastapi import APIRouter, Request, HTTPException
# from db import Db
from dotenv import load_dotenv
import os, json
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
# 加载.env文件
load_dotenv()
router = APIRouter()
@router.post("/chat/completions")
async def chat_completions(request:Request):
print('chat_completions')
auth_header = request.headers['authorization']
user_api_key = auth_header.split()[1] # 分割字符串并取第二个元素
user_id = await get_user_id(user_api_key)
if user_id==0:
return {'error': {'code': '401', 'message': 'api_key 已过期或验证不正确!'}}
data = await request.json()
try:
model = data['model']
except:
model = ''
if model=='':
model = await get_default_model()
api_key_info = await get_api_key(model)
api_key = api_key_info['api_key']
group_name = api_key_info['group_name']
try:
base_url = api_key_info['base_url']
except:
base_url = ''
# 下面是google api
if group_name=='gemini':
from langchain_google_genai import ChatGoogleGenerativeAI
# 初始化模型
llm = ChatGoogleGenerativeAI(
api_key = api_key,
model = model,
)
# 下面就是 chatgpt 兼容 api
from langchain_openai import ChatOpenAI
# 初始化 ChatOpenAI 模型
llm = ChatOpenAI(
model = model,
api_key = api_key,
base_url = base_url,
)
messages = data['messages']
rslt = [(item["role"], item["content"]) for item in messages] # 转换为所需的元组列表格式
parser = StrOutputParser()
prompt_template = ChatPromptTemplate.from_messages(rslt)
chain = prompt_template | llm | parser
try:
result = chain.invoke({})
except Exception as e:
return {'error': str(e)}
return result
async def get_default_model():
results = Db(os.getenv("DB_PATH")).list_query("SELECT * FROM api_names order by default_order limit 1")
try:
result = results[0]['api_name']
except:
result = ''
return result
async def get_api_key(model):
sql = f"""
SELECT an.api_name, ak.api_key, an.base_url, ag.group_name
FROM api_keys ak
JOIN api_groups ag ON ak.api_group_id = ag.id
JOIN api_names an ON an.api_group_id = ag.id
WHERE ak.category='LLM' and an.api_name='{model}' and disabled=0
ORDER BY ak.last_call_at
limit 1
"""
results = Db(os.getenv("DB_PATH")).list_query(sql)
# return results
try:
result = results[0]
api_key = result['api_key']
except:
api_key = ''
sql = f"update api_keys set last_call_at=datetime('now') where api_key='{api_key}'"
Db(os.getenv("DB_PATH")).update_query(sql)
return result
async def get_user_id(api_key):
sql = f"select id from users where api_key='{api_key}'"
results = Db(os.getenv("DB_PATH")).list_query(sql)
try:
result = results[0]['id']
except:
result = 0
return result