Spaces:
Sleeping
Sleeping
from fastapi import APIRouter, Request, HTTPException | |
# from db import Db | |
from dotenv import load_dotenv | |
import os, json | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.prompts import ChatPromptTemplate | |
# 加载.env文件 | |
load_dotenv() | |
router = APIRouter() | |
async def chat_completions(request:Request): | |
print('chat_completions') | |
auth_header = request.headers['authorization'] | |
user_api_key = auth_header.split()[1] # 分割字符串并取第二个元素 | |
user_id = await get_user_id(user_api_key) | |
if user_id==0: | |
return {'error': {'code': '401', 'message': 'api_key 已过期或验证不正确!'}} | |
data = await request.json() | |
try: | |
model = data['model'] | |
except: | |
model = '' | |
if model=='': | |
model = await get_default_model() | |
api_key_info = await get_api_key(model) | |
api_key = api_key_info['api_key'] | |
group_name = api_key_info['group_name'] | |
try: | |
base_url = api_key_info['base_url'] | |
except: | |
base_url = '' | |
# 下面是google api | |
if group_name=='gemini': | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
# 初始化模型 | |
llm = ChatGoogleGenerativeAI( | |
api_key = api_key, | |
model = model, | |
) | |
# 下面就是 chatgpt 兼容 api | |
from langchain_openai import ChatOpenAI | |
# 初始化 ChatOpenAI 模型 | |
llm = ChatOpenAI( | |
model = model, | |
api_key = api_key, | |
base_url = base_url, | |
) | |
messages = data['messages'] | |
rslt = [(item["role"], item["content"]) for item in messages] # 转换为所需的元组列表格式 | |
parser = StrOutputParser() | |
prompt_template = ChatPromptTemplate.from_messages(rslt) | |
chain = prompt_template | llm | parser | |
try: | |
result = chain.invoke({}) | |
except Exception as e: | |
return {'error': str(e)} | |
return result | |
async def get_default_model(): | |
results = Db(os.getenv("DB_PATH")).list_query("SELECT * FROM api_names order by default_order limit 1") | |
try: | |
result = results[0]['api_name'] | |
except: | |
result = '' | |
return result | |
async def get_api_key(model): | |
sql = f""" | |
SELECT an.api_name, ak.api_key, an.base_url, ag.group_name | |
FROM api_keys ak | |
JOIN api_groups ag ON ak.api_group_id = ag.id | |
JOIN api_names an ON an.api_group_id = ag.id | |
WHERE ak.category='LLM' and an.api_name='{model}' and disabled=0 | |
ORDER BY ak.last_call_at | |
limit 1 | |
""" | |
results = Db(os.getenv("DB_PATH")).list_query(sql) | |
# return results | |
try: | |
result = results[0] | |
api_key = result['api_key'] | |
except: | |
api_key = '' | |
sql = f"update api_keys set last_call_at=datetime('now') where api_key='{api_key}'" | |
Db(os.getenv("DB_PATH")).update_query(sql) | |
return result | |
async def get_user_id(api_key): | |
sql = f"select id from users where api_key='{api_key}'" | |
results = Db(os.getenv("DB_PATH")).list_query(sql) | |
try: | |
result = results[0]['id'] | |
except: | |
result = 0 | |
return result | |