|
import asyncio |
|
import aiohttp |
|
from aiohttp import web |
|
import json |
|
import logging |
|
import os |
|
import time |
|
from typing import Dict, List, Optional, Any, Union |
|
from collections import deque |
|
from dataclasses import dataclass |
|
from enum import Enum |
|
import uuid |
|
import sys |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
|
handlers=[logging.StreamHandler(sys.stdout)] |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
class SystemRoleMode(Enum): |
|
KEEP = "keep" |
|
CONVERT = "convert" |
|
|
|
@dataclass |
|
class TokenInfo: |
|
token: str |
|
failed_count: int = 0 |
|
last_used: float = 0 |
|
last_balance_check: float = 0 |
|
|
|
class ConfigManager: |
|
"""配置管理器""" |
|
|
|
def __init__(self): |
|
self.API_KEY = os.getenv('API_KEY', 'sk-123456') |
|
self.TARGET_URL = os.getenv('TARGET_URL', 'https://miler-kiloai.deno.dev') |
|
self.BALANCE_CHECK_URL = os.getenv('BALANCE_CHECK_URL', 'https://kilocode.ai/api/profile/balance') |
|
self.TARGET_HEADERS = { |
|
'Content-Type': 'application/json', |
|
'User-Agent': 'Kilo-Code/4.58.0', |
|
'Accept': 'application/json', |
|
'Accept-Encoding': 'br, gzip, deflate', |
|
'X-Stainless-Retry-Count': '0', |
|
'X-Stainless-Lang': 'js', |
|
'X-Stainless-Package-Version': '5.5.1', |
|
'X-Stainless-OS': 'Windows', |
|
'X-Stainless-Arch': 'x64', |
|
'X-Stainless-Runtime': 'node', |
|
'X-Stainless-Runtime-Version': 'v20.19.0', |
|
'HTTP-Referer': 'https://kilocode.ai', |
|
'X-Title': 'Kilo Code', |
|
'X-KiloCode-Version': '4.58.0', |
|
'accept-language': '*', |
|
'sec-fetch-mode': 'cors' |
|
} |
|
|
|
self.BALANCE_CHECK_HEADERS = { |
|
'User-Agent': 'axios/1.9.0', |
|
'Connection': 'close', |
|
'Accept': 'application/json, text/plain, */*', |
|
'Accept-Encoding': 'gzip, compress, deflate, br', |
|
'Content-Type': 'application/json' |
|
} |
|
self.MAX_RETRIES = int(os.getenv('MAX_RETRIES', '3')) |
|
self.MAX_CONCURRENT = int(os.getenv('MAX_CONCURRENT', '10')) |
|
self.PORT = int(os.getenv('PORT', '25526')) |
|
self.SYSTEM_ROLE_MODE = SystemRoleMode(os.getenv('SYSTEM_ROLE_MODE', 'keep')) |
|
|
|
|
|
self.MODEL_MAPPING = { |
|
'gemini-2.5-flash':'google/gemini-2.5-flash', |
|
'gemini-2.5-flash-thinking':'google/gemini-2.5-flash', |
|
'gemini-2.5-pro-thinking':'google/gemini-2.5-pro', |
|
'grok-4-07-09-thingking':'x-ai/grok-4', |
|
'claude-3-7-sonnet-20250219': 'anthropic/claude-3.7-sonnet', |
|
'claude-3-7-sonnet-20250219-thinking': 'anthropic/claude-3.7-sonnet', |
|
'claude-opus-4-20250514': 'anthropic/claude-opus-4', |
|
'claude-opus-4-20250514-thinking': 'anthropic/claude-opus-4', |
|
'claude-sonnet-4-20250514': 'anthropic/claude-sonnet-4', |
|
'claude-sonnet-4-20250514-thinking': 'anthropic/claude-sonnet-4' |
|
} |
|
|
|
|
|
self.TOKEN_POOL = self._load_token_pool() |
|
self.TOKEN_FAILURE_THRESHOLD = int(os.getenv('TOKEN_FAILURE_THRESHOLD', '3')) |
|
|
|
self.BALANCE_CHECK_INTERVAL = int(os.getenv('BALANCE_CHECK_INTERVAL', '3600')) |
|
self.MIN_BALANCE_THRESHOLD = float(os.getenv('MIN_BALANCE_THRESHOLD', '1.0')) |
|
|
|
def _load_token_pool(self) -> List[str]: |
|
"""加载Token池""" |
|
tokens = os.getenv('TOKEN_POOL', '').split(',') |
|
return [token.strip() for token in tokens if token.strip()] |
|
|
|
class MessageProcessor: |
|
"""消息处理器""" |
|
|
|
def __init__(self, config: ConfigManager): |
|
self.config = config |
|
|
|
def process_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
|
"""处理消息数组""" |
|
if self.config.SYSTEM_ROLE_MODE == SystemRoleMode.KEEP: |
|
return self._process_keep_system_mode(messages) |
|
else: |
|
return self._process_convert_system_mode(messages) |
|
|
|
def _process_keep_system_mode(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
|
"""保留system角色模式处理""" |
|
if not messages: |
|
return messages |
|
|
|
result = [] |
|
i = 0 |
|
|
|
|
|
if messages[0].get('role') == 'system': |
|
merged_content = [] |
|
while i < len(messages) and messages[i].get('role') == 'system': |
|
content = messages[i].get('content', '') |
|
if content: |
|
merged_content.append(self._extract_text_content(content)) |
|
i += 1 |
|
|
|
if merged_content: |
|
result.append({ |
|
'role': 'system', |
|
'content': [{'type': 'text', 'text': '\n'.join(merged_content)}] |
|
}) |
|
|
|
|
|
while i < len(messages): |
|
current_msg = messages[i].copy() |
|
|
|
|
|
if current_msg.get('role') == 'system': |
|
current_msg['role'] = 'user' |
|
|
|
|
|
current_msg = self._normalize_message_content(current_msg) |
|
|
|
|
|
if (result and |
|
result[-1].get('role') == current_msg.get('role') and |
|
self._can_merge_content(result[-1].get('content')) and |
|
self._can_merge_content(current_msg.get('content'))): |
|
|
|
|
|
prev_content = self._extract_text_content(result[-1]['content']) |
|
curr_content = self._extract_text_content(current_msg['content']) |
|
result[-1]['content'] = [{'type': 'text', 'text': f"{prev_content}\n{curr_content}"}] |
|
else: |
|
result.append(current_msg) |
|
|
|
i += 1 |
|
|
|
return result |
|
|
|
def _process_convert_system_mode(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
|
"""转换system角色模式处理""" |
|
if not messages: |
|
return messages |
|
|
|
|
|
converted_messages = [] |
|
for msg in messages: |
|
new_msg = msg.copy() |
|
if new_msg.get('role') == 'system': |
|
new_msg['role'] = 'user' |
|
new_msg = self._normalize_message_content(new_msg) |
|
converted_messages.append(new_msg) |
|
|
|
|
|
result = [] |
|
for msg in converted_messages: |
|
if (result and |
|
result[-1].get('role') == msg.get('role') and |
|
self._can_merge_content(result[-1].get('content')) and |
|
self._can_merge_content(msg.get('content'))): |
|
|
|
|
|
prev_content = self._extract_text_content(result[-1]['content']) |
|
curr_content = self._extract_text_content(msg['content']) |
|
result[-1]['content'] = [{'type': 'text', 'text': f"{prev_content}\n{curr_content}"}] |
|
else: |
|
result.append(msg) |
|
|
|
return result |
|
|
|
def _normalize_message_content(self, message: Dict[str, Any]) -> Dict[str, Any]: |
|
"""标准化消息内容格式""" |
|
content = message.get('content') |
|
role = message.get('role') |
|
tool_calls = message.get('tool_calls', None) |
|
|
|
if role == 'tool' or tool_calls is not None: |
|
return message |
|
if isinstance(content, str): |
|
message['content'] = [{'type': 'text', 'text': content}] |
|
elif isinstance(content, list): |
|
|
|
pass |
|
else: |
|
message['content'] = [{'type': 'text', 'text': str(content)}] |
|
|
|
return message |
|
|
|
def _can_merge_content(self, content: Any) -> bool: |
|
"""判断内容是否可以合并""" |
|
if isinstance(content, list) and len(content) == 1: |
|
return content[0].get('type') == 'text' |
|
return False |
|
|
|
def _extract_text_content(self, content: Any) -> str: |
|
"""提取文本内容""" |
|
if isinstance(content, str): |
|
return content |
|
elif isinstance(content, list) and len(content) == 1 and content[0].get('type') == 'text': |
|
return content[0].get('text', '') |
|
return str(content) |
|
|
|
class TokenManager: |
|
"""Token管理器""" |
|
|
|
def __init__(self, config: ConfigManager): |
|
self.config = config |
|
self.available_tokens = deque([TokenInfo(token) for token in config.TOKEN_POOL]) |
|
self.failed_tokens = deque() |
|
self.lock = asyncio.Lock() |
|
self.balance_check_task = None |
|
self._shutdown_event = asyncio.Event() |
|
|
|
async def start_balance_checker(self): |
|
"""启动余额检测后台任务""" |
|
if self.balance_check_task is None: |
|
self.balance_check_task = asyncio.create_task(self._balance_check_loop()) |
|
logger.info("余额检测后台任务已启动") |
|
|
|
async def stop_balance_checker(self): |
|
"""停止余额检测后台任务""" |
|
if self.balance_check_task: |
|
self._shutdown_event.set() |
|
try: |
|
await asyncio.wait_for(self.balance_check_task, timeout=5.0) |
|
except asyncio.TimeoutError: |
|
self.balance_check_task.cancel() |
|
self.balance_check_task = None |
|
logger.info("余额检测后台任务已停止") |
|
|
|
async def get_token(self) -> Optional[str]: |
|
"""获取可用token""" |
|
async with self.lock: |
|
|
|
if not self.available_tokens and self.failed_tokens: |
|
await self._immediate_recovery_check() |
|
|
|
if self.available_tokens: |
|
token_info = self.available_tokens.popleft() |
|
token_info.last_used = time.time() |
|
return token_info.token |
|
|
|
return None |
|
|
|
async def return_token(self, token: str, success: bool = True): |
|
"""归还token""" |
|
async with self.lock: |
|
token_info = TokenInfo(token) |
|
|
|
if success: |
|
token_info.failed_count = 0 |
|
self.available_tokens.append(token_info) |
|
else: |
|
token_info.failed_count += 1 |
|
if token_info.failed_count >= self.config.TOKEN_FAILURE_THRESHOLD: |
|
self.failed_tokens.append(token_info) |
|
logger.warning(f"Token已移至失败池: {token[:10]}...") |
|
else: |
|
self.available_tokens.append(token_info) |
|
|
|
async def _balance_check_loop(self): |
|
"""余额检测循环(后台任务)""" |
|
logger.info(f"开始余额检测循环,检测间隔: {self.config.BALANCE_CHECK_INTERVAL}秒") |
|
|
|
while not self._shutdown_event.is_set(): |
|
try: |
|
await asyncio.wait_for( |
|
self._shutdown_event.wait(), |
|
timeout=self.config.BALANCE_CHECK_INTERVAL |
|
) |
|
break |
|
except asyncio.TimeoutError: |
|
pass |
|
|
|
|
|
await self._check_failed_tokens_balance() |
|
|
|
async def _immediate_recovery_check(self): |
|
"""立即恢复检测(当没有可用token时)""" |
|
logger.info("没有可用token,立即执行恢复检测") |
|
await self._check_failed_tokens_balance() |
|
|
|
async def _check_failed_tokens_balance(self): |
|
"""检测失败token的余额状态""" |
|
if not self.failed_tokens: |
|
return |
|
|
|
current_time = time.time() |
|
tokens_to_check = [] |
|
|
|
|
|
async with self.lock: |
|
for token_info in list(self.failed_tokens): |
|
|
|
if current_time - token_info.last_balance_check >= self.config.BALANCE_CHECK_INTERVAL: |
|
tokens_to_check.append(token_info) |
|
|
|
if not tokens_to_check: |
|
return |
|
|
|
logger.info(f"开始检测 {len(tokens_to_check)} 个失败token的余额") |
|
|
|
|
|
check_tasks = [ |
|
self._check_single_token_balance(token_info) |
|
for token_info in tokens_to_check |
|
] |
|
|
|
results = await asyncio.gather(*check_tasks, return_exceptions=True) |
|
|
|
|
|
recovered_tokens = [] |
|
async with self.lock: |
|
for token_info, result in zip(tokens_to_check, results): |
|
token_info.last_balance_check = current_time |
|
|
|
if isinstance(result, Exception): |
|
logger.warning(f"Token {token_info.token[:10]}... 余额检测失败: {str(result)}") |
|
continue |
|
|
|
if result: |
|
|
|
try: |
|
self.failed_tokens.remove(token_info) |
|
token_info.failed_count = 0 |
|
recovered_tokens.append(token_info) |
|
except ValueError: |
|
pass |
|
|
|
|
|
if recovered_tokens: |
|
async with self.lock: |
|
self.available_tokens.extend(recovered_tokens) |
|
logger.info(f"成功恢复 {len(recovered_tokens)} 个token到可用池") |
|
|
|
async def _check_single_token_balance(self, token_info: TokenInfo) -> bool: |
|
"""检测单个token的余额""" |
|
try: |
|
headers = self.config.BALANCE_CHECK_HEADERS.copy() |
|
headers['Authorization'] = f'Bearer {token_info.token}' |
|
|
|
timeout = aiohttp.ClientTimeout(total=10) |
|
|
|
async with aiohttp.ClientSession(timeout=timeout) as session: |
|
async with session.get( |
|
self.config.BALANCE_CHECK_URL, |
|
headers=headers |
|
) as response: |
|
|
|
if response.status == 200: |
|
balance_data = await response.json() |
|
balance = balance_data.get('balance', 0) |
|
is_depleted = balance_data.get('isDepleted', True) |
|
|
|
logger.info(f"Token {token_info.token[:10]}... 余额检测: balance={balance}, isDepleted={is_depleted}") |
|
|
|
|
|
if balance > self.config.MIN_BALANCE_THRESHOLD and not is_depleted: |
|
logger.info(f"Token {token_info.token[:10]}... 余额充足,可以恢复使用") |
|
return True |
|
else: |
|
logger.info(f"Token {token_info.token[:10]}... 余额不足或已耗尽") |
|
return False |
|
else: |
|
error_text = await response.text() |
|
logger.warning(f"Token {token_info.token[:10]}... 余额检测失败: 状态码={response.status}, 错误={error_text}") |
|
return False |
|
|
|
except Exception as e: |
|
logger.error(f"Token {token_info.token[:10]}... 余额检测异常: {str(e)}") |
|
return False |
|
|
|
class RequestHandler: |
|
"""请求处理器""" |
|
|
|
def __init__(self, config: ConfigManager, message_processor: MessageProcessor, token_manager: TokenManager): |
|
self.config = config |
|
self.message_processor = message_processor |
|
self.token_manager = token_manager |
|
self.semaphore = asyncio.Semaphore(config.MAX_CONCURRENT) |
|
|
|
async def handle_chat_completion(self, request: web.Request) -> web.Response: |
|
"""处理聊天完成请求""" |
|
async with self.semaphore: |
|
try: |
|
|
|
if not self._validate_api_key(request): |
|
logger.warning("API密钥验证失败") |
|
return web.json_response( |
|
{"error": {"message": "Invalid API key", "type": "authentication_error"}}, |
|
status=401 |
|
) |
|
|
|
|
|
request_data = await request.json() |
|
logger.info(f"收到请求: 模型={request_data.get('model', 'unknown')}, 消息数={len(request_data.get('messages', []))}") |
|
|
|
|
|
extracted_params = self._extract_openai_params(request_data) |
|
|
|
|
|
processed_messages = self.message_processor.process_messages(extracted_params['messages']) |
|
logger.info(f"处理后的消息: {json.dumps(processed_messages, ensure_ascii=False)}") |
|
|
|
|
|
target_request = self._build_target_request(extracted_params, processed_messages) |
|
logger.info(f"构建的目标请求体: {json.dumps(target_request, ensure_ascii=False, indent=4)}") |
|
|
|
|
|
return await self._execute_request(request, target_request, extracted_params.get('stream', False)) |
|
|
|
except Exception as e: |
|
logger.error(f"请求处理错误: {str(e)}") |
|
return web.json_response( |
|
{"error": {"message": "Internal server error", "type": "server_error"}}, |
|
status=500 |
|
) |
|
|
|
def _validate_api_key(self, request: web.Request) -> bool: |
|
"""验证API Key""" |
|
auth_header = request.headers.get('Authorization', '') |
|
if not auth_header.startswith('Bearer '): |
|
return False |
|
|
|
api_key = auth_header[7:] |
|
return api_key in self.config.API_KEY |
|
|
|
def _extract_openai_params(self, request_data: Dict[str, Any]) -> Dict[str, Any]: |
|
"""提取OpenAI标准参数""" |
|
params = {} |
|
|
|
|
|
params['messages'] = request_data.get('messages', []) |
|
params['model'] = request_data.get('model', None) |
|
if params['model'] is None: |
|
raise ValueError("model 不能为空") |
|
elif params['model'] not in self.config.MODEL_MAPPING: |
|
raise ValueError(f"model {params['model']} 不支持") |
|
|
|
|
|
optional_params = [ |
|
'stream', 'max_tokens', 'temperature', 'top_p', 'reasoning', |
|
'include_reasoning', 'stop', 'frequency_penalty', 'presence_penalty', |
|
'seed', 'repetition_penalty', 'logit_bias', 'tools', 'tool_choice', |
|
'stream_options' |
|
] |
|
|
|
for param in optional_params: |
|
if param in request_data: |
|
params[param] = request_data[param] |
|
|
|
return params |
|
|
|
def _build_target_request(self, params: Dict[str, Any], processed_messages: List[Dict[str, Any]]) -> Dict[str, Any]: |
|
"""构建目标请求体""" |
|
target_request = { |
|
'messages': processed_messages, |
|
'model': self.config.MODEL_MAPPING.get(params['model'], params['model']) |
|
} |
|
|
|
|
|
for key, value in params.items(): |
|
if key not in ['messages', 'model']: |
|
target_request[key] = value |
|
|
|
if "thinking" in params['model']: |
|
if "max_tokens" in params: |
|
target_request['reasoning'] ={'max_tokens': int(params['max_tokens'] / 2)} |
|
else: |
|
target_request['max_tokens'] = 4096 |
|
target_request['reasoning'] = {'max_tokens': 2048} |
|
logger.info(f"目标模型: {target_request['model']}") |
|
return target_request |
|
|
|
async def _execute_request(self, original_request: web.Request, target_request: Dict[str, Any], is_stream: bool) -> web.Response: |
|
"""执行请求""" |
|
for attempt in range(self.config.MAX_RETRIES): |
|
token = await self.token_manager.get_token() |
|
if not token: |
|
logger.error("没有可用的token") |
|
return web.json_response( |
|
{"error": {"message": "No available tokens", "type": "server_error"}}, |
|
status=503 |
|
) |
|
|
|
try: |
|
headers = self.config.TARGET_HEADERS.copy() |
|
headers['authorization'] = f'Bearer {token}' |
|
headers['X-KiloCode-TaskId'] = str(uuid.uuid4()) |
|
|
|
timeout = aiohttp.ClientTimeout(total=3000) |
|
|
|
logger.info(f"尝试第 {attempt + 1} 次请求 Kilo API") |
|
|
|
async with aiohttp.ClientSession(timeout=timeout) as session: |
|
async with session.post( |
|
self.config.TARGET_URL, |
|
json=target_request, |
|
headers=headers |
|
) as response: |
|
|
|
if response.status == 200: |
|
await self.token_manager.return_token(token, success=True) |
|
logger.info(f"请求成功: 状态码={response.status}, 流式={is_stream}") |
|
|
|
if is_stream: |
|
return await self._handle_stream_response(original_request, response) |
|
else: |
|
return await self._handle_non_stream_response(response) |
|
else: |
|
await self.token_manager.return_token(token, success=False) |
|
error_text = await response.text() |
|
logger.error(f"请求失败: 状态码={response.status}, 错误={error_text}") |
|
|
|
if attempt == self.config.MAX_RETRIES - 1: |
|
return web.json_response( |
|
{"error": {"message": error_text, "type": "api_error"}}, |
|
status=response.status |
|
) |
|
|
|
except Exception as e: |
|
await self.token_manager.return_token(token, success=False) |
|
logger.error(f"请求尝试 {attempt + 1} 失败: {str(e)}") |
|
|
|
if attempt == self.config.MAX_RETRIES - 1: |
|
return web.json_response( |
|
{"error": {"message": "Request failed after retries", "type": "server_error"}}, |
|
status=500 |
|
) |
|
|
|
return web.json_response( |
|
{"error": {"message": "Max retries exceeded", "type": "server_error"}}, |
|
status=500 |
|
) |
|
|
|
async def _handle_stream_response(self, original_request: web.Request, response: aiohttp.ClientResponse) -> web.Response: |
|
"""处理流式响应""" |
|
stream_response = web.StreamResponse( |
|
status=200, |
|
headers={ |
|
'Content-Type': 'text/event-stream', |
|
'Cache-Control': 'no-cache', |
|
'Connection': 'keep-alive', |
|
'Access-Control-Allow-Origin': '*' |
|
} |
|
) |
|
|
|
await stream_response.prepare(original_request) |
|
logger.info("开始处理流式响应") |
|
|
|
try: |
|
async for line in response.content: |
|
logger.info(f"流式响应: {line}") |
|
|
|
if original_request.transport is None or original_request.transport.is_closing(): |
|
logger.info("客户端在流式传输期间断开连接") |
|
break |
|
|
|
if not line: |
|
continue |
|
|
|
try: |
|
line_str = line.decode('utf-8').strip() |
|
|
|
|
|
if line_str.startswith('data: '): |
|
json_str = line_str[6:] |
|
|
|
|
|
if json_str == '[DONE]': |
|
continue |
|
|
|
|
|
openai_chunk = json.loads(json_str) |
|
sse_line = f"data: {json.dumps(openai_chunk, ensure_ascii=False)}\n\n" |
|
await stream_response.write(sse_line.encode('utf-8')) |
|
|
|
except json.JSONDecodeError: |
|
|
|
continue |
|
except Exception as e: |
|
logger.warning(f"处理流式数据时出错: {str(e)}") |
|
continue |
|
|
|
except Exception as e: |
|
logger.error(f"流式响应错误: {str(e)}") |
|
finally: |
|
|
|
await stream_response.write(b"data: [DONE]\n\n") |
|
logger.info("流式响应处理完成") |
|
|
|
return stream_response |
|
|
|
async def _handle_non_stream_response(self, response: aiohttp.ClientResponse) -> web.Response: |
|
"""处理非流式响应""" |
|
logger.info("开始处理非流式响应") |
|
response_data = await response.json() |
|
logger.info("非流式响应处理完成") |
|
return web.json_response(response_data) |
|
|
|
def _convert_chunk_to_openai_format(self, kilo_chunk: Dict[str, Any]) -> Dict[str, Any]: |
|
"""转换Kilo流式chunk为OpenAI格式""" |
|
openai_chunk = { |
|
"id": kilo_chunk.get("id", ""), |
|
"object": "chat.completion.chunk", |
|
"created": kilo_chunk.get("created", int(time.time())), |
|
"model": kilo_chunk.get("model", "gpt-3.5-turbo"), |
|
"choices": [] |
|
} |
|
|
|
if "choices" in kilo_chunk: |
|
for choice in kilo_chunk["choices"]: |
|
openai_choice = { |
|
"index": choice.get("index", 0), |
|
"delta": {}, |
|
"finish_reason": choice.get("finish_reason") |
|
} |
|
|
|
if "delta" in choice: |
|
delta = choice["delta"] |
|
if "role" in delta: |
|
openai_choice["delta"]["role"] = delta["role"] |
|
if "content" in delta: |
|
openai_choice["delta"]["content"] = delta["content"] |
|
|
|
openai_chunk["choices"].append(openai_choice) |
|
|
|
|
|
if "usage" in kilo_chunk: |
|
openai_chunk["usage"] = kilo_chunk["usage"] |
|
|
|
return openai_chunk |
|
|
|
def _convert_to_openai_format(self, kilo_response: Dict[str, Any]) -> Dict[str, Any]: |
|
"""转换Kilo响应为OpenAI格式""" |
|
openai_response = { |
|
"id": kilo_response.get("id", ""), |
|
"object": "chat.completion", |
|
"created": kilo_response.get("created", int(time.time())), |
|
"model": kilo_response.get("model", "gpt-3.5-turbo"), |
|
"choices": [], |
|
"usage": kilo_response.get("usage", {}) |
|
} |
|
|
|
if "choices" in kilo_response: |
|
for choice in kilo_response["choices"]: |
|
openai_choice = { |
|
"index": choice.get("index", 0), |
|
"message": choice.get("message", {}), |
|
"finish_reason": choice.get("finish_reason", "stop") |
|
} |
|
openai_response["choices"].append(openai_choice) |
|
|
|
return openai_response |
|
|
|
class ModelListHandler: |
|
"""模型列表处理器""" |
|
|
|
def __init__(self, config: ConfigManager): |
|
self.config = config |
|
|
|
async def handle_models(self, request: web.Request) -> web.Response: |
|
"""处理模型列表请求""" |
|
models = [] |
|
current_time = int(time.time()) |
|
|
|
|
|
for openai_model ,kilo_model in self.config.MODEL_MAPPING.items(): |
|
models.append({ |
|
"id": openai_model, |
|
"object": "model", |
|
"created": current_time, |
|
"owned_by": "kilo-proxy", |
|
"permission": [], |
|
"root": openai_model, |
|
"parent": None |
|
}) |
|
|
|
logger.info(f"返回 {len(models)} 个可用模型") |
|
return web.json_response({ |
|
"object": "list", |
|
"data": models |
|
}) |
|
|
|
class ProxyServer: |
|
"""代理服务器主类""" |
|
|
|
def __init__(self): |
|
self.config = ConfigManager() |
|
self.message_processor = MessageProcessor(self.config) |
|
self.token_manager = TokenManager(self.config) |
|
self.request_handler = RequestHandler(self.config, self.message_processor, self.token_manager) |
|
self.model_handler = ModelListHandler(self.config) |
|
self.app = self._create_app() |
|
|
|
def _create_app(self) -> web.Application: |
|
"""创建应用""" |
|
app = web.Application() |
|
|
|
|
|
app.router.add_post('/v1/chat/completions', self.request_handler.handle_chat_completion) |
|
app.router.add_get('/v1/models', self.model_handler.handle_models) |
|
|
|
|
|
app.middlewares.append(self._cors_middleware) |
|
|
|
return app |
|
|
|
async def _cors_middleware(self, app, handler): |
|
"""CORS中间件""" |
|
async def middleware_handler(request): |
|
if request.method == 'OPTIONS': |
|
return web.Response( |
|
headers={ |
|
'Access-Control-Allow-Origin': '*', |
|
'Access-Control-Allow-Methods': 'GET, POST, OPTIONS', |
|
'Access-Control-Allow-Headers': 'Content-Type, Authorization' |
|
} |
|
) |
|
|
|
response = await handler(request) |
|
response.headers['Access-Control-Allow-Origin'] = '*' |
|
return response |
|
|
|
return middleware_handler |
|
|
|
async def start(self): |
|
"""启动服务器""" |
|
runner = web.AppRunner(self.app) |
|
await runner.setup() |
|
|
|
site = web.TCPSite(runner, '0.0.0.0', self.config.PORT) |
|
await site.start() |
|
|
|
logger.info(f"Kilo代理服务器已启动 http://127.0.0.1:{self.config.PORT}") |
|
logger.info(f"系统角色模式: {self.config.SYSTEM_ROLE_MODE.value}") |
|
logger.info(f"可用token数量: {len(self.config.TOKEN_POOL)}") |
|
logger.info(f"余额检测间隔: {self.config.BALANCE_CHECK_INTERVAL}秒") |
|
logger.info(f"最小余额阈值: {self.config.MIN_BALANCE_THRESHOLD}") |
|
logger.info(f"模型映射: {self.config.MODEL_MAPPING}") |
|
logger.info(f"目标URL: {self.config.TARGET_URL}") |
|
logger.info(f"余额检测URL: {self.config.BALANCE_CHECK_URL}") |
|
|
|
|
|
await self.token_manager.start_balance_checker() |
|
|
|
|
|
try: |
|
|
|
while True: |
|
await asyncio.sleep(3600) |
|
except KeyboardInterrupt: |
|
logger.info("正在关闭服务器...") |
|
finally: |
|
await self.token_manager.stop_balance_checker() |
|
await runner.cleanup() |
|
|
|
def main(): |
|
"""主函数""" |
|
|
|
required_env_vars = ['API_KEY', 'TOKEN_POOL'] |
|
missing_vars = [var for var in required_env_vars if not os.getenv(var)] |
|
|
|
if missing_vars: |
|
logger.error(f"缺少必需的环境变量: {missing_vars}") |
|
logger.error("请设置以下环境变量:") |
|
logger.error("- API_KEY: 您的API密钥(多个用逗号分隔)") |
|
logger.error("- TOKEN_POOL: Kilo token池(多个token用逗号分隔)") |
|
sys.exit(1) |
|
|
|
|
|
server = ProxyServer() |
|
|
|
try: |
|
logger.info("正在启动Kilo代理服务器...") |
|
asyncio.run(server.start()) |
|
except KeyboardInterrupt: |
|
logger.info("服务器已被用户停止") |
|
except Exception as e: |
|
logger.error(f"服务器错误: {str(e)}") |
|
sys.exit(1) |
|
|
|
if __name__ == '__main__': |
|
main() |