File size: 3,086 Bytes
daa9d8a
f5cf708
daa9d8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c7fa70
 
 
 
daa9d8a
 
 
 
 
2c7fa70
daa9d8a
 
 
2c7fa70
daa9d8a
 
2c7fa70
 
 
 
daa9d8a
 
 
 
 
 
 
 
 
 
 
2c7fa70
 
 
 
 
 
 
daa9d8a
 
 
 
 
 
 
2c7fa70
 
 
 
daa9d8a
 
 
 
 
 
 
 
 
 
 
 
2c7fa70
daa9d8a
 
 
 
 
 
 
 
2c7fa70
daa9d8a
 
 
 
 
 
 
 
 
2c7fa70
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from fastapi import APIRouter, Request, HTTPException
# from db import Db
from dotenv import load_dotenv
import os, json
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate

# 加载.env文件
load_dotenv()

router = APIRouter()

@router.post("/chat/completions")
async def chat_completions(request:Request):
    print('chat_completions')
    auth_header = request.headers['authorization']
    user_api_key = auth_header.split()[1]  # 分割字符串并取第二个元素
    user_id = await get_user_id(user_api_key)
    if user_id==0:
        return {'error': {'code': '401', 'message': 'api_key 已过期或验证不正确!'}}

    data = await request.json()
    try:
        model = data['model']
    except:
        model = ''
    
    if model=='':
        model = await get_default_model()
    api_key_info = await get_api_key(model)

    api_key = api_key_info['api_key']
    group_name = api_key_info['group_name']
    try:
        base_url = api_key_info['base_url']
    except:
        base_url = ''

    # 下面是google api
    if group_name=='gemini': 
        from langchain_google_genai import ChatGoogleGenerativeAI
        # 初始化模型
        llm = ChatGoogleGenerativeAI(
            api_key = api_key,
            model = model,
        )
    
    # 下面就是 chatgpt 兼容 api
    from langchain_openai import ChatOpenAI
    # 初始化 ChatOpenAI 模型
    llm = ChatOpenAI(
        model = model,
        api_key = api_key,
        base_url = base_url,
    )

    messages = data['messages']
    rslt = [(item["role"], item["content"]) for item in messages] # 转换为所需的元组列表格式
    parser = StrOutputParser()
    prompt_template = ChatPromptTemplate.from_messages(rslt)

    chain = prompt_template | llm | parser
    try:
        result = chain.invoke({})
    except Exception as e:
        return {'error': str(e)}
    return result

async def get_default_model():
    results = Db(os.getenv("DB_PATH")).list_query("SELECT * FROM api_names order by default_order limit 1")
    try:
        result = results[0]['api_name']
    except:
        result = ''
    return result

async def get_api_key(model):
    sql = f"""
SELECT an.api_name, ak.api_key, an.base_url, ag.group_name
FROM api_keys ak
JOIN api_groups ag ON ak.api_group_id = ag.id
JOIN api_names an ON an.api_group_id = ag.id
WHERE ak.category='LLM' and an.api_name='{model}' and disabled=0
ORDER BY ak.last_call_at
limit 1
"""
    results = Db(os.getenv("DB_PATH")).list_query(sql)
    # return results
    try:
        result = results[0]
        api_key = result['api_key']
    except:
        api_key = ''
    
    sql = f"update api_keys set last_call_at=datetime('now') where api_key='{api_key}'"
    Db(os.getenv("DB_PATH")).update_query(sql)
    return result

async def get_user_id(api_key):
    sql = f"select id from users where api_key='{api_key}'"
    results = Db(os.getenv("DB_PATH")).list_query(sql)
    try:
        result = results[0]['id']
    except:
        result = 0
    return result