File size: 2,217 Bytes
4b414b1
 
20017db
 
4b414b1
 
7d967cc
20017db
8e74b09
 
4b414b1
 
 
 
20017db
 
 
 
8e74b09
4b414b1
 
8e74b09
4b414b1
 
 
20017db
8e74b09
20017db
 
 
 
 
4b414b1
8e74b09
4b414b1
 
8e74b09
20017db
 
4b414b1
8e74b09
4b414b1
 
 
8e74b09
 
 
 
 
 
4b414b1
 
8e74b09
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import gradio as gr
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
from audiocraft.models import MusicGen
from audiocraft.data.audio import audio_write
from PIL import Image
import torch
import os
import tempfile

# ───── 이미지 캑셔닝 λͺ¨λΈ λ‘œλ”© ─────
caption_model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

# ───── MusicGen λͺ¨λΈ λ‘œλ”© ─────
musicgen = MusicGen.get_pretrained("facebook/musicgen-small")
musicgen.set_generation_params(duration=10)  # 생성할 μŒμ•… 길이 (초)

# ───── 이미지 β†’ μ„€λͺ… λ¬Έμž₯ 생성 ─────
def generate_caption(image):
    pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
    output_ids = caption_model.generate(pixel_values, max_length=50)
    caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    return caption

# ───── μ„€λͺ… β†’ μŒμ•… 생성 ─────
def generate_music(prompt):
    wav = musicgen.generate([prompt])  # batch size 1
    tmp_dir = tempfile.mkdtemp()
    audio_path = os.path.join(tmp_dir, "musicgen_output.wav")
    audio_write(audio_path, wav[0], musicgen.sample_rate, strategy="loudness")
    return audio_path

# ───── 전체 νŒŒμ΄ν”„λΌμΈ μ—°κ²° ─────
def process(image):
    caption = generate_caption(image)
    prompt = f"A cheerful melody inspired by: {caption}"
    audio_path = generate_music(prompt)
    return caption, audio_path

# ───── Gradio μΈν„°νŽ˜μ΄μŠ€ ꡬ성 ─────
demo = gr.Interface(
    fn=process,
    inputs=gr.Image(type="pil"),
    outputs=[
        gr.Text(label="AIκ°€ μƒμ„±ν•œ κ·Έλ¦Ό μ„€λͺ…"),
        gr.Audio(label="μƒμ„±λœ AI μŒμ•… (MusicGen)")
    ],
    title="🎨 AI κ·Έλ¦Ό μŒμ•… 생성기",
    description="그림을 μ—…λ‘œλ“œν•˜λ©΄ AIκ°€ μ„€λͺ…을 λ§Œλ“€κ³ , μ„€λͺ…을 λ°”νƒ•μœΌλ‘œ μŒμ•…μ„ λ§Œλ“€μ–΄ λ“€λ €μ€λ‹ˆλ‹€."
)

if __name__ == "__main__":
    demo.launch()