Spaces:
Paused
Paused
import os | |
import requests | |
import json | |
import time | |
import random | |
import base64 | |
import uuid | |
import threading | |
from pathlib import Path | |
from dotenv import load_dotenv | |
import gradio as gr | |
import torch | |
import logging | |
from PIL import Image, ImageDraw, ImageFont | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
load_dotenv() | |
MODEL_URL = "TostAI/nsfw-text-detection-large" | |
CLASS_NAMES = {0: "✅ SAFE", 1: "⚠️ QUESTIONABLE", 2: "🚫 UNSAFE"} | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_URL) | |
model = AutoModelForSequenceClassification.from_pretrained(MODEL_URL) | |
class SessionManager: | |
_instances = {} | |
_lock = threading.Lock() | |
def get_session(cls, session_id): | |
with cls._lock: | |
if session_id not in cls._instances: | |
cls._instances[session_id] = { | |
'count': 0, | |
'history': [], | |
'last_active': time.time() | |
} | |
return cls._instances[session_id] | |
def cleanup_sessions(cls): | |
with cls._lock: | |
now = time.time() | |
expired = [k for k, v in cls._instances.items() if now - v['last_active'] > 3600] | |
for k in expired: | |
del cls._instances[k] | |
class RateLimiter: | |
def __init__(self): | |
self.clients = {} | |
self.lock = threading.Lock() | |
def check(self, client_id): | |
with self.lock: | |
now = time.time() | |
if client_id not in self.clients: | |
self.clients[client_id] = {'count': 1, 'reset': now + 3600} | |
return True | |
if now > self.clients[client_id]['reset']: | |
self.clients[client_id] = {'count': 1, 'reset': now + 3600} | |
return True | |
if self.clients[client_id]['count'] >= 10: | |
return False | |
self.clients[client_id]['count'] += 1 | |
return True | |
session_manager = SessionManager() | |
rate_limiter = RateLimiter() | |
def create_error_image(message): | |
img = Image.new("RGB", (832, 480), "#ffdddd") | |
try: | |
font = ImageFont.truetype("arial.ttf", 24) | |
except: | |
font = ImageFont.load_default() | |
draw = ImageDraw.Draw(img) | |
text = f"Error: {message[:60]}..." if len(message) > 60 else message | |
draw.text((50, 200), text, fill="#ff0000", font=font) | |
img.save("error.jpg") | |
return "error.jpg" | |
def classify_prompt(prompt): | |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
return torch.argmax(outputs.logits).item() | |
def image_to_base64(file_path): | |
try: | |
with open(file_path, "rb") as image_file: | |
raw_data = image_file.read() | |
encoded = base64.b64encode(raw_data) | |
missing_padding = len(encoded) % 4 | |
if missing_padding: | |
encoded += b'=' * (4 - missing_padding) | |
return encoded.decode('utf-8') | |
except Exception as e: | |
raise ValueError(f"Base64编码失败: {str(e)}") | |
def video_to_base64(file_path): | |
""" | |
将视频文件转换为Base64格式 | |
""" | |
try: | |
with open(file_path, "rb") as video_file: | |
raw_data = video_file.read() | |
encoded = base64.b64encode(raw_data) | |
missing_padding = len(encoded) % 4 | |
if missing_padding: | |
encoded += b'=' * (4 - missing_padding) | |
return encoded.decode('utf-8') | |
except Exception as e: | |
raise ValueError(f"Base64编码失败: {str(e)}") | |
def generate_video( | |
context_scale, | |
enable_safety_checker, | |
enable_fast_mode, | |
flow_shift, | |
guidance_scale, | |
images, | |
negative_prompt, | |
num_inference_steps, | |
prompt, | |
seed, | |
size, | |
task, | |
video, | |
session_id, | |
): | |
safety_level = classify_prompt(prompt) | |
if safety_level != 0: | |
error_img = create_error_image(CLASS_NAMES[safety_level]) | |
yield f"❌ Blocked: {CLASS_NAMES[safety_level]}", error_img | |
return | |
if not rate_limiter.check(session_id): | |
error_img = create_error_image("每小时限制20次请求") | |
yield "❌ rate limit exceeded", error_img | |
return | |
session = session_manager.get_session(session_id) | |
session['last_active'] = time.time() | |
session['count'] += 1 | |
API_KEY = "30a09de38569400bcdab9cec1c9a660b1924a2b5f54aa386eeb87f96a112fb93" | |
if not API_KEY: | |
error_img = create_error_image("API key not found") | |
yield "❌ Error: Missing API Key", error_img | |
return | |
try: | |
base64_images = [] | |
if images is not None: # 检查 images 是否为 None | |
for img_path in images: | |
base64_img = image_to_base64(img_path) | |
base64_images.append(base64_img) | |
except Exception as e: | |
error_img = create_error_image(str(e)) | |
yield f"❌failed to upload images: {str(e)}", error_img | |
return | |
video_payload = "" | |
if video is not None: | |
if isinstance(video, (list, tuple)): | |
video_payload = video[0] if video else "" | |
else: | |
video_payload = video | |
# 将视频文件转换为Base64格式 | |
try: | |
base64_video = video_to_base64(video_payload) | |
video_payload = base64_video | |
except Exception as e: | |
error_img = create_error_image(str(e)) | |
yield f"❌ Failed to encode video: {str(e)}", error_img | |
return | |
payload = { | |
"context_scale": context_scale, | |
"enable_fast_mode": enable_fast_mode, | |
"enable_safety_checker": enable_safety_checker, | |
"flow_shift": flow_shift, | |
"guidance_scale": guidance_scale, | |
"images": base64_images, | |
"negative_prompt": negative_prompt, | |
"num_inference_steps": num_inference_steps, | |
"prompt": prompt, | |
"seed": seed if seed != -1 else random.randint(0, 999999), | |
"size": size, | |
"task": task, | |
"video": str(video_payload) if video_payload else "", | |
} | |
logging.debug(f"API request payload: {json.dumps(payload, indent=2)}") | |
headers = { | |
"Content-Type": "application/json", | |
"Authorization": f"Bearer {API_KEY}", | |
} | |
try: | |
response = requests.post( | |
"https://api.wavespeed.ai/api/v2/wavespeed-ai/wan-2.1-14b-vace", | |
headers=headers, | |
data=json.dumps(payload) | |
) | |
if response.status_code != 200: | |
error_img = create_error_image(response.text) | |
yield f"❌ API Error ({response.status_code}): {response.text}", error_img | |
return | |
request_id = response.json()["data"]["id"] | |
yield f"✅ Task ID (ID: {request_id})", None | |
except Exception as e: | |
error_img = create_error_image(str(e)) | |
yield f"❌ Connection Error: {str(e)}", error_img | |
return | |
result_url = f"https://api.wavespeed.ai/api/v2/predictions/{request_id}/result" | |
start_time = time.time() | |
while True: | |
time.sleep(0.5) | |
try: | |
response = requests.get(result_url, headers=headers) | |
if response.status_code != 200: | |
error_img = create_error_image(response.text) | |
yield f"❌ 轮询错误 ({response.status_code}): {response.text}", error_img | |
return | |
data = response.json()["data"] | |
status = data["status"] | |
if status == "completed": | |
elapsed = time.time() - start_time | |
video_url = data['outputs'][0] | |
session["history"].append(video_url) | |
yield (f"🎉 完成! 耗时 {elapsed:.1f}秒\n" | |
f"下载链接: {video_url}"), video_url | |
return | |
elif status == "failed": | |
error_img = create_error_image(data.get('error', '未知错误')) | |
yield f"❌ 任务失败: {data.get('error', '未知错误')}", error_img | |
return | |
else: | |
yield f"⏳ 状态: {status.capitalize()}...", None | |
except Exception as e: | |
error_img = create_error_image(str(e)) | |
yield f"❌ 轮询失败: {str(e)}", error_img | |
return | |
def cleanup_task(): | |
while True: | |
session_manager.cleanup_sessions() | |
time.sleep(3600) | |
with gr.Blocks( | |
theme=gr.themes.Soft(), | |
css=""" | |
.video-preview { max-width: 600px !important; } | |
.status-box { padding: 10px; border-radius: 5px; margin: 5px; } | |
.safe { background: #e8f5e9; border: 1px solid #a5d6a7; } | |
.warning { background: #fff3e0; border: 1px solid #ffcc80; } | |
.error { background: #ffebee; border: 1px solid #ef9a9a; } | |
#centered_button { | |
align-self: center !important; | |
height: fit-content !important; | |
margin-top: 22px !important; # 根据输入框高度微调 | |
} | |
""" | |
) as app: | |
session_id = gr.State(str(uuid.uuid4())) | |
gr.Markdown("# 🌊Wan-2.1-14B-Vace Run On [WaveSpeedAI](https://wavespeed.ai/)") | |
gr.Markdown("""VACE is an all-in-one model designed for video creation and editing. It encompasses various tasks, including reference-to-video generation (R2V), video-to-video editing (V2V), and masked video-to-video editing (MV2V), allowing users to compose these tasks freely. This functionality enables users to explore diverse possibilities and streamlines their workflows effectively, offering a range of capabilities, such as Move-Anything, Swap-Anything, Reference-Anything, Expand-Anything, Animate-Anything, and more.""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
with gr.Row(): | |
images = gr.File(label="upload image", file_count="multiple", file_types=["image"], type="filepath", elem_id="image-uploader", | |
scale=1) | |
video = gr.Video(label="Input Video", format="mp4", sources=["upload"], | |
scale=1) | |
prompt = gr.Textbox(label="Prompt", lines=5, placeholder="Prompt") | |
negative_prompt = gr.Textbox(label="Negative Prompt", lines=2) | |
with gr.Row(): | |
size = gr.Dropdown(["832*480", "480*832"], value="832*480", label="Size") | |
task = gr.Dropdown(["depth", "pose"], value="depth", label="Task") | |
with gr.Row(): | |
num_inference_steps = gr.Slider(1, 100, value=30, step=1, label="Inference Steps") | |
context_scale = gr.Slider(0, 2, value=1, step=0.1, label="Context Scale") | |
with gr.Row(): | |
guidance = gr.Slider(1, 20, value=5, step=0.1, label="Guidance_Scale") | |
flow_shift = gr.Slider(1, 20, value=16, step=1, label="Shift") | |
with gr.Row(): | |
seed = gr.Number(-1, label="Seed") | |
random_seed_btn = gr.Button("Random🎲Seed", variant="secondary", elem_id="centered_button") | |
with gr.Row(): | |
enable_safety_checker = gr.Checkbox(True, label="Enable Safety Checker", interactive=True) | |
enable_fast_mode = gr.Checkbox(True, label="To enable the fast mode, please visit Wave Speed AI", interactive=False) | |
with gr.Column(scale=1): | |
video_output = gr.Video(label="Video Output", format="mp4", interactive=False, elem_classes=["video-preview"]) | |
generate_btn = gr.Button("Generate", variant="primary") | |
status_output = gr.Textbox(label="status", interactive=False, lines=4) | |
# gr.Examples( | |
# examples=[ | |
# [ | |
# "The elegant lady carefully selects bags in the boutique, and she shows the charm of a mature woman in a black slim dress with a pearl necklace, as well as her pretty face. Holding a vintage-inspired blue leather half-moon handbag, she is carefully observing its craftsmanship and texture. The interior of the store is a haven of sophistication and luxury. Soft, ambient lighting casts a warm glow over the polished wooden floors", | |
# [ | |
# "https://d2g64w682n9w0w.cloudfront.net/media/ec44bbf6abac4c25998dd2c4af1a46a7/images/1747413751234102420_md9ywspl.png", | |
# "https://d2g64w682n9w0w.cloudfront.net/media/ec44bbf6abac4c25998dd2c4af1a46a7/images/1747413586520964413_7bkgc9ol.png" | |
# ] | |
# ] | |
# ], | |
# inputs=[prompt, images], | |
# ) | |
random_seed_btn.click( | |
fn=lambda: random.randint(0, 999999), | |
outputs=seed | |
) | |
generate_btn.click( | |
generate_video, | |
inputs=[ | |
context_scale, | |
enable_safety_checker, | |
enable_fast_mode, | |
flow_shift, | |
guidance, | |
images, | |
negative_prompt, | |
num_inference_steps, | |
prompt, | |
seed, | |
size, | |
task, | |
video, | |
session_id, | |
], | |
outputs=[status_output, video_output] | |
) | |
logging.basicConfig( | |
level=logging.DEBUG, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
handlers=[ | |
logging.FileHandler("gradio_app.log"), | |
logging.StreamHandler() | |
] | |
) | |
gradio_logger = logging.getLogger("gradio") | |
gradio_logger.setLevel(logging.INFO) | |
if __name__ == "__main__": | |
threading.Thread(target=cleanup_task, daemon=True).start() | |
app.queue(max_size=2).launch( | |
server_name="0.0.0.0", | |
max_threads=10, | |
share=False | |
) |