Spaces:
Runtime error
Runtime error
import os | |
os.environ["HF_FORCE_SAFE_SERIALIZATION"] = "1" # μ΄λ―Έ μμ λͺ¨λ # β safetensorsλ§ μ¬μ©νλλ‘ κ°μ | |
import gradio as gr | |
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer | |
from audiocraft.models import MusicGen | |
from audiocraft.data.audio import audio_write | |
from PIL import Image | |
import torch | |
import tempfile | |
# βββββ μ΄λ―Έμ§ μΊ‘μ λ λͺ¨λΈ λ‘λ© βββββ | |
caption_model = VisionEncoderDecoderModel.from_pretrained( | |
"nlpconnect/vit-gpt2-image-captioning", | |
use_safetensors=True, # β μ¬λ°λ₯Έ μ΅μ λͺ | |
low_cpu_mem_usage=True # (μ΅μ ) λ©λͺ¨λ¦¬ μ μ½ | |
) | |
feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning") | |
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning") | |
# βββββ MusicGen λͺ¨λΈ λ‘λ© βββββ | |
musicgen = MusicGen.get_pretrained("facebook/musicgen-small") | |
musicgen.set_generation_params(duration=10) # μμ±ν μμ κΈΈμ΄ (μ΄) | |
# βββββ μ΄λ―Έμ§ β μ€λͺ λ¬Έμ₯ μμ± βββββ | |
def generate_caption(image): | |
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values | |
output_ids = caption_model.generate(pixel_values, max_length=50) | |
caption = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
return caption | |
# βββββ μ€λͺ β μμ μμ± βββββ | |
def generate_music(prompt): | |
wav = musicgen.generate([prompt]) # batch size 1 | |
tmp_dir = tempfile.mkdtemp() | |
audio_path = os.path.join(tmp_dir, "musicgen_output.wav") | |
audio_write(audio_path, wav[0], musicgen.sample_rate, strategy="loudness") | |
return audio_path | |
# βββββ μ 체 νμ΄νλΌμΈ μ°κ²° βββββ | |
def process(image): | |
caption = generate_caption(image) | |
prompt = f"A cheerful melody inspired by: {caption}" | |
audio_path = generate_music(prompt) | |
return caption, audio_path | |
# βββββ Gradio μΈν°νμ΄μ€ κ΅¬μ± βββββ | |
demo = gr.Interface( | |
fn=process, | |
inputs=gr.Image(type="pil"), | |
outputs=[ | |
gr.Text(label="AIκ° μμ±ν κ·Έλ¦Ό μ€λͺ "), | |
gr.Audio(label="μμ±λ AI μμ (MusicGen)") | |
], | |
title="π¨ AI κ·Έλ¦Ό μμ μμ±κΈ°", | |
description="κ·Έλ¦Όμ μ λ‘λνλ©΄ AIκ° μ€λͺ μ λ§λ€κ³ , μ€λͺ μ λ°νμΌλ‘ μμ μ λ§λ€μ΄ λ€λ €μ€λλ€." | |
) | |
if __name__ == "__main__": | |
demo.launch() | |