import os import requests import json import time import random import base64 import uuid import threading from pathlib import Path from dotenv import load_dotenv import gradio as gr import torch import logging from PIL import Image, ImageDraw, ImageFont from transformers import AutoTokenizer, AutoModelForSequenceClassification load_dotenv() MODEL_URL = "TostAI/nsfw-text-detection-large" CLASS_NAMES = {0: "✅ SAFE", 1: "⚠️ QUESTIONABLE", 2: "🚫 UNSAFE"} tokenizer = AutoTokenizer.from_pretrained(MODEL_URL) model = AutoModelForSequenceClassification.from_pretrained(MODEL_URL) class SessionManager: _instances = {} _lock = threading.Lock() @classmethod def get_session(cls, session_id): with cls._lock: if session_id not in cls._instances: cls._instances[session_id] = { 'count': 0, 'history': [], 'last_active': time.time() } return cls._instances[session_id] @classmethod def cleanup_sessions(cls): with cls._lock: now = time.time() expired = [k for k, v in cls._instances.items() if now - v['last_active'] > 3600] for k in expired: del cls._instances[k] class RateLimiter: def __init__(self): self.clients = {} self.lock = threading.Lock() def check(self, client_id): with self.lock: now = time.time() if client_id not in self.clients: self.clients[client_id] = {'count': 1, 'reset': now + 3600} return True if now > self.clients[client_id]['reset']: self.clients[client_id] = {'count': 1, 'reset': now + 3600} return True if self.clients[client_id]['count'] >= 10: return False self.clients[client_id]['count'] += 1 return True session_manager = SessionManager() rate_limiter = RateLimiter() def create_error_image(message): img = Image.new("RGB", (832, 480), "#ffdddd") try: font = ImageFont.truetype("arial.ttf", 24) except: font = ImageFont.load_default() draw = ImageDraw.Draw(img) text = f"Error: {message[:60]}..." if len(message) > 60 else message draw.text((50, 200), text, fill="#ff0000", font=font) img.save("error.jpg") return "error.jpg" def classify_prompt(prompt): inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) with torch.no_grad(): outputs = model(**inputs) return torch.argmax(outputs.logits).item() def image_to_base64(file_path): try: with open(file_path, "rb") as image_file: raw_data = image_file.read() encoded = base64.b64encode(raw_data) missing_padding = len(encoded) % 4 if missing_padding: encoded += b'=' * (4 - missing_padding) return encoded.decode('utf-8') except Exception as e: raise ValueError(f"Base64编码失败: {str(e)}") def video_to_base64(file_path): """ 将视频文件转换为Base64格式 """ try: with open(file_path, "rb") as video_file: raw_data = video_file.read() encoded = base64.b64encode(raw_data) missing_padding = len(encoded) % 4 if missing_padding: encoded += b'=' * (4 - missing_padding) return encoded.decode('utf-8') except Exception as e: raise ValueError(f"Base64编码失败: {str(e)}") def generate_video( context_scale, enable_safety_checker, enable_fast_mode, flow_shift, guidance_scale, images, negative_prompt, num_inference_steps, prompt, seed, size, task, video, session_id, ): safety_level = classify_prompt(prompt) if safety_level != 0: error_img = create_error_image(CLASS_NAMES[safety_level]) yield f"❌ Blocked: {CLASS_NAMES[safety_level]}", error_img return if not rate_limiter.check(session_id): error_img = create_error_image("每小时限制20次请求") yield "❌ rate limit exceeded", error_img return session = session_manager.get_session(session_id) session['last_active'] = time.time() session['count'] += 1 API_KEY = "30a09de38569400bcdab9cec1c9a660b1924a2b5f54aa386eeb87f96a112fb93" if not API_KEY: error_img = create_error_image("API key not found") yield "❌ Error: Missing API Key", error_img return try: base64_images = [] if images is not None: # 检查 images 是否为 None for img_path in images: base64_img = image_to_base64(img_path) base64_images.append(base64_img) except Exception as e: error_img = create_error_image(str(e)) yield f"❌failed to upload images: {str(e)}", error_img return video_payload = "" if video is not None: if isinstance(video, (list, tuple)): video_payload = video[0] if video else "" else: video_payload = video # 将视频文件转换为Base64格式 try: base64_video = video_to_base64(video_payload) video_payload = base64_video except Exception as e: error_img = create_error_image(str(e)) yield f"❌ Failed to encode video: {str(e)}", error_img return payload = { "context_scale": context_scale, "enable_fast_mode": enable_fast_mode, "enable_safety_checker": enable_safety_checker, "flow_shift": flow_shift, "guidance_scale": guidance_scale, "images": base64_images, "negative_prompt": negative_prompt, "num_inference_steps": num_inference_steps, "prompt": prompt, "seed": seed if seed != -1 else random.randint(0, 999999), "size": size, "task": task, "video": str(video_payload) if video_payload else "", } logging.debug(f"API request payload: {json.dumps(payload, indent=2)}") headers = { "Content-Type": "application/json", "Authorization": f"Bearer {API_KEY}", } try: response = requests.post( "https://api.wavespeed.ai/api/v2/wavespeed-ai/wan-2.1-14b-vace", headers=headers, data=json.dumps(payload) ) if response.status_code != 200: error_img = create_error_image(response.text) yield f"❌ API Error ({response.status_code}): {response.text}", error_img return request_id = response.json()["data"]["id"] yield f"✅ Task ID (ID: {request_id})", None except Exception as e: error_img = create_error_image(str(e)) yield f"❌ Connection Error: {str(e)}", error_img return result_url = f"https://api.wavespeed.ai/api/v2/predictions/{request_id}/result" start_time = time.time() while True: time.sleep(0.5) try: response = requests.get(result_url, headers=headers) if response.status_code != 200: error_img = create_error_image(response.text) yield f"❌ 轮询错误 ({response.status_code}): {response.text}", error_img return data = response.json()["data"] status = data["status"] if status == "completed": elapsed = time.time() - start_time video_url = data['outputs'][0] session["history"].append(video_url) yield (f"🎉 完成! 耗时 {elapsed:.1f}秒\n" f"下载链接: {video_url}"), video_url return elif status == "failed": error_img = create_error_image(data.get('error', '未知错误')) yield f"❌ 任务失败: {data.get('error', '未知错误')}", error_img return else: yield f"⏳ 状态: {status.capitalize()}...", None except Exception as e: error_img = create_error_image(str(e)) yield f"❌ 轮询失败: {str(e)}", error_img return def cleanup_task(): while True: session_manager.cleanup_sessions() time.sleep(3600) with gr.Blocks( theme=gr.themes.Soft(), css=""" .video-preview { max-width: 600px !important; } .status-box { padding: 10px; border-radius: 5px; margin: 5px; } .safe { background: #e8f5e9; border: 1px solid #a5d6a7; } .warning { background: #fff3e0; border: 1px solid #ffcc80; } .error { background: #ffebee; border: 1px solid #ef9a9a; } #centered_button { align-self: center !important; height: fit-content !important; margin-top: 22px !important; # 根据输入框高度微调 } """ ) as app: session_id = gr.State(str(uuid.uuid4())) gr.Markdown("# 🌊Wan-2.1-14B-Vace Run On [WaveSpeedAI](https://wavespeed.ai/)") gr.Markdown("""VACE is an all-in-one model designed for video creation and editing. It encompasses various tasks, including reference-to-video generation (R2V), video-to-video editing (V2V), and masked video-to-video editing (MV2V), allowing users to compose these tasks freely. This functionality enables users to explore diverse possibilities and streamlines their workflows effectively, offering a range of capabilities, such as Move-Anything, Swap-Anything, Reference-Anything, Expand-Anything, Animate-Anything, and more.""") with gr.Row(): with gr.Column(scale=1): with gr.Row(): images = gr.File(label="upload image", file_count="multiple", file_types=["image"], type="filepath", elem_id="image-uploader", scale=1) video = gr.Video(label="Input Video", format="mp4", sources=["upload"], scale=1) prompt = gr.Textbox(label="Prompt", lines=5, placeholder="Prompt") negative_prompt = gr.Textbox(label="Negative Prompt", lines=2) with gr.Row(): size = gr.Dropdown(["832*480", "480*832"], value="832*480", label="Size") task = gr.Dropdown(["depth", "pose"], value="depth", label="Task") with gr.Row(): num_inference_steps = gr.Slider(1, 100, value=30, step=1, label="Inference Steps") context_scale = gr.Slider(0, 2, value=1, step=0.1, label="Context Scale") with gr.Row(): guidance = gr.Slider(1, 20, value=5, step=0.1, label="Guidance_Scale") flow_shift = gr.Slider(1, 20, value=16, step=1, label="Shift") with gr.Row(): seed = gr.Number(-1, label="Seed") random_seed_btn = gr.Button("Random🎲Seed", variant="secondary", elem_id="centered_button") with gr.Row(): enable_safety_checker = gr.Checkbox(True, label="Enable Safety Checker", interactive=True) enable_fast_mode = gr.Checkbox(True, label="To enable the fast mode, please visit Wave Speed AI", interactive=False) with gr.Column(scale=1): video_output = gr.Video(label="Video Output", format="mp4", interactive=False, elem_classes=["video-preview"]) generate_btn = gr.Button("Generate", variant="primary") status_output = gr.Textbox(label="status", interactive=False, lines=4) # gr.Examples( # examples=[ # [ # "The elegant lady carefully selects bags in the boutique, and she shows the charm of a mature woman in a black slim dress with a pearl necklace, as well as her pretty face. Holding a vintage-inspired blue leather half-moon handbag, she is carefully observing its craftsmanship and texture. The interior of the store is a haven of sophistication and luxury. Soft, ambient lighting casts a warm glow over the polished wooden floors", # [ # "https://d2g64w682n9w0w.cloudfront.net/media/ec44bbf6abac4c25998dd2c4af1a46a7/images/1747413751234102420_md9ywspl.png", # "https://d2g64w682n9w0w.cloudfront.net/media/ec44bbf6abac4c25998dd2c4af1a46a7/images/1747413586520964413_7bkgc9ol.png" # ] # ] # ], # inputs=[prompt, images], # ) random_seed_btn.click( fn=lambda: random.randint(0, 999999), outputs=seed ) generate_btn.click( generate_video, inputs=[ context_scale, enable_safety_checker, enable_fast_mode, flow_shift, guidance, images, negative_prompt, num_inference_steps, prompt, seed, size, task, video, session_id, ], outputs=[status_output, video_output] ) logging.basicConfig( level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler("gradio_app.log"), logging.StreamHandler() ] ) gradio_logger = logging.getLogger("gradio") gradio_logger.setLevel(logging.INFO) if __name__ == "__main__": threading.Thread(target=cleanup_task, daemon=True).start() app.queue(max_size=2).launch( server_name="0.0.0.0", max_threads=10, share=False )