Spaces:
Sleeping
Sleeping
from fastapi import APIRouter, Depends, Request | |
import requests | |
from datetime import datetime | |
from pydantic import BaseModel, Field | |
from typing import List, Dict | |
from global_state import get | |
from db.tbs_db import TbsDb | |
from auth import get_current_user | |
from db_model.user import UserModel | |
router = APIRouter() | |
class Message(BaseModel): | |
role: str | |
content: str | |
class GeminiModel(BaseModel): | |
model: str | |
temperature: float = Field(..., gt=0) | |
top_p: float = Field(..., gt=0) | |
# frequency_penalty: float = Field(..., ge=0) | |
# presence_penalty: float = Field(..., ge=0) | |
n: int = Field(..., ge=0) | |
stream: bool | |
messages: List[Message] | |
db_module_filename = f"{get('project_root')}/db/cloudflare.py" | |
async def chat_completions(request:Request, current_user: UserModel = Depends(get_current_user)): | |
data = await request.json() | |
model = data.get('model', '') | |
if (model=='')or(model is None): | |
model = await get_default_model() | |
api_key_info = await get_api_key(model) | |
api_key = api_key_info.get('api_key', '') | |
base_url = api_key_info.get('base_url', '') | |
group_name=api_key_info.get('group_name', '') | |
if group_name=='gemini': | |
data = GeminiModel(**data).model_dump() | |
headers = { | |
'Content-Type': 'application/json', | |
'Authorization': f'Bearer {api_key}', | |
'User-Agent': 'PostmanRuntime/7.43.0' | |
} | |
try: | |
response = requests.post(url=f"{base_url}/chat/completions", headers=headers, json=data) | |
return response.json() | |
except Exception as e: | |
print(e) | |
return {'error': e} | |
# 从数据库获取默认模型 | |
async def get_default_model(): | |
query = f"SELECT * FROM api_names order by default_order limit 1" | |
response = TbsDb(db_module_filename, "Cloudflare").get_item(query) | |
try: | |
result = response['result'][0]['results'][0]['api_name'] | |
except: | |
result = '' | |
return result | |
async def get_api_key(model): | |
query = f""" | |
SELECT an.api_name, ak.api_key, an.base_url, ag.group_name | |
FROM api_keys ak | |
JOIN api_groups ag ON ak.api_group_id = ag.id | |
JOIN api_names an ON an.api_group_id = ag.id | |
WHERE ak.category='LLM' and an.api_name='{model}' and disabled=0 | |
ORDER BY ak.last_call_at | |
limit 1 | |
""" | |
response = TbsDb(db_module_filename, "Cloudflare").get_item(query) | |
try: | |
result = response['result'][0]['results'][0] | |
api_key = result['api_key'] | |
except: | |
api_key = '' | |
query = f"update api_keys set last_call_at=datetime('now') where api_key='{api_key}'" | |
TbsDb(db_module_filename, "Cloudflare").execute_query(query) | |
return result | |
def convert_to_openai_format(original_json): | |
# 创建新的JSON对象 | |
new_json = { | |
"id": "chatcmpl-123", # 这里可以生成一个唯一的ID,或者使用传入的id | |
"object": "chat.completion", | |
"created": int(datetime.now().timestamp()), # 当前时间戳 | |
"choices": [ | |
{ | |
"index": 0, | |
"message": { | |
"role": "assistant", | |
"content": original_json.content # 使用原始内容 | |
}, | |
"finish_reason": "stop" | |
} | |
], | |
"usage": { | |
"prompt_tokens": original_json.usage_metadata.get("input_tokens",0), | |
"completion_tokens": original_json.usage_metadata.get("output_tokens", 0), | |
"total_tokens": original_json.usage_metadata.get("total_tokens", 0) | |
} | |
} | |
return new_json | |