import os from huggingface_hub import HfApi, hf_hub_download from apscheduler.schedulers.background import BackgroundScheduler from concurrent.futures import ThreadPoolExecutor from datetime import datetime import threading # Added for locking from huggingface_hub.hf_api import RepoFile from pydub import AudioSegment, silence from sqlalchemy import or_ # Added for vote counting query import hashlib import numpy as np import wave year = datetime.now().year month = datetime.now().month # Check if running in a Huggin Face Space IS_SPACES = False if os.getenv("SPACE_REPO_NAME"): print("Running in a Hugging Face Space 🤗") IS_SPACES = True # Setup database sync for HF Spaces if not os.path.exists("instance/tts_arena.db"): os.makedirs("instance", exist_ok=True) try: print("Database not found, downloading from HF dataset...") hf_hub_download( repo_id="kemuriririn/database-arena", filename="tts_arena.db", repo_type="dataset", local_dir="instance", token=os.getenv("HF_TOKEN"), ) print("Database downloaded successfully ✅") except Exception as e: print(f"Error downloading database from HF dataset: {str(e)} ⚠️") from flask import ( Flask, render_template, g, request, jsonify, send_file, redirect, url_for, session, ) from flask_login import LoginManager, current_user from models import * from auth import auth, init_oauth, is_admin from admin import admin import os from dotenv import load_dotenv from flask_limiter import Limiter from flask_limiter.util import get_remote_address import uuid import tempfile import shutil from tts import predict_tts import random import json from datetime import datetime, timedelta from flask_migrate import Migrate import requests # Load environment variables if not IS_SPACES: load_dotenv() # Only load .env if not running in a Hugging Face Space app = Flask(__name__) app.config["SECRET_KEY"] = os.getenv("SECRET_KEY", os.urandom(24)) app.config["SQLALCHEMY_DATABASE_URI"] = os.getenv( "DATABASE_URI", "sqlite:///tts_arena.db" ) app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False app.config["SESSION_COOKIE_SECURE"] = True app.config["SESSION_COOKIE_SAMESITE"] = ( "None" if IS_SPACES else "Lax" ) # HF Spaces uses iframes to load the app, so we need to set SAMESITE to None app.config["PERMANENT_SESSION_LIFETIME"] = timedelta(days=30) # Set to desired duration # Force HTTPS when running in HuggingFace Spaces if IS_SPACES: app.config["PREFERRED_URL_SCHEME"] = "https" # Cloudflare Turnstile settings app.config["TURNSTILE_ENABLED"] = ( os.getenv("TURNSTILE_ENABLED", "False").lower() == "true" ) app.config["TURNSTILE_SITE_KEY"] = os.getenv("TURNSTILE_SITE_KEY", "") app.config["TURNSTILE_SECRET_KEY"] = os.getenv("TURNSTILE_SECRET_KEY", "") app.config["TURNSTILE_VERIFY_URL"] = ( "https://challenges.cloudflare.com/turnstile/v0/siteverify" ) migrate = Migrate(app, db) # Initialize extensions db.init_app(app) login_manager = LoginManager() login_manager.init_app(app) login_manager.login_view = "auth.login" # Initialize OAuth init_oauth(app) # Configure rate limits limiter = Limiter( app=app, key_func=get_remote_address, default_limits=["2000 per day", "50 per minute"], storage_uri="memory://", ) # TTS Cache Configuration - Read from environment TTS_CACHE_SIZE = int(os.getenv("TTS_CACHE_SIZE", "10")) CACHE_AUDIO_SUBDIR = "cache" tts_cache = {} # sentence -> {model_a, model_b, audio_a, audio_b, created_at} tts_cache_lock = threading.Lock() preload_cache_lock = threading.Lock() SMOOTHING_FACTOR_MODEL_SELECTION = 500 # For weighted random model selection # Increased max_workers to 8 for concurrent generation/refill cache_executor = ThreadPoolExecutor(max_workers=8, thread_name_prefix='CacheReplacer') all_harvard_sentences = [] # Keep the full list available # Create temp directories TEMP_AUDIO_DIR = os.path.join(tempfile.gettempdir(), "tts_arena_audio") CACHE_AUDIO_DIR = os.path.join(TEMP_AUDIO_DIR, CACHE_AUDIO_SUBDIR) os.makedirs(TEMP_AUDIO_DIR, exist_ok=True) os.makedirs(CACHE_AUDIO_DIR, exist_ok=True) # Ensure cache subdir exists # 预载缓存和运行时缓存目录 PRELOAD_CACHE_DIR = os.path.join(CACHE_AUDIO_DIR, "preload") RUNTIME_CACHE_DIR = os.path.join(CACHE_AUDIO_DIR, "runtime") os.makedirs(PRELOAD_CACHE_DIR, exist_ok=True) os.makedirs(RUNTIME_CACHE_DIR, exist_ok=True) # --- 参考音色下载与管理 --- REFERENCE_AUDIO_DIR = os.path.join(TEMP_AUDIO_DIR, "reference_audios") REFERENCE_AUDIO_DATASET = os.getenv("REFERENCE_AUDIO_DATASET", "kemuriririn/arena-files") REFERENCE_AUDIO_PATTERN = os.getenv("REFERENCE_AUDIO_PATTERN", "reference_audios/") CACHE_AUDIO_PATTERN = os.getenv("CACHE_AUDIO_PATTERN", "cache_audios/") reference_audio_files = [] predefined_texts = [] # 预置文本库 predefined_prompts = {} # 预置prompt库,存储本地路径 def download_reference_audios(): """从 Hugging Face dataset 下载参考音频到本地目录,并生成文件列表""" global reference_audio_files os.makedirs(REFERENCE_AUDIO_DIR, exist_ok=True) try: api = HfApi(token=os.getenv("HF_TOKEN")) files = api.list_repo_files(repo_id=REFERENCE_AUDIO_DATASET, repo_type="dataset") # 只下载 wav 文件 wav_files = [f for f in files if f.startswith(REFERENCE_AUDIO_PATTERN) and f.endswith(".wav")] for f in wav_files: local_path = hf_hub_download( repo_id=REFERENCE_AUDIO_DATASET, filename=f, repo_type="dataset", local_dir=REFERENCE_AUDIO_DIR, token=os.getenv("HF_TOKEN"), ) reference_audio_files.append(local_path) print(f"Downloaded {len(reference_audio_files)} reference audios.") except Exception as e: print(f"Error downloading reference audios: {e}") reference_audio_files = [] # Store active TTS sessions app.tts_sessions = {} tts_sessions = app.tts_sessions # Store active conversational sessions app.conversational_sessions = {} conversational_sessions = app.conversational_sessions # Register blueprints app.register_blueprint(auth, url_prefix="/auth") app.register_blueprint(admin) @login_manager.user_loader def load_user(user_id): return User.query.get(int(user_id)) @app.before_request def before_request(): g.user = current_user g.is_admin = is_admin(current_user) # Ensure HTTPS for HuggingFace Spaces environment if IS_SPACES and request.headers.get("X-Forwarded-Proto") == "http": url = request.url.replace("http://", "https://", 1) return redirect(url, code=301) # Check if Turnstile verification is required if app.config["TURNSTILE_ENABLED"]: # Exclude verification routes excluded_routes = ["verify_turnstile", "turnstile_page", "static"] if request.endpoint not in excluded_routes: # Check if user is verified if not session.get("turnstile_verified"): # Save original URL for redirect after verification redirect_url = request.url # Force HTTPS in HuggingFace Spaces if IS_SPACES and redirect_url.startswith("http://"): redirect_url = redirect_url.replace("http://", "https://", 1) # If it's an API request, return a JSON response if request.path.startswith("/api/"): return jsonify({"error": "Turnstile verification required"}), 403 # For regular requests, redirect to verification page return redirect(url_for("turnstile_page", redirect_url=redirect_url)) else: # Check if verification has expired (default: 24 hours) verification_timeout = ( int(os.getenv("TURNSTILE_TIMEOUT_HOURS", "24")) * 3600 ) # Convert hours to seconds verified_at = session.get("turnstile_verified_at", 0) current_time = datetime.utcnow().timestamp() if current_time - verified_at > verification_timeout: # Verification expired, clear status and redirect to verification page session.pop("turnstile_verified", None) session.pop("turnstile_verified_at", None) redirect_url = request.url # Force HTTPS in HuggingFace Spaces if IS_SPACES and redirect_url.startswith("http://"): redirect_url = redirect_url.replace("http://", "https://", 1) if request.path.startswith("/api/"): return jsonify({"error": "Turnstile verification expired"}), 403 return redirect( url_for("turnstile_page", redirect_url=redirect_url) ) @app.route("/turnstile", methods=["GET"]) def turnstile_page(): """Display Cloudflare Turnstile verification page""" redirect_url = request.args.get("redirect_url", url_for("arena", _external=True)) # Force HTTPS in HuggingFace Spaces if IS_SPACES and redirect_url.startswith("http://"): redirect_url = redirect_url.replace("http://", "https://", 1) return render_template( "turnstile.html", turnstile_site_key=app.config["TURNSTILE_SITE_KEY"], redirect_url=redirect_url, ) @app.route("/verify-turnstile", methods=["POST"]) def verify_turnstile(): """Verify Cloudflare Turnstile token""" token = request.form.get("cf-turnstile-response") redirect_url = request.form.get("redirect_url", url_for("arena", _external=True)) # Force HTTPS in HuggingFace Spaces if IS_SPACES and redirect_url.startswith("http://"): redirect_url = redirect_url.replace("http://", "https://", 1) if not token: # If AJAX request, return JSON error if request.headers.get("X-Requested-With") == "XMLHttpRequest": return ( jsonify({"success": False, "error": "Missing verification token"}), 400, ) # Otherwise redirect back to turnstile page return redirect(url_for("turnstile_page", redirect_url=redirect_url)) # Verify token with Cloudflare data = { "secret": app.config["TURNSTILE_SECRET_KEY"], "response": token, "remoteip": request.remote_addr, } try: response = requests.post(app.config["TURNSTILE_VERIFY_URL"], data=data) result = response.json() if result.get("success"): # Set verification status in session session["turnstile_verified"] = True session["turnstile_verified_at"] = datetime.utcnow().timestamp() # Determine response type based on request is_xhr = request.headers.get("X-Requested-With") == "XMLHttpRequest" accepts_json = "application/json" in request.headers.get("Accept", "") # If AJAX or JSON request, return success JSON if is_xhr or accepts_json: return jsonify({"success": True, "redirect": redirect_url}) # For regular form submissions, redirect to the target URL return redirect(redirect_url) else: # Verification failed app.logger.warning(f"Turnstile verification failed: {result}") # If AJAX request, return JSON error if request.headers.get("X-Requested-With") == "XMLHttpRequest": return jsonify({"success": False, "error": "Verification failed"}), 403 # Otherwise redirect back to turnstile page return redirect(url_for("turnstile_page", redirect_url=redirect_url)) except Exception as e: app.logger.error(f"Turnstile verification error: {str(e)}") # If AJAX request, return JSON error if request.headers.get("X-Requested-With") == "XMLHttpRequest": return ( jsonify( {"success": False, "error": "Server error during verification"} ), 500, ) # Otherwise redirect back to turnstile page return redirect(url_for("turnstile_page", redirect_url=redirect_url)) def init_predefined_texts_and_prompts(): """初始化预置文本和prompt库""" global predefined_texts, predefined_prompts # 加载预置文本 predefined_texts = [] try: # 可以从sentences.txt和emotional_sentences.txt中提取精选句子 with open("init_sentences.txt", "r") as f: predefined_texts.extend([line.strip() for line in f.readlines() if line.strip()][:50]) # 确保文本唯一性 predefined_texts = list(set(predefined_texts)) print(f"加载了 {len(predefined_texts)} 条预置文本") except Exception as e: print(f"加载预置文本出错: {e}") predefined_texts = [] # 处理预置prompt音频,值为文件md5 predefined_prompts = {} try: # 从reference_audio_files中筛选10个作为预置prompt if reference_audio_files: selected_prompts = random.sample(reference_audio_files, min(50, len(reference_audio_files))) for i, prompt_path in enumerate(selected_prompts): prompt_name = f"preset_prompt_{i + 1}" # 计算文件md5 with open(prompt_path, 'rb') as f: prompt_md5 = hashlib.md5(f.read()).hexdigest() predefined_prompts[prompt_name] = prompt_md5 print(f"设置了 {len(predefined_prompts)} 个预置prompt音频") except Exception as e: print(f"设置预置prompt音频出错: {e}") predefined_prompts = {} with open("init_sentences.txt", "r") as f: all_harvard_sentences = [line.strip() for line in f.readlines() if line.strip()] # Shuffle for initial random selection if needed, but main list remains ordered initial_sentences = random.sample(all_harvard_sentences, min(len(all_harvard_sentences), 500)) # Limit initial pass for template @app.route("/") def arena(): # Pass a subset of sentences for the random button fallback return render_template("arena.html", harvard_sentences=json.dumps(initial_sentences)) @app.route("/leaderboard") def leaderboard(): tts_leaderboard = get_leaderboard_data(ModelType.TTS) conversational_leaderboard = get_leaderboard_data(ModelType.CONVERSATIONAL) top_voters = get_top_voters(10) # Get top 10 voters # Initialize personal leaderboard data tts_personal_leaderboard = None conversational_personal_leaderboard = None user_leaderboard_visibility = None # If user is logged in, get their personal leaderboard and visibility setting if current_user.is_authenticated: tts_personal_leaderboard = get_user_leaderboard(current_user.id, ModelType.TTS) conversational_personal_leaderboard = get_user_leaderboard( current_user.id, ModelType.CONVERSATIONAL ) user_leaderboard_visibility = current_user.show_in_leaderboard # Get key dates for the timeline tts_key_dates = get_key_historical_dates(ModelType.TTS) conversational_key_dates = get_key_historical_dates(ModelType.CONVERSATIONAL) # Format dates for display in the dropdown formatted_tts_dates = [date.strftime("%B %Y") for date in tts_key_dates] formatted_conversational_dates = [ date.strftime("%B %Y") for date in conversational_key_dates ] return render_template( "leaderboard.html", tts_leaderboard=tts_leaderboard, conversational_leaderboard=conversational_leaderboard, tts_personal_leaderboard=tts_personal_leaderboard, conversational_personal_leaderboard=conversational_personal_leaderboard, tts_key_dates=tts_key_dates, conversational_key_dates=conversational_key_dates, formatted_tts_dates=formatted_tts_dates, formatted_conversational_dates=formatted_conversational_dates, top_voters=top_voters, user_leaderboard_visibility=user_leaderboard_visibility ) @app.route("/api/historical-leaderboard/") def historical_leaderboard(model_type): """Get historical leaderboard data for a specific date""" if model_type not in [ModelType.TTS, ModelType.CONVERSATIONAL]: return jsonify({"error": "Invalid model type"}), 400 # Get date from query parameter date_str = request.args.get("date") if not date_str: return jsonify({"error": "Date parameter is required"}), 400 try: # Parse date from URL parameter (format: YYYY-MM-DD) target_date = datetime.strptime(date_str, "%Y-%m-%d") # Get historical leaderboard data leaderboard_data = get_historical_leaderboard_data(model_type, target_date) return jsonify( {"date": target_date.strftime("%B %d, %Y"), "leaderboard": leaderboard_data} ) except ValueError: return jsonify({"error": "Invalid date format. Use YYYY-MM-DD"}), 400 @app.route("/about") def about(): return render_template("about.html") # --- TTS Caching Functions --- def get_cached_audio_path(model_name, text, prompt_audio_path, cache_type="runtime"): """ 根据key生成本地缓存音频文件路径 """ key = get_tts_cache_key(model_name, text, prompt_audio_path) if cache_type == "preload": return os.path.join(PRELOAD_CACHE_DIR, model_name,f"{key}.wav") else: return os.path.join(RUNTIME_CACHE_DIR, model_name,f"{key}.wav") def find_cached_audio(model_name, text, prompt_audio_path): """ 查找缓存音频,优先查找预载缓存 """ preload_path = get_cached_audio_path(model_name, text, prompt_audio_path, cache_type="preload") if os.path.exists(preload_path): return preload_path runtime_path = get_cached_audio_path(model_name, text, prompt_audio_path, cache_type="runtime") if os.path.exists(runtime_path): return runtime_path return None def generate_and_save_tts(text, model_id, output_dir, prompt_audio_path=None): """Generates TTS and saves it to a specific directory, returning the full path.""" temp_audio_path = None # Initialize to None try: app.logger.debug(f"[TTS Gen {model_id}] Starting generation for: '{text[:30]}...'") # 参考音频可传入prompt_audio_path reference_audio_path = prompt_audio_path temp_audio_path = predict_tts(text, model_id, reference_audio_path=reference_audio_path) app.logger.debug(f"[TTS Gen {model_id}] predict_tts returned: {temp_audio_path}") if not temp_audio_path or not os.path.exists(temp_audio_path): app.logger.warning(f"[TTS Gen {model_id}] predict_tts failed or returned invalid path: {temp_audio_path}") raise ValueError("predict_tts did not return a valid path or file does not exist") # 用key命名 model_name = str(model_id) dest_path = get_cached_audio_path(model_name, text, reference_audio_path, cache_type="runtime") shutil.move(temp_audio_path, dest_path) app.logger.debug(f"[TTS Gen {model_id}] Move successful. Returning {dest_path}") return dest_path, reference_audio_path except Exception as e: app.logger.error(f"Error generating/saving TTS for model {model_id} and text '{text[:30]}...': {str(e)}") # Ensure temporary file from predict_tts (if any) is cleaned up if temp_audio_path and os.path.exists(temp_audio_path): try: app.logger.debug(f"[TTS Gen {model_id}] Cleaning up temporary file {temp_audio_path} after error.") os.remove(temp_audio_path) except OSError: pass # Ignore error if file couldn't be removed return None, None def _generate_cache_entry_task(sentence): """Task function to generate audio for a sentence and add to cache.""" # Wrap the entire task in an application context with app.app_context(): if not sentence: # Select a new sentence if not provided (for replacement) with tts_cache_lock: cached_keys = set(tts_cache.keys()) available_sentences = [s for s in all_harvard_sentences if s not in cached_keys] if not available_sentences: app.logger.warning("No more unique Harvard sentences available for caching.") return sentence = random.choice(available_sentences) # app.logger.info removed duplicate log print(f"[Cache Task] Querying models for: '{sentence[:50]}...'") available_models = Model.query.filter_by( model_type=ModelType.TTS, is_active=True ).all() if len(available_models) < 2: app.logger.error("Not enough active TTS models to generate cache entry.") return try: models = get_weighted_random_models(available_models, 2, ModelType.TTS) model_a_id = models[0].id model_b_id = models[1].id # Generate audio concurrently using a local executor for clarity within the task with ThreadPoolExecutor(max_workers=2, thread_name_prefix='AudioGen') as audio_executor: future_a = audio_executor.submit(generate_and_save_tts, sentence, model_a_id, CACHE_AUDIO_DIR) future_b = audio_executor.submit(generate_and_save_tts, sentence, model_b_id, CACHE_AUDIO_DIR) timeout_seconds = 120 audio_a_path, ref_a = future_a.result(timeout=timeout_seconds) audio_b_path, ref_b = future_b.result(timeout=timeout_seconds) if audio_a_path and audio_b_path: with tts_cache_lock: # Only add if the sentence isn't already back in the cache # And ensure cache size doesn't exceed limit if sentence not in tts_cache and len(tts_cache) < TTS_CACHE_SIZE: tts_cache[sentence] = { "model_a": model_a_id, "model_b": model_b_id, "audio_a": audio_a_path, "audio_b": audio_b_path, "ref_a": ref_a, "ref_b": ref_b, "created_at": datetime.utcnow(), } app.logger.info(f"Successfully cached entry for: '{sentence[:50]}...'") elif sentence in tts_cache: app.logger.warning( f"Sentence '{sentence[:50]}...' already re-cached. Discarding new generation.") # Clean up the newly generated files if not added if os.path.exists(audio_a_path): os.remove(audio_a_path) if os.path.exists(audio_b_path): os.remove(audio_b_path) else: # Cache is full app.logger.warning( f"Cache is full ({len(tts_cache)} entries). Discarding new generation for '{sentence[:50]}...'.") # Clean up the newly generated files if not added if os.path.exists(audio_a_path): os.remove(audio_a_path) if os.path.exists(audio_b_path): os.remove(audio_b_path) else: app.logger.error(f"Failed to generate one or both audio files for cache: '{sentence[:50]}...'") # Clean up whichever file might have been created if audio_a_path and os.path.exists(audio_a_path): os.remove(audio_a_path) if audio_b_path and os.path.exists(audio_b_path): os.remove(audio_b_path) except Exception as e: # Log the exception within the app context app.logger.error(f"Exception in _generate_cache_entry_task for '{sentence[:50]}...': {str(e)}", exc_info=True) # --- End TTS Caching Functions --- @app.route("/api/tts/generate", methods=["POST"]) @limiter.limit("10 per minute") # Keep limit, cached responses are still requests def generate_tts(): # If verification not setup, handle it first user_token = request.headers['x-ip-token'] if app.config["TURNSTILE_ENABLED"] and not session.get("turnstile_verified"): return jsonify({"error": "Turnstile verification required"}), 403 # 新增:支持 multipart/form-data 以接收音频文件 if request.content_type and request.content_type.startswith('multipart/form-data'): text = request.form.get("text", "").strip() voice_file = request.files.get("voice_file") reference_audio_path = None if voice_file: temp_voice_path = os.path.join(TEMP_AUDIO_DIR, f"ref_{uuid.uuid4()}.wav") voice_file.save(temp_voice_path) reference_audio_path = temp_voice_path else: data = request.json text = data.get("text", "").strip() # Ensure text is stripped reference_audio_path = None if not text or len(text) > 1000: return jsonify({"error": "Invalid or too long text"}), 400 prompt_md5 = '' if reference_audio_path and os.path.exists(reference_audio_path): with open(reference_audio_path, 'rb') as f: prompt_md5 = hashlib.md5(f.read()).hexdigest() # --- Cache Check --- cache_hit = False session_data_from_cache = None with tts_cache_lock: if text in tts_cache: cache_hit = True cached_entry = tts_cache.pop(text) # Remove from cache immediately app.logger.info(f"TTS Cache HIT for: '{text[:50]}...'") # Prepare session data using cached info session_id = str(uuid.uuid4()) session_data_from_cache = { "model_a": cached_entry["model_a"], "model_b": cached_entry["model_b"], "audio_a": cached_entry["audio_a"], # Paths are now from cache_dir "audio_b": cached_entry["audio_b"], "text": text, "created_at": datetime.utcnow(), "expires_at": datetime.utcnow() + timedelta(minutes=30), "voted": False, } app.tts_sessions[session_id] = session_data_from_cache if cache_hit and session_data_from_cache: # Return response using cached data # Note: The files are now managed by the session lifecycle (cleanup_session) return jsonify( { "session_id": session_id, "audio_a": f"/api/tts/audio/{session_id}/a", "audio_b": f"/api/tts/audio/{session_id}/b", "expires_in": 1800, # 30 minutes in seconds "cache_hit": True, } ) # --- End Cache Check --- # --- Cache Miss: Local File Cache --- # 对于预置文本和预置prompt,检查本地缓存 if text in predefined_texts and prompt_md5 in predefined_prompts.values(): app.logger.warning(f"TTS Cache MISS for: '{text[:50]}...'. Finding in local cache.") available_models = Model.query.filter_by( model_type=ModelType.TTS, is_active=True ).all() if len(available_models) < 2: return jsonify({"error": "Not enough TTS models available"}), 500 # 新增:a和b模型都需通过缓存检测 candidate_models = available_models.copy() valid_models = [] invalid_models = [] for model in candidate_models: audio_path = find_cached_audio(model.id, text, prompt_audio_path=reference_audio_path) app.logger.warning(f"Checking cached audio for model {model.id}: {audio_path}") if audio_path and os.path.exists(audio_path): valid_models.append(model) else: invalid_models.append(model) if len(valid_models) < 2: return jsonify({"error": "Not enough valid TTS model results available"}), 500 apply_filter_penalty_and_redistribute(invalid_models, valid_models, penalty_amount=1.0) # 从有结果的模型中随机选择两个 model_a,model_b = random.sample(valid_models, 2) audio_a_path = find_cached_audio(model_a.id, text, prompt_audio_path=reference_audio_path) audio_b_path = find_cached_audio(model_b.id, text, prompt_audio_path=reference_audio_path) session_id = str(uuid.uuid4()) app.tts_sessions[session_id] = { "model_a": model_a.id, "model_b": model_b.id, "audio_a": audio_a_path, "audio_b": audio_b_path, "text": text, "created_at": datetime.utcnow(), "expires_at": datetime.utcnow() + timedelta(minutes=30), "voted": False, } # 清理临时参考音频文件 if reference_audio_path and os.path.exists(reference_audio_path): os.remove(reference_audio_path) return jsonify({ "session_id": session_id, "audio_a": f"/api/tts/audio/{session_id}/a", "audio_b": f"/api/tts/audio/{session_id}/b", "expires_in": 1800, "cache_hit": True, }) # --- End Cache Miss --- else: app.logger.warning(f"TTS Cache MISS for: '{text[:50]}...'. Generating on the fly.") available_models = Model.query.filter_by( model_type=ModelType.TTS, is_active=True ).all() if len(available_models) < 2: return jsonify({"error": "Not enough TTS models available"}), 500 # Get two random models with weighted selection models = get_weighted_random_models(available_models, 2, ModelType.TTS) # Generate audio concurrently using a local executor for clarity within the request with ThreadPoolExecutor(max_workers=2, thread_name_prefix='AudioGen') as audio_executor: future_a = audio_executor.submit(generate_and_save_tts, text, models[0].id, RUNTIME_CACHE_DIR, prompt_audio_path=reference_audio_path) future_b = audio_executor.submit(generate_and_save_tts, text, models[1].id, RUNTIME_CACHE_DIR, prompt_audio_path=reference_audio_path) timeout_seconds = 120 audio_a_path, ref_a = future_a.result(timeout=timeout_seconds) audio_b_path, ref_b = future_b.result(timeout=timeout_seconds) if not audio_a_path or not audio_b_path: return jsonify({"error": "Failed to generate TTS audio"}), 500 session_id = str(uuid.uuid4()) app.tts_sessions[session_id] = { "model_a": models[0].id, "model_b": models[1].id, "audio_a": audio_a_path, "audio_b": audio_b_path, "text": text, "created_at": datetime.utcnow(), "expires_at": datetime.utcnow() + timedelta(minutes=30), "voted": False, } # Clean up temporary reference audio file if it was provided if reference_audio_path and os.path.exists(reference_audio_path): os.remove(reference_audio_path) # Return response with session ID and audio URLs return jsonify( { "session_id": session_id, "audio_a": f"/api/tts/audio/{session_id}/a", "audio_b": f"/api/tts/audio/{session_id}/b", "expires_in": 1800, # 30 minutes in seconds "cache_hit": False, } ) @app.route("/api/tts/audio//") def get_audio(session_id, model_key): # If verification not setup, handle it first if app.config["TURNSTILE_ENABLED"] and not session.get("turnstile_verified"): return jsonify({"error": "Turnstile verification required"}), 403 if session_id not in app.tts_sessions: return jsonify({"error": "Invalid or expired session"}), 404 session_data = app.tts_sessions[session_id] # Check if session expired if datetime.utcnow() > session_data["expires_at"]: cleanup_session(session_id) return jsonify({"error": "Session expired"}), 410 if model_key == "a": audio_path = session_data["audio_a"] elif model_key == "b": audio_path = session_data["audio_b"] else: return jsonify({"error": "Invalid model key"}), 400 # Check if file exists if not os.path.exists(audio_path): return jsonify({"error": "Audio file not found"}), 404 return send_file(audio_path, mimetype="audio/wav") @app.route("/api/tts/vote", methods=["POST"]) @limiter.limit("30 per minute") def submit_vote(): # If verification not setup, handle it first if app.config["TURNSTILE_ENABLED"] and not session.get("turnstile_verified"): return jsonify({"error": "Turnstile verification required"}), 403 data = request.json session_id = data.get("session_id") chosen_model_key = data.get("chosen_model") # "a" or "b" if not session_id or session_id not in app.tts_sessions: return jsonify({"error": "Invalid or expired session"}), 404 if not chosen_model_key or chosen_model_key not in ["a", "b"]: return jsonify({"error": "Invalid chosen model"}), 400 session_data = app.tts_sessions[session_id] # Check if session expired if datetime.utcnow() > session_data["expires_at"]: cleanup_session(session_id) return jsonify({"error": "Session expired"}), 410 # Check if already voted if session_data["voted"]: return jsonify({"error": "Vote already submitted for this session"}), 400 # Get model IDs and audio paths chosen_id = ( session_data["model_a"] if chosen_model_key == "a" else session_data["model_b"] ) rejected_id = ( session_data["model_b"] if chosen_model_key == "a" else session_data["model_a"] ) chosen_audio_path = ( session_data["audio_a"] if chosen_model_key == "a" else session_data["audio_b"] ) rejected_audio_path = ( session_data["audio_b"] if chosen_model_key == "a" else session_data["audio_a"] ) # Record vote in database user_id = current_user.id if current_user.is_authenticated else None vote, error = record_vote( user_id, session_data["text"], chosen_id, rejected_id, ModelType.TTS ) if error: return jsonify({"error": error}), 500 # --- Save preference data --- try: vote_uuid = str(uuid.uuid4()) vote_dir = os.path.join("./votes", vote_uuid) os.makedirs(vote_dir, exist_ok=True) # Copy audio files shutil.copy(chosen_audio_path, os.path.join(vote_dir, "chosen.wav")) shutil.copy(rejected_audio_path, os.path.join(vote_dir, "rejected.wav")) # Create metadata chosen_model_obj = Model.query.get(chosen_id) rejected_model_obj = Model.query.get(rejected_id) metadata = { "text": session_data["text"], "chosen_model": chosen_model_obj.name if chosen_model_obj else "Unknown", "chosen_model_id": chosen_model_obj.id if chosen_model_obj else "Unknown", "rejected_model": rejected_model_obj.name if rejected_model_obj else "Unknown", "rejected_model_id": rejected_model_obj.id if rejected_model_obj else "Unknown", "session_id": session_id, "timestamp": datetime.utcnow().isoformat(), "username": current_user.username if current_user.is_authenticated else None, "model_type": "TTS" } with open(os.path.join(vote_dir, "metadata.json"), "w") as f: json.dump(metadata, f, indent=2) except Exception as e: app.logger.error(f"Error saving preference data for vote {session_id}: {str(e)}") # Continue even if saving preference data fails, vote is already recorded # Mark session as voted session_data["voted"] = True # Return updated models (use previously fetched objects) return jsonify( { "success": True, "chosen_model": {"id": chosen_id, "name": chosen_model_obj.name if chosen_model_obj else "Unknown"}, "rejected_model": { "id": rejected_id, "name": rejected_model_obj.name if rejected_model_obj else "Unknown", }, "names": { "a": ( chosen_model_obj.name if chosen_model_key == "a" else rejected_model_obj.name if chosen_model_obj and rejected_model_obj else "Unknown" ), "b": ( rejected_model_obj.name if chosen_model_key == "a" else chosen_model_obj.name if chosen_model_obj and rejected_model_obj else "Unknown" ), }, } ) def cleanup_session(session_id): """Remove session and its audio files""" if session_id in app.tts_sessions: session = app.tts_sessions[session_id] # Remove audio files for audio_file in [session["audio_a"], session["audio_b"]]: if os.path.exists(audio_file): try: os.remove(audio_file) except Exception as e: app.logger.error(f"Error removing audio file: {str(e)}") # Remove session del app.tts_sessions[session_id] # Schedule periodic cleanup def setup_cleanup(): def cleanup_expired_sessions(): with app.app_context(): # Ensure app context for logging current_time = datetime.utcnow() # Cleanup TTS sessions expired_tts_sessions = [ sid for sid, session_data in app.tts_sessions.items() if current_time > session_data["expires_at"] ] for sid in expired_tts_sessions: cleanup_session(sid) app.logger.info( f"Cleaned up {len(expired_tts_sessions)} TTS.") # Also cleanup potentially expired cache entries (e.g., > 1 hour old) # This prevents stale cache entries if generation is slow or failing # cleanup_stale_cache_entries() # Run cleanup every 15 minutes scheduler = BackgroundScheduler(daemon=True) # Run scheduler as daemon thread scheduler.add_job(cleanup_expired_sessions, "interval", minutes=15) scheduler.start() print("Cleanup scheduler started") # Use print for startup messages # Schedule periodic tasks (database sync and preference upload) def setup_periodic_tasks(): """Setup periodic database synchronization and preference data upload for Spaces""" if not IS_SPACES: return db_path = app.config["SQLALCHEMY_DATABASE_URI"].replace("sqlite:///", "instance/") # Get relative path preferences_repo_id = "kemuriririn/arena-preferences" database_repo_id = "kemuriririn/database-arena" votes_dir = "./votes" def sync_database(): """Uploads the database to HF dataset""" with app.app_context(): # Ensure app context for logging try: if not os.path.exists(db_path): app.logger.warning(f"Database file not found at {db_path}, skipping sync.") return api = HfApi(token=os.getenv("HF_TOKEN")) api.upload_file( path_or_fileobj=db_path, path_in_repo="tts_arena.db", repo_id=database_repo_id, repo_type="dataset", ) app.logger.info(f"Database uploaded to {database_repo_id} at {datetime.utcnow()}") except Exception as e: app.logger.error(f"Error uploading database to {database_repo_id}: {str(e)}") def sync_preferences_data(): """Zips and uploads preference data folders in batches to HF dataset""" with app.app_context(): # Ensure app context for logging if not os.path.isdir(votes_dir): return # Don't log every 5 mins if dir doesn't exist yet temp_batch_dir = None # Initialize to manage cleanup temp_individual_zip_dir = None # Initialize for individual zips local_batch_zip_path = None # Initialize for batch zip path try: api = HfApi(token=os.getenv("HF_TOKEN")) vote_uuids = [d for d in os.listdir(votes_dir) if os.path.isdir(os.path.join(votes_dir, d))] if not vote_uuids: return # No data to process app.logger.info(f"Found {len(vote_uuids)} vote directories to process.") # Create temporary directories temp_batch_dir = tempfile.mkdtemp(prefix="hf_batch_") temp_individual_zip_dir = tempfile.mkdtemp(prefix="hf_indiv_zips_") app.logger.debug(f"Created temp directories: {temp_batch_dir}, {temp_individual_zip_dir}") processed_vote_dirs = [] individual_zips_in_batch = [] # 1. Create individual zips and move them to the batch directory for vote_uuid in vote_uuids: dir_path = os.path.join(votes_dir, vote_uuid) individual_zip_base_path = os.path.join(temp_individual_zip_dir, vote_uuid) individual_zip_path = f"{individual_zip_base_path}.zip" try: shutil.make_archive(individual_zip_base_path, 'zip', dir_path) app.logger.debug(f"Created individual zip: {individual_zip_path}") # Move the created zip into the batch directory final_individual_zip_path = os.path.join(temp_batch_dir, f"{vote_uuid}.zip") shutil.move(individual_zip_path, final_individual_zip_path) app.logger.debug(f"Moved individual zip to batch dir: {final_individual_zip_path}") processed_vote_dirs.append(dir_path) # Mark original dir for later cleanup individual_zips_in_batch.append(final_individual_zip_path) except Exception as zip_err: app.logger.error(f"Error creating or moving zip for {vote_uuid}: {str(zip_err)}") # Clean up partial zip if it exists if os.path.exists(individual_zip_path): try: os.remove(individual_zip_path) except OSError: pass # Continue processing other votes # Clean up the temporary dir used for creating individual zips shutil.rmtree(temp_individual_zip_dir) temp_individual_zip_dir = None # Mark as cleaned app.logger.debug("Cleaned up temporary individual zip directory.") if not individual_zips_in_batch: app.logger.warning("No individual zips were successfully created for batching.") # Clean up batch dir if it's empty or only contains failed attempts if temp_batch_dir and os.path.exists(temp_batch_dir): shutil.rmtree(temp_batch_dir) temp_batch_dir = None return # 2. Create the batch zip file batch_timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S") batch_uuid_short = str(uuid.uuid4())[:8] batch_zip_filename = f"{batch_timestamp}_batch_{batch_uuid_short}.zip" # Create batch zip in a standard temp location first local_batch_zip_base = os.path.join(tempfile.gettempdir(), batch_zip_filename.replace('.zip', '')) local_batch_zip_path = f"{local_batch_zip_base}.zip" app.logger.info( f"Creating batch zip: {local_batch_zip_path} with {len(individual_zips_in_batch)} individual zips.") shutil.make_archive(local_batch_zip_base, 'zip', temp_batch_dir) app.logger.info(f"Batch zip created successfully: {local_batch_zip_path}") # 3. Upload the batch zip file hf_repo_path = f"votes/{year}/{month}/{batch_zip_filename}" app.logger.info(f"Uploading batch zip to HF Hub: {preferences_repo_id}/{hf_repo_path}") api.upload_file( path_or_fileobj=local_batch_zip_path, path_in_repo=hf_repo_path, repo_id=preferences_repo_id, repo_type="dataset", commit_message=f"Add batch preference data {batch_zip_filename} ({len(individual_zips_in_batch)} votes)" ) app.logger.info(f"Successfully uploaded batch {batch_zip_filename} to {preferences_repo_id}") # 4. Cleanup after successful upload app.logger.info("Cleaning up local files after successful upload.") # Remove original vote directories that were successfully zipped and uploaded for dir_path in processed_vote_dirs: try: shutil.rmtree(dir_path) app.logger.debug(f"Removed original vote directory: {dir_path}") except OSError as e: app.logger.error(f"Error removing processed vote directory {dir_path}: {str(e)}") # Remove the temporary batch directory (containing the individual zips) shutil.rmtree(temp_batch_dir) temp_batch_dir = None app.logger.debug("Removed temporary batch directory.") # Remove the local batch zip file os.remove(local_batch_zip_path) local_batch_zip_path = None app.logger.debug("Removed local batch zip file.") app.logger.info(f"Finished preference data sync. Uploaded batch {batch_zip_filename}.") except Exception as e: app.logger.error(f"Error during preference data batch sync: {str(e)}", exc_info=True) # If upload failed, the local batch zip might exist, clean it up. if local_batch_zip_path and os.path.exists(local_batch_zip_path): try: os.remove(local_batch_zip_path) app.logger.debug("Cleaned up local batch zip after failed upload.") except OSError as clean_err: app.logger.error(f"Error cleaning up batch zip after failed upload: {clean_err}") # Do NOT remove temp_batch_dir if it exists; its contents will be retried next time. # Do NOT remove original vote directories if upload failed. finally: # Final cleanup for temporary directories in case of unexpected exits if temp_individual_zip_dir and os.path.exists(temp_individual_zip_dir): try: shutil.rmtree(temp_individual_zip_dir) except Exception as final_clean_err: app.logger.error(f"Error in final cleanup (indiv zips): {final_clean_err}") # Only clean up batch dir in finally block if it *wasn't* kept intentionally after upload failure if temp_batch_dir and os.path.exists(temp_batch_dir): # Check if an upload attempt happened and failed upload_failed = 'e' in locals() and isinstance(e, Exception) # Crude check if exception occurred if not upload_failed: # If no upload error or upload succeeded, clean up try: shutil.rmtree(temp_batch_dir) except Exception as final_clean_err: app.logger.error(f"Error in final cleanup (batch dir): {final_clean_err}") else: app.logger.warning("Keeping temporary batch directory due to upload failure for next attempt.") def sync_cache_audios_with_hf(): """ 将整个cache_audios目录打包上传到HF,并下载最新的cache_audios.zip解压覆盖本地(增量覆盖)。 """ os.makedirs(PRELOAD_CACHE_DIR, exist_ok=True) cache_zip_name = "cache_audios.zip" cache_zip_local_path = os.path.join(tempfile.gettempdir(), cache_zip_name) cache_zip_remote_path = os.path.join(CACHE_AUDIO_PATTERN, cache_zip_name) try: api = HfApi(token=os.getenv("HF_TOKEN")) # 1. 下载远端cache_audios.zip并解压到临时目录 try: print("Downloading cache_audios.zip from HF...") remote_zip_path = hf_hub_download( repo_id=REFERENCE_AUDIO_DATASET, filename=cache_zip_remote_path, repo_type="dataset", local_dir=tempfile.gettempdir(), token=os.getenv("HF_TOKEN"), force_download=True ) # 解压到临时目录 temp_unzip_dir = tempfile.mkdtemp(prefix="cache_audios_unzip_") shutil.unpack_archive(remote_zip_path, temp_unzip_dir) print(f"Unzipped cache_audios.zip to {temp_unzip_dir}") # 增量覆盖:逐个文件对比md5,有diff才覆盖 for root, _, files in os.walk(temp_unzip_dir): for fname in files: if fname.startswith("._"): # 忽略macOS的资源分支文件 continue rel_path = os.path.relpath(os.path.join(root, fname), temp_unzip_dir) src_file = os.path.join(root, fname) dst_file = os.path.join(PRELOAD_CACHE_DIR, rel_path) need_copy = True if os.path.exists(dst_file): # 对比md5 def file_md5(path): hash_md5 = hashlib.md5() with open(path, "rb") as f: for chunk in iter(lambda: f.read(4096), b""): hash_md5.update(chunk) return hash_md5.hexdigest() if file_md5(src_file) == file_md5(dst_file): need_copy = False if need_copy: os.makedirs(os.path.dirname(dst_file), exist_ok=True) shutil.copy2(src_file, dst_file) # print(f"Updated cache audio: {dst_file}") shutil.rmtree(temp_unzip_dir) except Exception as e: print(f"Download/unzip/incremental update cache_audios.zip failed: {e}") # 2. 打包本地cache_audios目录为zip # if os.path.exists(PRELOAD_CACHE_DIR): # # 先删除旧的zip # if os.path.exists(cache_zip_local_path): # os.remove(cache_zip_local_path) # shutil.make_archive(cache_zip_local_path.replace('.zip', ''), 'zip', PRELOAD_CACHE_DIR) # print(f"Packed {PRELOAD_CACHE_DIR} to {cache_zip_local_path}") # # # 3. 上传zip到HF # print(f"Uploading {cache_zip_local_path} to HF as {cache_zip_remote_path} ...") # api.upload_file( # path_or_fileobj=cache_zip_local_path, # path_in_repo=cache_zip_remote_path, # repo_id=REFERENCE_AUDIO_DATASET, # repo_type="dataset", # commit_message="Upload full cache_audios.zip" # ) # print("Upload cache_audios.zip done.") # # 上传后可删除本地zip # os.remove(cache_zip_local_path) except Exception as e: print(f"Error syncing cache_audios.zip with HF: {e}") app.logger.error(f"Error syncing cache_audios.zip with HF: {e}") # Schedule periodic tasks scheduler = BackgroundScheduler() # Sync database less frequently if needed, e.g., every 15 minutes scheduler.add_job(sync_database, "interval", minutes=15, id="sync_db_job") # Sync preferences more frequently scheduler.add_job(sync_preferences_data, "interval", minutes=5, id="sync_pref_job") scheduler.add_job(sync_cache_audios_with_hf, "interval", minutes=30, id="sync_cache_audio_job",next_run_time=datetime.now() + timedelta(seconds=30)) # Start after 30 seconds scheduler.start() print("Periodic tasks scheduler started (DB sync and Preferences upload)") # Use print for startup @app.cli.command("init-db") def init_db(): """Initialize the database.""" with app.app_context(): db.create_all() print("Database initialized!") @app.route("/api/toggle-leaderboard-visibility", methods=["POST"]) def toggle_leaderboard_visibility(): """Toggle whether the current user appears in the top voters leaderboard""" if not current_user.is_authenticated: return jsonify({"error": "You must be logged in to change this setting"}), 401 new_status = toggle_user_leaderboard_visibility(current_user.id) if new_status is None: return jsonify({"error": "User not found"}), 404 return jsonify({ "success": True, "visible": new_status, "message": "You are now visible in the voters leaderboard" if new_status else "You are now hidden from the voters leaderboard" }) @app.route("/api/tts/cached-sentences") def get_cached_sentences(): """Returns a list of sentences currently available in the TTS cache, with reference audio.""" with tts_cache_lock: cached = [ { "sentence": k, "model_a": v["model_a"], "model_b": v["model_b"], "ref_a": os.path.relpath(v["ref_a"], start=REFERENCE_AUDIO_DIR) if v.get("ref_a") else None, "ref_b": os.path.relpath(v["ref_b"], start=REFERENCE_AUDIO_DIR) if v.get("ref_b") else None, } for k, v in tts_cache.items() ] return jsonify(cached) @app.route("/api/tts/reference-audio/") def get_reference_audio(filename): """试听参考音频""" file_path = os.path.join(REFERENCE_AUDIO_DIR, filename) if not os.path.exists(file_path): return jsonify({"error": "Reference audio not found"}), 404 return send_file(file_path, mimetype="audio/wav") @app.route('/api/voice/random', methods=['GET']) def get_random_voice(): # 随机选择一个音频文件 random_voice = random.choice(reference_audio_files) voice_path = os.path.join(REFERENCE_AUDIO_DIR, random_voice) # 返回音频文件 return send_file(voice_path, mimetype='audio/' + voice_path.split('.')[-1]) def get_weighted_random_models( applicable_models: list[Model], num_to_select: int, model_type: ModelType ) -> list[Model]: """ Selects a specified number of models randomly from a list of applicable_models, weighting models with fewer votes higher. A smoothing factor is used to ensure the preference is slight and to prevent models with zero votes from being overwhelmingly favored. Models are selected without replacement. Assumes len(applicable_models) >= num_to_select, which should be checked by the caller. """ model_votes_counts = {} for model in applicable_models: votes = ( Vote.query.filter(Vote.model_type == model_type) .filter(or_(Vote.model_chosen == model.id, Vote.model_rejected == model.id)) .count() ) model_votes_counts[model.id] = votes weights = [ 1.0 / (model_votes_counts[model.id] + SMOOTHING_FACTOR_MODEL_SELECTION) for model in applicable_models ] selected_models_list = [] # Create copies to modify during selection process current_candidates = list(applicable_models) current_weights = list(weights) # Assumes num_to_select is positive and less than or equal to len(current_candidates) # Callers should ensure this (e.g., len(available_models) >= 2). for _ in range(num_to_select): if not current_candidates: # Safety break app.logger.warning("Not enough candidates left for weighted selection.") break chosen_model = random.choices(current_candidates, weights=current_weights, k=1)[0] selected_models_list.append(chosen_model) try: idx_to_remove = current_candidates.index(chosen_model) current_candidates.pop(idx_to_remove) current_weights.pop(idx_to_remove) except ValueError: # This should ideally not happen if chosen_model came from current_candidates. app.logger.error(f"Error removing model {chosen_model.id} from weighted selection candidates.") break # Avoid potential issues return selected_models_list def get_tts_cache_key(model_name, text, prompt_audio_path): """ 生成 TTS 缓存 key: md5(模型+文本+md5(prompt音频内容)) :param model_name: str :param text: str :param prompt_audio_path: str or None :return: str (md5 hash) """ prompt_md5 = '' if prompt_audio_path and os.path.exists(prompt_audio_path): with open(prompt_audio_path, 'rb') as f: prompt_content = f.read() prompt_md5 = hashlib.md5(prompt_content).hexdigest() key_str = f"{model_name}::{text}::{prompt_md5}" return hashlib.md5(key_str.encode('utf-8')).hexdigest() def has_long_silence(audio_path, min_silence_len_ms=10000, silence_thresh_db=-40): try: audio = AudioSegment.from_file(audio_path) silent_ranges = silence.detect_silence(audio, min_silence_len=min_silence_len_ms, silence_thresh=silence_thresh_db) return len(silent_ranges) > 0 except Exception as e: print(f"无法分析音频文件 {audio_path}: {e}") app.logger.error(f"无法分析音频文件 {audio_path}: {e}") return False if __name__ == "__main__": with app.app_context(): # Ensure ./instance and ./votes directories exist os.makedirs("instance", exist_ok=True) os.makedirs("./votes", exist_ok=True) # Create votes directory if it doesn't exist os.makedirs(CACHE_AUDIO_DIR, exist_ok=True) # Ensure cache audio dir exists # Clean up old cache audio files on startup try: app.logger.info(f"Clearing old cache audio files from {CACHE_AUDIO_DIR}") for filename in os.listdir(CACHE_AUDIO_DIR): file_path = os.path.join(CACHE_AUDIO_DIR, filename) try: if os.path.isfile(file_path) or os.path.islink(file_path): os.unlink(file_path) elif os.path.isdir(file_path): shutil.rmtree(file_path) except Exception as e: app.logger.error(f'Failed to delete {file_path}. Reason: {e}') except Exception as e: app.logger.error(f"Error clearing cache directory {CACHE_AUDIO_DIR}: {e}") # Download database if it doesn't exist (only on initial space start) if IS_SPACES and not os.path.exists(app.config["SQLALCHEMY_DATABASE_URI"].replace("sqlite:///", "")): try: print("Database not found, downloading from HF dataset...") hf_hub_download( repo_id="kemuriririn/database-arena", filename="tts_arena.db", repo_type="dataset", local_dir="instance", # download to instance/ token=os.getenv("HF_TOKEN"), ) print("Database downloaded successfully ✅") except Exception as e: print(f"Error downloading database from HF dataset: {str(e)} ⚠️") download_reference_audios() # 初始化预置文本和提示音频 init_predefined_texts_and_prompts() db.create_all() # Create tables if they don't exist insert_initial_models() # Setup background tasks # initialize_tts_cache() # Start populating the cache (关闭预生成缓存) setup_cleanup() setup_periodic_tasks() # Renamed function call # Configure Flask to recognize HTTPS when behind a reverse proxy from werkzeug.middleware.proxy_fix import ProxyFix # Apply ProxyFix middleware to handle reverse proxy headers # This ensures Flask generates correct URLs with https scheme # X-Forwarded-Proto header will be used to detect the original protocol app.wsgi_app = ProxyFix(app.wsgi_app, x_proto=1, x_host=1) # Force Flask to prefer HTTPS for generated URLs app.config["PREFERRED_URL_SCHEME"] = "https" from waitress import serve # Configuration for 2 vCPUs: # - threads: typically 4-8 threads per CPU core is a good balance # - connection_limit: maximum concurrent connections # - channel_timeout: prevent hanging connections threads = 12 # 6 threads per vCPU is a good balance for mixed IO/CPU workloads if IS_SPACES: serve( app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860)), threads=threads, connection_limit=100, channel_timeout=30, url_scheme='https' ) else: print(f"Starting Waitress server with {threads} threads") serve( app, host="0.0.0.0", port=5000, threads=threads, connection_limit=100, channel_timeout=30, url_scheme='https' # Keep https for local dev if using proxy/tunnel )