import spaces import logging from datetime import datetime from pathlib import Path import gradio as gr import torch import torchaudio import os import requests from transformers import pipeline import tempfile import numpy as np from einops import rearrange import cv2 from scipy.io import wavfile import librosa import json from typing import Optional, Tuple, List import atexit # 환경 변수 설정으로 torch.load 체크 우회 (임시 해결책) os.environ["TRANSFORMERS_ALLOW_UNSAFE_DESERIALIZATION"] = "1" try: import mmaudio except ImportError: os.system("pip install -e .") import mmaudio from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video, setup_eval_logging) from mmaudio.model.flow_matching import FlowMatching from mmaudio.model.networks import MMAudio, get_my_mmaudio from mmaudio.model.sequence_config import SequenceConfig from mmaudio.model.utils.features_utils import FeaturesUtils # 로깅 설정 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) log = logging.getLogger() # CUDA 설정 if torch.cuda.is_available(): device = torch.device("cuda") torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.benchmark = True else: device = torch.device("cpu") dtype = torch.bfloat16 # 모델 설정 model: ModelConfig = all_model_cfg['large_44k_v2'] model.download_if_needed() output_dir = Path('./output/gradio') setup_eval_logging() # 번역기 설정 - safetensors 사용 시도 try: # 먼저 safetensors 형식이 있는지 확인 translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en", device="cpu", use_fast=True, # Fast tokenizer 사용 trust_remote_code=False) except Exception as e: log.warning(f"Failed to load translation model with safetensors: {e}") # 대체 방법: 환경 변수 설정 후 로드 try: translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en", device="cpu") except Exception as e2: log.error(f"Failed to load translation model: {e2}") translator = None PIXABAY_API_KEY = "33492762-a28a596ec4f286f84cd328b17" def cleanup_temp_files(): temp_dir = tempfile.gettempdir() for file in os.listdir(temp_dir): if file.endswith(('.mp4', '.flac')): try: os.remove(os.path.join(temp_dir, file)) except: pass atexit.register(cleanup_temp_files) def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]: with torch.cuda.device(device): seq_cfg = model.seq_cfg net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval() net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True)) log.info(f'Loaded weights from {model.model_path}') feature_utils = FeaturesUtils( tod_vae_ckpt=model.vae_path, synchformer_ckpt=model.synchformer_ckpt, enable_conditions=True, mode=model.mode, bigvgan_vocoder_ckpt=model.bigvgan_16k_path, need_vae_encoder=False ).to(device, dtype).eval() return net, feature_utils, seq_cfg net, feature_utils, seq_cfg = get_model() # translate_prompt 함수 수정 def translate_prompt(text): try: # 번역기가 없으면 원본 텍스트 반환 if translator is None: return text if text and any(ord(char) >= 0x3131 and ord(char) <= 0xD7A3 for char in text): # CPU에서 번역 실행 with torch.no_grad(): translation = translator(text)[0]['translation_text'] return translation return text except Exception as e: logging.error(f"Translation error: {e}") return text # search_videos 함수 수정 @torch.no_grad() def search_videos(query): try: # CPU에서 번역 실행 query = translate_prompt(query) return search_pixabay_videos(query, PIXABAY_API_KEY) except Exception as e: logging.error(f"Video search error: {e}") return [] def search_pixabay_videos(query, api_key): try: base_url = "https://pixabay.com/api/videos/" params = { "key": api_key, "q": query, "per_page": 40 } response = requests.get(base_url, params=params) if response.status_code == 200: data = response.json() return [video['videos']['large']['url'] for video in data.get('hits', [])] return [] except Exception as e: logging.error(f"Pixabay API error: {e}") return [] @spaces.GPU @torch.inference_mode() def video_to_audio(video: gr.Video, prompt: str, negative_prompt: str, seed: int, num_steps: int, cfg_strength: float, duration: float): prompt = translate_prompt(prompt) negative_prompt = translate_prompt(negative_prompt) rng = torch.Generator(device=device) rng.manual_seed(seed) fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps) clip_frames, sync_frames, duration = load_video(video, duration) clip_frames = clip_frames.unsqueeze(0) sync_frames = sync_frames.unsqueeze(0) seq_cfg.duration = duration net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len) audios = generate(clip_frames, sync_frames, [prompt], negative_text=[negative_prompt], feature_utils=feature_utils, net=net, fm=fm, rng=rng, cfg_strength=cfg_strength) audio = audios.float().cpu()[0] video_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name make_video(video, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate, duration_sec=seq_cfg.duration) return video_save_path @spaces.GPU @torch.inference_mode() def text_to_audio(prompt: str, negative_prompt: str, seed: int, num_steps: int, cfg_strength: float, duration: float): prompt = translate_prompt(prompt) negative_prompt = translate_prompt(negative_prompt) rng = torch.Generator(device=device) rng.manual_seed(seed) fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps) clip_frames = sync_frames = None seq_cfg.duration = duration net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len) audios = generate(clip_frames, sync_frames, [prompt], negative_text=[negative_prompt], feature_utils=feature_utils, net=net, fm=fm, rng=rng, cfg_strength=cfg_strength) audio = audios.float().cpu()[0] audio_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.flac').name torchaudio.save(audio_save_path, audio, seq_cfg.sampling_rate) return audio_save_path # CSS 스타일 custom_css = """ .gradio-container { background: linear-gradient(45deg, #1a1a1a, #2a2a2a); border-radius: 15px; box-shadow: 0 8px 32px rgba(0,0,0,0.3); color: #e0e0e0; } .input-container, .output-container { background: rgba(40, 40, 40, 0.95); backdrop-filter: blur(10px); border-radius: 10px; padding: 20px; transform-style: preserve-3d; transition: transform 0.3s ease; border: 1px solid rgba(255, 255, 255, 0.1); } .input-container:hover { transform: translateZ(20px); box-shadow: 0 8px 32px rgba(0,0,0,0.5); } .gallery-item { transition: transform 0.3s ease; border-radius: 8px; overflow: hidden; background: #2a2a2a; } .gallery-item:hover { transform: scale(1.05); box-shadow: 0 4px 15px rgba(0,0,0,0.4); } .tabs { background: rgba(30, 30, 30, 0.95); border-radius: 10px; padding: 10px; border: 1px solid rgba(255, 255, 255, 0.05); } button { background: linear-gradient(45deg, #2196F3, #1976D2); border: none; border-radius: 5px; transition: all 0.3s ease; color: white; } button:hover { transform: translateY(-2px); box-shadow: 0 4px 15px rgba(33,150,243,0.3); } textarea, input[type="text"], input[type="number"] { background: rgba(30, 30, 30, 0.95) !important; border: 1px solid rgba(255, 255, 255, 0.1) !important; color: #e0e0e0 !important; border-radius: 5px !important; } label { color: #e0e0e0 !important; } .gallery { background: rgba(30, 30, 30, 0.95); padding: 15px; border-radius: 10px; border: 1px solid rgba(255, 255, 255, 0.05); } """ css = """ footer { visibility: hidden; } """ + custom_css # Gradio 인터페이스 생성 text_to_audio_tab = gr.Interface( fn=text_to_audio, inputs=[ gr.Textbox(label="Prompt(한글지원)" if translator else "Prompt"), gr.Textbox(label="Negative Prompt"), gr.Number(label="Seed", value=0), gr.Number(label="Steps", value=25), gr.Number(label="Guidance Scale", value=4.5), gr.Number(label="Duration (sec)", value=8), ], outputs=gr.Audio(label="Generated Audio"), css=custom_css ) video_to_audio_tab = gr.Interface( fn=video_to_audio, inputs=[ gr.Video(label="Input Video"), gr.Textbox(label="Prompt(한글지원)" if translator else "Prompt"), gr.Textbox(label="Negative Prompt", value="music"), gr.Number(label="Seed", value=0), gr.Number(label="Steps", value=25), gr.Number(label="Guidance Scale", value=4.5), gr.Number(label="Duration (sec)", value=8), ], outputs=gr.Video(label="Generated Result"), css=custom_css ) video_search_tab = gr.Interface( fn=search_videos, inputs=gr.Textbox(label="Search Query(한글지원)" if translator else "Search Query"), outputs=gr.Gallery(label="Search Results", columns=4, rows=20), css=custom_css, api_name=False ) # 메인 실행 if __name__ == "__main__": # 번역기 로드 실패 시 경고 메시지 if translator is None: log.warning("Translation model failed to load. Korean translation will be disabled.") gr.TabbedInterface( [video_search_tab, video_to_audio_tab, text_to_audio_tab], ["Video Search", "Video-to-Audio", "Text-to-Audio"], theme="soft", css=css ).launch(allowed_paths=[output_dir])