File size: 2,770 Bytes
87e6f23
 
 
 
 
 
 
 
 
 
 
 
c3cf7db
4b414b1
 
20017db
 
4b414b1
 
20017db
8e74b09
 
ad5c75b
9472531
 
 
ad5c75b
4b414b1
 
 
20017db
 
 
 
8e74b09
4b414b1
 
8e74b09
4b414b1
 
 
20017db
8e74b09
20017db
 
 
 
 
4b414b1
8e74b09
4b414b1
 
8e74b09
20017db
 
4b414b1
8e74b09
4b414b1
 
 
8e74b09
 
 
 
 
 
4b414b1
 
8e74b09
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os, subprocess, sys
os.environ["HF_FORCE_SAFE_SERIALIZATION"] = "1"

# ─── audiocraft 동적 μ„€μΉ˜ (ν•œ 번만) ───
try:
    from audiocraft.models import MusicGen
except ModuleNotFoundError:
    subprocess.check_call(
        [sys.executable, "-m", "pip", "install",
         "git+https://github.com/facebookresearch/audiocraft@main", "--no-deps"]
    )
    from audiocraft.models import MusicGen

import gradio as gr
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
from audiocraft.models import MusicGen
from audiocraft.data.audio import audio_write
from PIL import Image
import torch
import tempfile

# ───── 이미지 캑셔닝 λͺ¨λΈ λ‘œλ”© ─────
caption_model = VisionEncoderDecoderModel.from_pretrained(
    "nlpconnect/vit-gpt2-image-captioning",
    use_safetensors=True,                 # βœ… μ˜¬λ°”λ₯Έ μ˜΅μ…˜λͺ…
    low_cpu_mem_usage=True               # (μ˜΅μ…˜) λ©”λͺ¨λ¦¬ μ ˆμ•½
)
feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

# ───── MusicGen λͺ¨λΈ λ‘œλ”© ─────
musicgen = MusicGen.get_pretrained("facebook/musicgen-small")
musicgen.set_generation_params(duration=10)  # 생성할 μŒμ•… 길이 (초)

# ───── 이미지 β†’ μ„€λͺ… λ¬Έμž₯ 생성 ─────
def generate_caption(image):
    pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
    output_ids = caption_model.generate(pixel_values, max_length=50)
    caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    return caption

# ───── μ„€λͺ… β†’ μŒμ•… 생성 ─────
def generate_music(prompt):
    wav = musicgen.generate([prompt])  # batch size 1
    tmp_dir = tempfile.mkdtemp()
    audio_path = os.path.join(tmp_dir, "musicgen_output.wav")
    audio_write(audio_path, wav[0], musicgen.sample_rate, strategy="loudness")
    return audio_path

# ───── 전체 νŒŒμ΄ν”„λΌμΈ μ—°κ²° ─────
def process(image):
    caption = generate_caption(image)
    prompt = f"A cheerful melody inspired by: {caption}"
    audio_path = generate_music(prompt)
    return caption, audio_path

# ───── Gradio μΈν„°νŽ˜μ΄μŠ€ ꡬ성 ─────
demo = gr.Interface(
    fn=process,
    inputs=gr.Image(type="pil"),
    outputs=[
        gr.Text(label="AIκ°€ μƒμ„±ν•œ κ·Έλ¦Ό μ„€λͺ…"),
        gr.Audio(label="μƒμ„±λœ AI μŒμ•… (MusicGen)")
    ],
    title="🎨 AI κ·Έλ¦Ό μŒμ•… 생성기",
    description="그림을 μ—…λ‘œλ“œν•˜λ©΄ AIκ°€ μ„€λͺ…을 λ§Œλ“€κ³ , μ„€λͺ…을 λ°”νƒ•μœΌλ‘œ μŒμ•…μ„ λ§Œλ“€μ–΄ λ“€λ €μ€λ‹ˆλ‹€."
)

if __name__ == "__main__":
    demo.launch()