imagetoaudio / app.py
yongyeol's picture
Update app.py
2a55caa verified
raw
history blame
6.88 kB
import os, sys, types, subprocess, tempfile
import torch, gradio as gr
from transformers import (
VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
)
from PIL import Image
# ─────────────────────────────────────────────────────────────
# 0. ν™˜κ²½ λ³€μˆ˜
# ─────────────────────────────────────────────────────────────
os.environ["HF_FORCE_SAFE_SERIALIZATION"] = "1"
os.environ["XFORMERS_FORCE_DISABLE"] = "1" # audiocraft λ‚΄λΆ€ ν”Œλž˜κ·Έ
# ─────────────────────────────────────────────────────────────
# 1. xformers 더미 λͺ¨λ“ˆ μ£Όμž… (GPU 쒅속 제거)
# ─────────────────────────────────────────────────────────────
dummy = types.ModuleType("xformers")
dummy.__version__ = "0.0.0"
ops = types.ModuleType("xformers.ops")
def _fake_mea(q, k, v, *_, dropout_p: float = 0.0, **__):
# PyTorch 2.x ν‘œμ€€ S-DPA둜 λŒ€μ²΄ (CPUμ—μ„œλ„ λ™μž‘)
return torch.nn.functional.scaled_dot_product_attention(
q, k, v, dropout_p=dropout_p, is_causal=False
)
class _FakeLowerTriangularMask:
"""audiocraftκ°€ νƒ€μž… 쑴재만 ν™•μΈν•˜λ―€λ‘œ 빈 클래슀둜 λŒ€μ²΄"""
pass
ops.memory_efficient_attention = _fake_mea
ops.LowerTriangularMask = _FakeLowerTriangularMask
dummy.ops = ops
sys.modules["xformers"] = dummy
sys.modules["xformers.ops"] = ops
# ─────────────────────────────────────────────────────────────
# 2. 기타 λˆ„λ½ λͺ¨λ“ˆμ— λŒ€ν•œ 더미(stub) μ•ˆμ „λ§
# (requirements.txt 에 이미 μ„€μΉ˜ν•˜μ§€λ§Œ, ν˜Ήμ‹œ 빠져도 λŸ°νƒ€μž„ 톡과)
# ─────────────────────────────────────────────────────────────
for name in (
"av", "librosa", "torchdiffeq", "torchmetrics",
"pesq", "pystoi", "soxr"
):
if name not in sys.modules:
sys.modules[name] = types.ModuleType(name)
# ─────────────────────────────────────────────────────────────
# 3. audiocraft (MusicGen) 뢈러였기
# ─────────────────────────────────────────────────────────────
try:
from audiocraft.models import MusicGen
from audiocraft.data.audio import audio_write
except ModuleNotFoundError:
# 둜컬 μ‹€ν–‰ λ“±μœΌλ‘œ λ―Έμ„€μΉ˜ μ‹œ: μ˜μ‘΄μ„± μ—†λŠ” ν˜•νƒœλ‘œ μ„€μΉ˜
subprocess.check_call([
sys.executable, "-m", "pip", "install",
"git+https://github.com/facebookresearch/audiocraft@main",
"--no-deps", "--use-pep517"
])
# ν•„μš” μ΅œμ†Œ μ˜μ‘΄μ„±λ§Œ 즉석 μ„€μΉ˜ (stubλ‘œλ„ λŒ€λΆ€λΆ„ ν†΅κ³Όν•˜μ§€λ§Œ μ•ˆμ „ν•˜κ²Œ)
subprocess.check_call([sys.executable, "-m", "pip", "install",
"encodec", "torchdiffeq", "torchmetrics",
"librosa", "soxr", "av"])
from audiocraft.models import MusicGen
from audiocraft.data.audio import audio_write
# ─────────────────────────────────────────────────────────────
# 4. 이미지 캑셔닝 λͺ¨λΈ
# ─────────────────────────────────────────────────────────────
caption_model = VisionEncoderDecoderModel.from_pretrained(
"nlpconnect/vit-gpt2-image-captioning",
use_safetensors=True,
low_cpu_mem_usage=True
)
feature_extractor = ViTImageProcessor.from_pretrained(
"nlpconnect/vit-gpt2-image-captioning"
)
tokenizer = AutoTokenizer.from_pretrained(
"nlpconnect/vit-gpt2-image-captioning"
)
# ─────────────────────────────────────────────────────────────
# 5. MusicGen λͺ¨λΈ (CPU μ „μš©)
# ─────────────────────────────────────────────────────────────
musicgen = MusicGen.get_pretrained("facebook/musicgen-small")
musicgen.set_generation_params(duration=10) # 10초 길이
# ─────────────────────────────────────────────────────────────
# 6. νŒŒμ΄ν”„λΌμΈ ν•¨μˆ˜
# ─────────────────────────────────────────────────────────────
def generate_caption(image: Image.Image) -> str:
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
ids = caption_model.generate(pixel_values, max_length=50)
return tokenizer.decode(ids[0], skip_special_tokens=True)
def generate_music(prompt: str) -> str:
wav = musicgen.generate([prompt]) # batch size = 1
tmpdir = tempfile.mkdtemp()
path = os.path.join(tmpdir, "musicgen.wav")
audio_write(path, wav[0], musicgen.sample_rate, strategy="loudness")
return path
def process(image: Image.Image):
caption = generate_caption(image)
path = generate_music(f"A cheerful melody inspired by: {caption}")
return caption, path
# ─────────────────────────────────────────────────────────────
# 7. Gradio UI
# ─────────────────────────────────────────────────────────────
demo = gr.Interface(
fn=process,
inputs=gr.Image(type="pil"),
outputs=[
gr.Text(label="AIκ°€ μƒμ„±ν•œ κ·Έλ¦Ό μ„€λͺ…"),
gr.Audio(label="μƒμ„±λœ AI μŒμ•… (MusicGen)")
],
title="🎨 AI κ·Έλ¦Ό-μŒμ•… 생성기",
description="그림을 μ—…λ‘œλ“œν•˜λ©΄ AIκ°€ μ„€λͺ…을 λ§Œλ“€κ³ , μ„€λͺ…을 λ°”νƒ•μœΌλ‘œ 10초 길이의 μŒμ•…μ„ 생성해 λ“€λ €μ€λ‹ˆλ‹€."
)
if __name__ == "__main__":
demo.launch()