File size: 4,399 Bytes
06f6c9e
 
 
 
 
78ea8dc
 
 
 
bfbdf81
78ea8dc
4ae4657
78ea8dc
bfbdf81
78ea8dc
4ae4657
78ea8dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ae4657
78ea8dc
 
87e6f23
78ea8dc
87e6f23
 
 
4ae4657
 
06f6c9e
78ea8dc
4ae4657
87e6f23
c3cf7db
20017db
8e74b09
78ea8dc
ad5c75b
9472531
06f6c9e
 
ad5c75b
4b414b1
 
 
78ea8dc
20017db
78ea8dc
20017db
78ea8dc
06f6c9e
4b414b1
8e74b09
06f6c9e
4b414b1
06f6c9e
78ea8dc
20017db
 
 
 
4b414b1
06f6c9e
4b414b1
8e74b09
20017db
 
4b414b1
78ea8dc
4b414b1
 
 
8e74b09
 
 
 
78ea8dc
06f6c9e
4b414b1
 
8e74b09
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
import sys
import types
import subprocess
import tempfile
import torch
import gradio as gr
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
from PIL import Image

# ── ν™˜κ²½ λ³€μˆ˜ ────────────────────────────────────────────────
os.environ["HF_FORCE_SAFE_SERIALIZATION"] = "1"
os.environ["XFORMERS_FORCE_DISABLE"] = "1"      # μ‹€μ œ xformers λΉ„ν™œμ„±ν™”

# ── ✨ xformers 더미 λͺ¨λ“ˆ μ‚½μž… ─────────────────────────────────
dummy = types.ModuleType("xformers")
dummy.__version__ = "0.0.0"

# ν•˜μœ„ λͺ¨λ“ˆ xformers.ops
ops = types.ModuleType("xformers.ops")

def _fake_memory_efficient_attention(q, k, v, *_, dropout_p: float = 0.0, **__):
    """
    xformers.memory_efficient_attention λŒ€μ²΄ κ΅¬ν˜„.
    PyTorch 2.x κΈ°λ³Έ S-DPA둜 μ²˜λ¦¬ν•΄ μ†λ„λŠ” λŠλ¦¬μ§€λ§Œ CPUμ—μ„œλ„ λ™μž‘ν•©λ‹ˆλ‹€.
    """
    return torch.nn.functional.scaled_dot_product_attention(
        q, k, v, dropout_p=dropout_p, is_causal=False
    )

class _FakeLowerTriangularMask:  # audiocraft λ‚΄λΆ€ νƒ€μž… 체크용 더미
    pass

ops.memory_efficient_attention = _fake_memory_efficient_attention
ops.LowerTriangularMask = _FakeLowerTriangularMask

dummy.ops = ops
sys.modules["xformers"] = dummy
sys.modules["xformers.ops"] = ops
# ────────────────────────────────────────────────────────────

# ── audiocraft 동적 μ„€μΉ˜ ─────────────────────────────────────
try:
    from audiocraft.models import MusicGen
except ModuleNotFoundError:
    subprocess.check_call([
        sys.executable, "-m", "pip", "install",
        "git+https://github.com/facebookresearch/audiocraft@main",
        "--use-pep517"
    ])
    from audiocraft.models import MusicGen

from audiocraft.data.audio import audio_write

# ── 이미지 캑셔닝 λͺ¨λΈ ─────────────────────────────────────
caption_model = VisionEncoderDecoderModel.from_pretrained(
    "nlpconnect/vit-gpt2-image-captioning",
    use_safetensors=True,
    low_cpu_mem_usage=True
)
feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

# ── MusicGen ───────────────────────────────────────────────
musicgen = MusicGen.get_pretrained("facebook/musicgen-small")
musicgen.set_generation_params(duration=10)  # 10초 μŒμ•…

# ── μœ ν‹Έ ν•¨μˆ˜λ“€ ─────────────────────────────────────────────
def generate_caption(image: Image.Image) -> str:
    pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
    output_ids = caption_model.generate(pixel_values, max_length=50)
    return tokenizer.decode(output_ids[0], skip_special_tokens=True)

def generate_music(prompt: str) -> str:
    wav = musicgen.generate([prompt])          # batch size = 1
    tmp_dir = tempfile.mkdtemp()
    audio_path = os.path.join(tmp_dir, "musicgen_output.wav")
    audio_write(audio_path, wav[0], musicgen.sample_rate, strategy="loudness")
    return audio_path

def process(image: Image.Image):
    caption = generate_caption(image)
    prompt = f"A cheerful melody inspired by: {caption}"
    audio_path = generate_music(prompt)
    return caption, audio_path

# ── Gradio UI ──────────────────────────────────────────────
demo = gr.Interface(
    fn=process,
    inputs=gr.Image(type="pil"),
    outputs=[
        gr.Text(label="AIκ°€ μƒμ„±ν•œ κ·Έλ¦Ό μ„€λͺ…"),
        gr.Audio(label="μƒμ„±λœ AI μŒμ•… (MusicGen)")
    ],
    title="🎨 AI κ·Έλ¦Όβ€‘μŒμ•… 생성기",
    description="그림을 μ—…λ‘œλ“œν•˜λ©΄ AIκ°€ μ„€λͺ…을 λ§Œλ“€κ³ , μ„€λͺ…을 λ°”νƒ•μœΌλ‘œ μŒμ•…μ„ 생성해 λ“€λ €μ€λ‹ˆλ‹€."
)

if __name__ == "__main__":
    demo.launch()