from fastapi import APIRouter, Depends, Request import requests from datetime import datetime from pydantic import BaseModel, Field from typing import List, Dict from global_state import get from db.tbs_db import TbsDb from auth import get_current_user from db_model.user import UserModel router = APIRouter() class Message(BaseModel): role: str content: str class GeminiModel(BaseModel): model: str temperature: float = Field(..., gt=0) top_p: float = Field(..., gt=0) # frequency_penalty: float = Field(..., ge=0) # presence_penalty: float = Field(..., ge=0) n: int = Field(..., ge=0) stream: bool messages: List[Message] db_module_filename = f"{get('project_root')}/db/cloudflare.py" @router.post("/chat/completions") async def chat_completions(request:Request, current_user: UserModel = Depends(get_current_user)): data = await request.json() model = data.get('model', '') if (model=='')or(model is None): model = await get_default_model() api_key_info = await get_api_key(model) api_key = api_key_info.get('api_key', '') base_url = api_key_info.get('base_url', '') group_name=api_key_info.get('group_name', '') if group_name=='gemini': data = GeminiModel(**data).model_dump() headers = { 'Content-Type': 'application/json', 'Authorization': f'Bearer {api_key}', 'User-Agent': 'PostmanRuntime/7.43.0' } try: response = requests.post(url=f"{base_url}/chat/completions", headers=headers, json=data) return response.json() except Exception as e: print(e) return {'error': e} # 从数据库获取默认模型 async def get_default_model(): query = f"SELECT * FROM api_names order by default_order limit 1" response = TbsDb(db_module_filename, "Cloudflare").get_item(query) try: result = response['result'][0]['results'][0]['api_name'] except: result = '' return result async def get_api_key(model): query = f""" SELECT an.api_name, ak.api_key, an.base_url, ag.group_name FROM api_keys ak JOIN api_groups ag ON ak.api_group_id = ag.id JOIN api_names an ON an.api_group_id = ag.id WHERE ak.category='LLM' and an.api_name='{model}' and disabled=0 ORDER BY ak.last_call_at limit 1 """ response = TbsDb(db_module_filename, "Cloudflare").get_item(query) try: result = response['result'][0]['results'][0] api_key = result['api_key'] except: api_key = '' query = f"update api_keys set last_call_at=datetime('now') where api_key='{api_key}'" TbsDb(db_module_filename, "Cloudflare").execute_query(query) return result def convert_to_openai_format(original_json): # 创建新的JSON对象 new_json = { "id": "chatcmpl-123", # 这里可以生成一个唯一的ID,或者使用传入的id "object": "chat.completion", "created": int(datetime.now().timestamp()), # 当前时间戳 "choices": [ { "index": 0, "message": { "role": "assistant", "content": original_json.content # 使用原始内容 }, "finish_reason": "stop" } ], "usage": { "prompt_tokens": original_json.usage_metadata.get("input_tokens",0), "completion_tokens": original_json.usage_metadata.get("output_tokens", 0), "total_tokens": original_json.usage_metadata.get("total_tokens", 0) } } return new_json