import json import logging import os import random import time import uuid import re import socket from concurrent.futures import ThreadPoolExecutor from functools import lru_cache, wraps from typing import Dict, Any, Callable, List, Tuple import requests import tiktoken from flask import Flask, Response, jsonify, request, stream_with_context from flask_cors import CORS from requests.adapters import HTTPAdapter from urllib3.util.connection import create_connection import urllib3 from cachetools import TTLCache import threading from time import sleep from datetime import datetime, timedelta # 新增导入 import register_bot # Constants CHAT_COMPLETION_CHUNK = 'chat.completion.chunk' CHAT_COMPLETION = 'chat.completion' CONTENT_TYPE_EVENT_STREAM = 'text/event-stream' _BASE_URL = "https://chat.notdiamond.ai" _API_BASE_URL = "https://spuckhogycrxcbomznwo.supabase.co" _USER_AGENT = 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/128.0.0.0 Safari/537.36' # 从环境变量获取API密钥和特定URL API_KEY = os.getenv('API_KEY') _PASTE_API_URL = os.getenv('PASTE_API_URL') _PASTE_API_PASSWORD = os.getenv('PASTE_API_PASSWORD') if not API_KEY: raise ValueError("API_KEY environment variable must be set") if not _PASTE_API_URL: raise ValueError("PASTE_API_URL environment variable must be set") app = Flask(__name__) logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) CORS(app, resources={r"/*": {"origins": "*"}}) executor = ThreadPoolExecutor(max_workers=10) proxy_url = os.getenv('PROXY_URL') NOTDIAMOND_IP = os.getenv('NOTDIAMOND_IP') NOTDIAMOND_DOMAIN = os.getenv('NOTDIAMOND_DOMAIN') if not NOTDIAMOND_IP: logger.error("NOTDIAMOND_IP environment variable is not set!") raise ValueError("NOTDIAMOND_IP must be set") # API密钥验证装饰器 def require_api_key(f): @wraps(f) def decorated_function(*args, **kwargs): auth_header = request.headers.get('Authorization') if not auth_header: return jsonify({'error': 'No API key provided'}), 401 try: # 从 Bearer token 中提取API密钥 provided_key = auth_header.split('Bearer ')[-1].strip() if provided_key != API_KEY: return jsonify({'error': 'Invalid API key'}), 401 except Exception: return jsonify({'error': 'Invalid Authorization header format'}), 401 return f(*args, **kwargs) return decorated_function refresh_token_cache = TTLCache(maxsize=1000, ttl=3600) headers_cache = TTLCache(maxsize=1, ttl=3600) # 1小时过期 token_refresh_lock = threading.Lock() # 自定义连接函数 def patched_create_connection(address, *args, **kwargs): host, port = address if host == NOTDIAMOND_DOMAIN: logger.info(f"Connecting to {NOTDIAMOND_DOMAIN} using IP: {NOTDIAMOND_IP}") return create_connection((NOTDIAMOND_IP, port), *args, **kwargs) return create_connection(address, *args, **kwargs) # 替换 urllib3 的默认连接函数 urllib3.util.connection.create_connection = patched_create_connection # 自定义 HTTPAdapter class CustomHTTPAdapter(HTTPAdapter): def init_poolmanager(self, *args, **kwargs): kwargs['socket_options'] = kwargs.get('socket_options', []) kwargs['socket_options'] += [(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)] return super(CustomHTTPAdapter, self).init_poolmanager(*args, **kwargs) # 创建自定义的 Session def create_custom_session(): session = requests.Session() adapter = CustomHTTPAdapter() session.mount('https://', adapter) session.mount('http://', adapter) return session # 添加速率限制相关的常量 AUTH_RETRY_DELAY = 60 # 认证重试延迟(秒) AUTH_BACKOFF_FACTOR = 2 # 退避因子 AUTH_MAX_RETRIES = 3 # 最大重试次数 AUTH_CHECK_INTERVAL = 300 # 健康检查间隔(秒) AUTH_RATE_LIMIT_WINDOW = 3600 # 速率限制窗口(秒) AUTH_MAX_REQUESTS = 100 # 每个窗口最大请求数 class AuthManager: def __init__(self, email: str, password: str): self._email: str = email self._password: str = password self._max_retries: int = 3 self._retry_delay: int = 1 self._api_key: str = "" self._user_info: Dict[str, Any] = {} self._refresh_token: str = "" self._access_token: str = "" self._token_expiry: float = 0 self._session: requests.Session = create_custom_session() self._logger: logging.Logger = logging.getLogger(__name__) self.model_status = {model: True for model in MODEL_INFO.keys()} # 添加新的属性来跟踪认证请求 self._last_auth_attempt = 0 self._auth_attempts = 0 self._auth_window_start = time.time() self._backoff_delay = AUTH_RETRY_DELAY def _should_attempt_auth(self) -> bool: """检查是否应该尝试认证请求""" current_time = time.time() # 检查是否在退避期内 if current_time - self._last_auth_attempt < self._backoff_delay: return False # 检查速率限制窗口 if current_time - self._auth_window_start > AUTH_RATE_LIMIT_WINDOW: # 重置窗口 self._auth_window_start = current_time self._auth_attempts = 0 self._backoff_delay = AUTH_RETRY_DELAY # 检查请求数量 if self._auth_attempts >= AUTH_MAX_REQUESTS: return False return True def login(self) -> bool: """改进的登录方法,包含速率限制和退避机制""" if not self._should_attempt_auth(): logger.warning(f"Rate limit reached for {self._email}, waiting {self._backoff_delay}s") return False try: self._last_auth_attempt = time.time() self._auth_attempts += 1 url = f"{_API_BASE_URL}/auth/v1/token?grant_type=password" headers = self._get_headers(with_content_type=True) data = { "email": self._email, "password": self._password, "gotrue_meta_security": {} } response = self._make_request('POST', url, headers=headers, json=data) if response.status_code == 429: self._backoff_delay *= AUTH_BACKOFF_FACTOR logger.warning(f"Rate limit hit, increasing backoff to {self._backoff_delay}s") return False response.raise_for_status() self._user_info = response.json() self._refresh_token = self._user_info.get('refresh_token', '') self._access_token = self._user_info.get('access_token', '') self._token_expiry = time.time() + self._user_info.get('expires_in', 3600) # 重置退避延迟 self._backoff_delay = AUTH_RETRY_DELAY self._log_values() return True except requests.RequestException as e: logger.error(f"\033[91m登录请求错误: {e}\033[0m") self._backoff_delay *= AUTH_BACKOFF_FACTOR return False def refresh_user_token(self) -> bool: url = f"{_API_BASE_URL}/auth/v1/token?grant_type=refresh_token" headers = self._get_headers(with_content_type=True) data = {"refresh_token": self._refresh_token} try: response = self._make_request('POST', url, headers=headers, json=data) self._user_info = response.json() self._refresh_token = self._user_info.get('refresh_token', '') self._access_token = self._user_info.get('access_token', '') self._token_expiry = time.time() + self._user_info.get('expires_in', 3600) self._log_values() return True except requests.RequestException as e: self._logger.error(f"刷新令牌请求错误: {e}") # 尝试重新登录 if self.login(): return True return False def get_jwt_value(self) -> str: """返回访问令牌。""" return self._access_token def is_token_valid(self) -> bool: """检查当前的访问令牌是否有效。""" return bool(self._access_token) and time.time() < self._token_expiry def ensure_valid_token(self) -> bool: """改进的token验证方法""" if self.is_token_valid(): return True if not self._should_attempt_auth(): return False if self._refresh_token and self.refresh_user_token(): return True return self.login() def clear_auth(self) -> None: """清除当前的授权信息。""" self._user_info = {} self._refresh_token = "" self._access_token = "" self._token_expiry = 0 def _log_values(self) -> None: """记录刷新令牌到日志中。""" self._logger.info(f"\033[92mRefresh Token: {self._refresh_token}\033[0m") self._logger.info(f"\033[92mAccess Token: {self._access_token}\033[0m") def _fetch_apikey(self) -> str: """获取API密钥。""" if self._api_key: return self._api_key try: login_url = f"{_BASE_URL}/login" response = self._make_request('GET', login_url) match = re.search(r'