api-mapper / routers /openai_v1_1.py
tanbushi's picture
gemini ok
c154b96
raw
history blame
3.57 kB
from fastapi import APIRouter, Depends, Request
import requests
from datetime import datetime
from pydantic import BaseModel, Field
from typing import List, Dict
from global_state import get
from db.tbs_db import TbsDb
from auth import get_current_user
from db_model.user import UserModel
router = APIRouter()
class Message(BaseModel):
role: str
content: str
class GeminiModel(BaseModel):
model: str
temperature: float = Field(..., gt=0)
top_p: float = Field(..., gt=0)
# frequency_penalty: float = Field(..., ge=0)
# presence_penalty: float = Field(..., ge=0)
n: int = Field(..., ge=0)
stream: bool
messages: List[Message]
db_module_filename = f"{get('project_root')}/db/cloudflare.py"
@router.post("/chat/completions")
async def chat_completions(request:Request, current_user: UserModel = Depends(get_current_user)):
data = await request.json()
model = data.get('model', '')
if (model=='')or(model is None):
model = await get_default_model()
api_key_info = await get_api_key(model)
api_key = api_key_info.get('api_key', '')
base_url = api_key_info.get('base_url', '')
group_name=api_key_info.get('group_name', '')
if group_name=='gemini':
data = GeminiModel(**data).model_dump()
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {api_key}',
'User-Agent': 'PostmanRuntime/7.43.0'
}
try:
response = requests.post(url=f"{base_url}/chat/completions", headers=headers, json=data)
return response.json()
except Exception as e:
print(e)
return {'error': e}
# 从数据库获取默认模型
async def get_default_model():
query = f"SELECT * FROM api_names order by default_order limit 1"
response = TbsDb(db_module_filename, "Cloudflare").get_item(query)
try:
result = response['result'][0]['results'][0]['api_name']
except:
result = ''
return result
async def get_api_key(model):
query = f"""
SELECT an.api_name, ak.api_key, an.base_url, ag.group_name
FROM api_keys ak
JOIN api_groups ag ON ak.api_group_id = ag.id
JOIN api_names an ON an.api_group_id = ag.id
WHERE ak.category='LLM' and an.api_name='{model}' and disabled=0
ORDER BY ak.last_call_at
limit 1
"""
response = TbsDb(db_module_filename, "Cloudflare").get_item(query)
try:
result = response['result'][0]['results'][0]
api_key = result['api_key']
except:
api_key = ''
query = f"update api_keys set last_call_at=datetime('now') where api_key='{api_key}'"
TbsDb(db_module_filename, "Cloudflare").execute_query(query)
return result
def convert_to_openai_format(original_json):
# 创建新的JSON对象
new_json = {
"id": "chatcmpl-123", # 这里可以生成一个唯一的ID,或者使用传入的id
"object": "chat.completion",
"created": int(datetime.now().timestamp()), # 当前时间戳
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": original_json.content # 使用原始内容
},
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": original_json.usage_metadata.get("input_tokens",0),
"completion_tokens": original_json.usage_metadata.get("output_tokens", 0),
"total_tokens": original_json.usage_metadata.get("total_tokens", 0)
}
}
return new_json