aiqcamp's picture
Update app.py
8bad790 verified
raw
history blame
18.6 kB
# 1. 먼저 로깅 설정
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 2. spaces를 먼저 import
import spaces
# 3. 나머지 imports
import os
import time
from datetime import datetime
import gradio as gr
import torch
import requests
from pathlib import Path
import cv2
from PIL import Image
import json
import torchaudio
import tempfile
# 4. GPU 초기화 설정
if torch.cuda.is_available():
device = torch.device('cuda')
logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
device = torch.device('cpu')
logger.warning("GPU not available, using CPU")
try:
import mmaudio
except ImportError:
os.system("pip install -e .")
import mmaudio
# 나머지 imports
from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video,
setup_eval_logging)
from mmaudio.model.flow_matching import FlowMatching
from mmaudio.model.networks import MMAudio, get_my_mmaudio
from mmaudio.model.sequence_config import SequenceConfig
from mmaudio.model.utils.features_utils import FeaturesUtils
# 번역 모델 import
from transformers import pipeline
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
# API 설정
CATBOX_USER_HASH = "30f52c895fd9d9cb387eee489"
REPLICATE_API_TOKEN = os.getenv("API_KEY")
# 오디오 모델 설정
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
# 5. get_model 함수 정의
def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
seq_cfg = model.seq_cfg
net: MMAudio = get_my_mmaudio(model.model_name).to(device)
if torch.cuda.is_available():
net = net.to(dtype)
net.eval()
net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
logger.info(f'Loaded weights from {model.model_path}')
feature_utils = FeaturesUtils(
tod_vae_ckpt=model.vae_path,
synchformer_ckpt=model.synchformer_ckpt,
enable_conditions=True,
mode=model.mode,
bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
need_vae_encoder=False
).to(device)
if torch.cuda.is_available():
feature_utils = feature_utils.to(dtype)
feature_utils.eval()
return net, feature_utils, seq_cfg
# 6. 모델 초기화
model: ModelConfig = all_model_cfg['large_44k_v2']
model.download_if_needed()
output_dir = Path('./output/gradio')
setup_eval_logging()
net, feature_utils, seq_cfg = get_model()
@spaces.GPU(duration=30)
@torch.inference_mode()
def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
seed: int = -1, num_steps: int = 15,
cfg_strength: float = 4.0, target_duration: float = None):
try:
logger.info("Starting audio generation process")
if torch.cuda.is_available():
torch.cuda.empty_cache()
# 비디오 길이 확인
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
video_duration = total_frames / fps
cap.release()
# 실제 비디오 길이를 target_duration으로 사용
target_duration = video_duration
logger.info(f"Video duration: {target_duration} seconds")
rng = torch.Generator(device=device)
if seed >= 0:
rng.manual_seed(seed)
else:
rng.seed()
fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
# 비디오 길이에 맞춰 load_video 호출
video_info = load_video(video_path, duration_sec=target_duration)
if video_info is None:
logger.error("Failed to load video")
return video_path
clip_frames = video_info.clip_frames
sync_frames = video_info.sync_frames
actual_duration = video_info.duration_sec
if clip_frames is None or sync_frames is None:
logger.error("Failed to extract frames from video")
return video_path
# 실제 비디오 프레임 수에 맞춰 조정
clip_frames = clip_frames[:int(actual_duration * video_info.fps)]
sync_frames = sync_frames[:int(actual_duration * video_info.fps)]
clip_frames = clip_frames.unsqueeze(0).to(device, dtype=torch.float16)
sync_frames = sync_frames.unsqueeze(0).to(device, dtype=torch.float16)
# sequence config 업데이트
seq_cfg.duration = actual_duration
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
logger.info(f"Generating audio for {actual_duration} seconds...")
logger.info("Generating audio...")
with torch.cuda.amp.autocast():
audios = generate(clip_frames,
sync_frames,
[prompt],
negative_text=[negative_prompt],
feature_utils=feature_utils,
net=net,
fm=fm,
rng=rng,
cfg_strength=cfg_strength)
if audios is None:
logger.error("Failed to generate audio")
return video_path
audio = audios.float().cpu()[0]
output_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
logger.info(f"Creating final video with audio at {output_path}")
make_video(video_info, output_path, audio, sampling_rate=seq_cfg.sampling_rate)
torch.cuda.empty_cache()
if not os.path.exists(output_path):
logger.error("Failed to create output video")
return video_path
logger.info(f'Successfully saved video with audio to {output_path}')
return output_path
except Exception as e:
logger.error(f"Error in video_to_audio: {str(e)}")
torch.cuda.empty_cache()
return video_path
def upload_to_catbox(file_path):
"""catbox.moe API를 사용하여 파일 업로드"""
try:
logger.info(f"Preparing to upload file: {file_path}")
url = "https://catbox.moe/user/api.php"
mime_types = {
'.jpg': 'image/jpeg',
'.jpeg': 'image/jpeg',
'.png': 'image/png',
'.gif': 'image/gif',
'.webp': 'image/webp',
'.jfif': 'image/jpeg'
}
file_extension = Path(file_path).suffix.lower()
if file_extension not in mime_types:
try:
img = Image.open(file_path)
if img.mode != 'RGB':
img = img.convert('RGB')
new_path = file_path.rsplit('.', 1)[0] + '.png'
img.save(new_path, 'PNG')
file_path = new_path
file_extension = '.png'
logger.info(f"Converted image to PNG: {file_path}")
except Exception as e:
logger.error(f"Failed to convert image: {str(e)}")
return None
files = {
'fileToUpload': (
os.path.basename(file_path),
open(file_path, 'rb'),
mime_types.get(file_extension, 'application/octet-stream')
)
}
data = {
'reqtype': 'fileupload',
'userhash': CATBOX_USER_HASH
}
response = requests.post(url, files=files, data=data)
if response.status_code == 200 and response.text.startswith('http'):
file_url = response.text
logger.info(f"File uploaded successfully: {file_url}")
return file_url
else:
raise Exception(f"Upload failed: {response.text}")
except Exception as e:
logger.error(f"File upload error: {str(e)}")
return None
finally:
if 'new_path' in locals() and os.path.exists(new_path):
try:
os.remove(new_path)
except:
pass
def add_watermark(video_path):
"""OpenCV를 사용하여 비디오에 워터마크 추가"""
try:
cap = cv2.VideoCapture(video_path)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(cap.get(cv2.CAP_PROP_FPS))
text = "GiniGEN.AI"
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = height * 0.05 / 30
thickness = 2
color = (255, 255, 255)
(text_width, text_height), _ = cv2.getTextSize(text, font, font_scale, thickness)
margin = int(height * 0.02)
x_pos = width - text_width - margin
y_pos = height - margin
output_path = "watermarked_output.mp4"
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
cv2.putText(frame, text, (x_pos, y_pos), font, font_scale, color, thickness)
out.write(frame)
cap.release()
out.release()
return output_path
except Exception as e:
logger.error(f"Error adding watermark: {str(e)}")
return video_path
def generate_video(image, prompt):
logger.info("Starting video generation with API")
try:
API_KEY = os.getenv("API_KEY", "").strip()
if not API_KEY:
return "API key not properly configured"
temp_dir = "temp_videos"
os.makedirs(temp_dir, exist_ok=True)
image_url = None
if image:
image_url = upload_to_catbox(image)
if not image_url:
return "Failed to upload image"
logger.info(f"Input image URL: {image_url}")
generation_url = "https://api.minimaxi.chat/v1/video_generation"
headers = {
'authorization': f'Bearer {API_KEY}',
'Content-Type': 'application/json'
}
payload = {
"model": "video-01",
"prompt": prompt if prompt else "",
"prompt_optimizer": True
}
if image_url:
payload["first_frame_image"] = image_url
logger.info(f"Sending request with payload: {payload}")
response = requests.post(generation_url, headers=headers, json=payload)
if not response.ok:
error_msg = f"Failed to create video generation task: {response.text}"
logger.error(error_msg)
return error_msg
response_data = response.json()
task_id = response_data.get('task_id')
if not task_id:
return "Failed to get task ID from response"
query_url = "https://api.minimaxi.chat/v1/query/video_generation"
max_attempts = 30
attempt = 0
while attempt < max_attempts:
time.sleep(10)
query_response = requests.get(
f"{query_url}?task_id={task_id}",
headers={'authorization': f'Bearer {API_KEY}'}
)
if not query_response.ok:
attempt += 1
continue
status_data = query_response.json()
status = status_data.get('status')
if status == 'Success':
file_id = status_data.get('file_id')
if not file_id:
return "Failed to get file ID"
retrieve_url = "https://api.minimaxi.chat/v1/files/retrieve"
params = {'file_id': file_id}
file_response = requests.get(
retrieve_url,
headers={'authorization': f'Bearer {API_KEY}'},
params=params
)
if not file_response.ok:
return "Failed to retrieve video file"
try:
file_data = file_response.json()
download_url = file_data.get('file', {}).get('download_url')
if not download_url:
return "Failed to get download URL"
result_info = {
"timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
"input_image": image_url,
"output_video_url": download_url,
"prompt": prompt
}
logger.info(f"Video generation result: {json.dumps(result_info, indent=2)}")
video_response = requests.get(download_url)
if not video_response.ok:
return "Failed to download video"
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_path = os.path.join(temp_dir, f"output_{timestamp}.mp4")
with open(output_path, 'wb') as f:
f.write(video_response.content)
final_path = add_watermark(output_path)
# 비디오 길이 확인
cap = cv2.VideoCapture(final_path)
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
video_duration = total_frames / fps
cap.release()
logger.info(f"Original video duration: {video_duration} seconds")
# 오디오 처리 추가
try:
logger.info("Starting audio generation process")
final_path_with_audio = video_to_audio(
final_path,
prompt=prompt,
negative_prompt="music",
seed=-1,
num_steps=20,
cfg_strength=4.5
# target_duration 제거 - 자동으로 비디오 길이 사용
)
if final_path_with_audio != final_path:
logger.info("Audio generation successful")
try:
if output_path != final_path:
os.remove(output_path)
if final_path != final_path_with_audio:
os.remove(final_path)
except Exception as e:
logger.warning(f"Error cleaning up temporary files: {str(e)}")
return final_path_with_audio
else:
logger.warning("Audio generation skipped, using original video")
return final_path
except Exception as e:
logger.error(f"Error in audio processing: {str(e)}")
return final_path # 오디오 처리 실패 시 워터마크만 된 비디오 반환
except Exception as e:
logger.error(f"Error processing video file: {str(e)}")
return "Error processing video file"
elif status == 'Fail':
return "Video generation failed"
attempt += 1
return "Timeout waiting for video generation"
except Exception as e:
logger.error(f"Error in video generation: {str(e)}")
return f"Error in video generation process: {str(e)}"
css = """
footer {
visibility: hidden;
}
.gradio-container {max-width: 1200px !important}
"""
with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
gr.HTML('<div class="title">🎥 Dokdo Multimodal✨ "Prompt guide for automated video and sound synthesis from images" </div>')
gr.HTML('<div class="title">😄 Explore: <a href="https://huggingface.co/spaces/ginigen/theater" target="_blank">https://huggingface.co/spaces/ginigen/theater</a></div>')
with gr.Row():
with gr.Column(scale=3):
video_prompt = gr.Textbox(
label="Video Description",
placeholder="Enter video description...",
lines=3
)
upload_image = gr.Image(type="filepath", label="Upload First Frame Image")
video_generate_btn = gr.Button("🎬 Generate Video")
with gr.Column(scale=4):
video_output = gr.Video(label="Generated Video")
# process_and_generate_video 함수 수정
def process_and_generate_video(image, prompt):
if image is None:
return "Please upload an image"
try:
# 한글 프롬프트 감지 및 번역
contains_korean = any(ord('가') <= ord(char) <= ord('힣') for char in prompt)
if contains_korean:
translated = translator(prompt)[0]['translation_text']
logger.info(f"Translated prompt from '{prompt}' to '{translated}'")
prompt = translated
img = Image.open(image)
if img.mode != 'RGB':
img = img.convert('RGB')
temp_path = f"temp_{int(time.time())}.png"
img.save(temp_path, 'PNG')
result = generate_video(temp_path, prompt)
try:
os.remove(temp_path)
except:
pass
return result
except Exception as e:
logger.error(f"Error processing image: {str(e)}")
return "Error processing image"
video_generate_btn.click(
process_and_generate_video,
inputs=[upload_image, video_prompt],
outputs=video_output
)
if __name__ == "__main__":
# GPU 초기화 확인
if torch.cuda.is_available():
logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
logger.warning("GPU not available, using CPU")
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)