from fastapi import APIRouter, Request, HTTPException # from db import Db from dotenv import load_dotenv import os, json from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate # 加载.env文件 load_dotenv() router = APIRouter() @router.post("/chat/completions") async def chat_completions(request:Request): print('chat_completions') auth_header = request.headers['authorization'] user_api_key = auth_header.split()[1] # 分割字符串并取第二个元素 user_id = await get_user_id(user_api_key) if user_id==0: return {'error': {'code': '401', 'message': 'api_key 已过期或验证不正确!'}} data = await request.json() try: model = data['model'] except: model = '' if model=='': model = await get_default_model() api_key_info = await get_api_key(model) api_key = api_key_info['api_key'] group_name = api_key_info['group_name'] try: base_url = api_key_info['base_url'] except: base_url = '' # 下面是google api if group_name=='gemini': from langchain_google_genai import ChatGoogleGenerativeAI # 初始化模型 llm = ChatGoogleGenerativeAI( api_key = api_key, model = model, ) # 下面就是 chatgpt 兼容 api from langchain_openai import ChatOpenAI # 初始化 ChatOpenAI 模型 llm = ChatOpenAI( model = model, api_key = api_key, base_url = base_url, ) messages = data['messages'] rslt = [(item["role"], item["content"]) for item in messages] # 转换为所需的元组列表格式 parser = StrOutputParser() prompt_template = ChatPromptTemplate.from_messages(rslt) chain = prompt_template | llm | parser try: result = chain.invoke({}) except Exception as e: return {'error': str(e)} return result async def get_default_model(): results = Db(os.getenv("DB_PATH")).list_query("SELECT * FROM api_names order by default_order limit 1") try: result = results[0]['api_name'] except: result = '' return result async def get_api_key(model): sql = f""" SELECT an.api_name, ak.api_key, an.base_url, ag.group_name FROM api_keys ak JOIN api_groups ag ON ak.api_group_id = ag.id JOIN api_names an ON an.api_group_id = ag.id WHERE ak.category='LLM' and an.api_name='{model}' and disabled=0 ORDER BY ak.last_call_at limit 1 """ results = Db(os.getenv("DB_PATH")).list_query(sql) # return results try: result = results[0] api_key = result['api_key'] except: api_key = '' sql = f"update api_keys set last_call_at=datetime('now') where api_key='{api_key}'" Db(os.getenv("DB_PATH")).update_query(sql) return result async def get_user_id(api_key): sql = f"select id from users where api_key='{api_key}'" results = Db(os.getenv("DB_PATH")).list_query(sql) try: result = results[0]['id'] except: result = 0 return result