File size: 4,058 Bytes
0725a88
 
 
 
 
78ea8dc
bfbdf81
78ea8dc
4ae4657
0725a88
bfbdf81
0725a88
4ae4657
78ea8dc
 
 
0725a88
78ea8dc
 
 
0725a88
78ea8dc
0725a88
78ea8dc
 
4ae4657
78ea8dc
 
87e6f23
0725a88
87e6f23
 
0725a88
 
4ae4657
 
06f6c9e
0725a88
4ae4657
87e6f23
0725a88
8e74b09
78ea8dc
ad5c75b
9472531
0725a88
 
 
 
 
 
 
ad5c75b
4b414b1
0725a88
20017db
0725a88
20017db
0725a88
06f6c9e
4b414b1
0725a88
 
4b414b1
06f6c9e
0725a88
 
 
 
 
4b414b1
06f6c9e
4b414b1
0725a88
 
4b414b1
78ea8dc
4b414b1
 
 
8e74b09
 
 
 
78ea8dc
0725a88
4b414b1
 
8e74b09
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os, sys, types, subprocess, tempfile
import torch, gradio as gr
from transformers import (
    VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
)
from PIL import Image

# ── ν™˜κ²½ λ³€μˆ˜ ────────────────────────────────────────────────
os.environ["HF_FORCE_SAFE_SERIALIZATION"] = "1"
os.environ["XFORMERS_FORCE_DISABLE"] = "1"    # audiocraft λ‚΄λΆ€ ν”Œλž˜κ·Έ

# ── xformers 더미 λͺ¨λ“ˆ ───────────────────────────────────────
dummy = types.ModuleType("xformers")
dummy.__version__ = "0.0.0"
ops = types.ModuleType("xformers.ops")

def _fake_mem_eff_attn(q, k, v, *_, dropout_p: float = 0.0, **__):
    return torch.nn.functional.scaled_dot_product_attention(
        q, k, v, dropout_p=dropout_p, is_causal=False
    )
class _FakeLowerTriangularMask: pass

ops.memory_efficient_attention = _fake_mem_eff_attn
ops.LowerTriangularMask = _FakeLowerTriangularMask
dummy.ops = ops
sys.modules["xformers"] = dummy
sys.modules["xformers.ops"] = ops
# ────────────────────────────────────────────────────────────

# ── audiocraft λ‘œλ“œ (postInstallμ—μ„œ 이미 μ„€μΉ˜λμŒ) ───────────
try:
    from audiocraft.models import MusicGen
    from audiocraft.data.audio import audio_write
except ModuleNotFoundError:                 # μ˜ˆμ™Έμ  둜컬 μ‹€ν–‰ λŒ€λΉ„
    subprocess.check_call([
        sys.executable, "-m", "pip", "install",
        "git+https://github.com/facebookresearch/audiocraft@main",
        "--no-deps", "--use-pep517"
    ])
    from audiocraft.models import MusicGen
    from audiocraft.data.audio import audio_write

# ── 이미지 캑셔닝 λͺ¨λΈ ─────────────────────────────────────
caption_model = VisionEncoderDecoderModel.from_pretrained(
    "nlpconnect/vit-gpt2-image-captioning",
    use_safetensors=True, low_cpu_mem_usage=True
)
feature_extractor = ViTImageProcessor.from_pretrained(
    "nlpconnect/vit-gpt2-image-captioning"
)
tokenizer = AutoTokenizer.from_pretrained(
    "nlpconnect/vit-gpt2-image-captioning"
)

# ── MusicGen λͺ¨λΈ ──────────────────────────────────────────
musicgen = MusicGen.get_pretrained("facebook/musicgen-small")
musicgen.set_generation_params(duration=10)

# ── νŒŒμ΄ν”„λΌμΈ ν•¨μˆ˜λ“€ ──────────────────────────────────────
def generate_caption(image: Image.Image) -> str:
    pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
    ids = caption_model.generate(pixel_values, max_length=50)
    return tokenizer.decode(ids[0], skip_special_tokens=True)

def generate_music(prompt: str) -> str:
    wav = musicgen.generate([prompt])
    tmpdir = tempfile.mkdtemp()
    path = os.path.join(tmpdir, "musicgen.wav")
    audio_write(path, wav[0], musicgen.sample_rate, strategy="loudness")
    return path

def process(image: Image.Image):
    caption = generate_caption(image)
    path = generate_music(f"A cheerful melody inspired by: {caption}")
    return caption, path

# ── Gradio UI ──────────────────────────────────────────────
demo = gr.Interface(
    fn=process,
    inputs=gr.Image(type="pil"),
    outputs=[
        gr.Text(label="AIκ°€ μƒμ„±ν•œ κ·Έλ¦Ό μ„€λͺ…"),
        gr.Audio(label="μƒμ„±λœ AI μŒμ•… (MusicGen)")
    ],
    title="🎨 AI κ·Έλ¦Όβ€‘μŒμ•… 생성기",
    description="그림을 μ—…λ‘œλ“œν•˜λ©΄ AIκ°€ μ„€λͺ…을 λ§Œλ“€κ³ , μ„€λͺ…을 λ°”νƒ•μœΌλ‘œ μŒμ•…μ„ 10μ΄ˆκ°„ 생성해 λ“€λ €μ€λ‹ˆλ‹€."
)

if __name__ == "__main__":
    demo.launch()