File size: 3,565 Bytes
ac56577
 
 
c154b96
 
ac56577
 
 
 
 
 
 
 
c154b96
 
 
 
 
 
 
 
 
 
 
 
 
 
ac56577
 
 
 
 
c154b96
ac56577
 
 
 
 
c154b96
 
 
ac56577
 
 
3a942e3
 
ac56577
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from fastapi import APIRouter, Depends, Request
import requests
from datetime import datetime
from pydantic import BaseModel, Field
from typing import List, Dict

from global_state import get
from db.tbs_db import TbsDb
from auth import get_current_user
from db_model.user import UserModel

router = APIRouter()

class Message(BaseModel):
    role: str
    content: str

class GeminiModel(BaseModel):
    model: str
    temperature: float = Field(..., gt=0)
    top_p: float = Field(..., gt=0)
    # frequency_penalty: float = Field(..., ge=0)
    # presence_penalty: float = Field(..., ge=0)
    n: int = Field(..., ge=0)
    stream: bool
    messages: List[Message]

db_module_filename = f"{get('project_root')}/db/cloudflare.py"

@router.post("/chat/completions")
async def chat_completions(request:Request, current_user: UserModel = Depends(get_current_user)):
    data = await request.json()
    model = data.get('model', '')
    if (model=='')or(model is None):
        model = await get_default_model()
    api_key_info = await get_api_key(model)
    api_key = api_key_info.get('api_key', '')
    base_url = api_key_info.get('base_url', '')
    group_name=api_key_info.get('group_name', '')
    if group_name=='gemini':
        data = GeminiModel(**data).model_dump()

    headers = {
        'Content-Type': 'application/json',
        'Authorization': f'Bearer {api_key}',
        'User-Agent': 'PostmanRuntime/7.43.0'
    }
    
    try:
        response = requests.post(url=f"{base_url}/chat/completions", headers=headers, json=data)
        return response.json()
    except Exception as e:
        print(e)
        return {'error': e}

# 从数据库获取默认模型
async def get_default_model():
    query = f"SELECT * FROM api_names order by default_order limit 1"
    response = TbsDb(db_module_filename, "Cloudflare").get_item(query)
    try:
        result = response['result'][0]['results'][0]['api_name']
    except:
        result = ''
    return result

async def get_api_key(model):
    query = f"""
SELECT an.api_name, ak.api_key, an.base_url, ag.group_name
FROM api_keys ak
JOIN api_groups ag ON ak.api_group_id = ag.id
JOIN api_names an ON an.api_group_id = ag.id
WHERE ak.category='LLM' and an.api_name='{model}' and disabled=0
ORDER BY ak.last_call_at
limit 1
"""
    response = TbsDb(db_module_filename, "Cloudflare").get_item(query)
    try:
        result = response['result'][0]['results'][0]
        api_key = result['api_key']
    except:
        api_key = ''
    
    query = f"update api_keys set last_call_at=datetime('now') where api_key='{api_key}'"
    TbsDb(db_module_filename, "Cloudflare").execute_query(query)
    return result

def convert_to_openai_format(original_json):
    # 创建新的JSON对象
    new_json = {
        "id": "chatcmpl-123",  # 这里可以生成一个唯一的ID,或者使用传入的id
        "object": "chat.completion",
        "created": int(datetime.now().timestamp()),  # 当前时间戳
        "choices": [
            {
                "index": 0,
                "message": {
                    "role": "assistant",
                    "content": original_json.content  # 使用原始内容
                },
                "finish_reason": "stop"
            }
        ],
        "usage": {
            "prompt_tokens": original_json.usage_metadata.get("input_tokens",0),
            "completion_tokens": original_json.usage_metadata.get("output_tokens", 0),
            "total_tokens": original_json.usage_metadata.get("total_tokens", 0)
        }
    }
    return new_json