Spaces:
Runtime error
Runtime error
import os | |
import sys | |
import types | |
import subprocess | |
import tempfile | |
import torch | |
import gradio as gr | |
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer | |
from PIL import Image | |
# ββ νκ²½ λ³μ ββββββββββββββββββββββββββββββββββββββββββββββββ | |
os.environ["HF_FORCE_SAFE_SERIALIZATION"] = "1" | |
os.environ["XFORMERS_FORCE_DISABLE"] = "1" # μ€μ xformers λΉνμ±ν | |
# ββ β¨ xformers λλ―Έ λͺ¨λ μ½μ βββββββββββββββββββββββββββββββββ | |
dummy = types.ModuleType("xformers") | |
dummy.__version__ = "0.0.0" | |
# νμ λͺ¨λ xformers.ops | |
ops = types.ModuleType("xformers.ops") | |
def _fake_memory_efficient_attention(q, k, v, *_, dropout_p: float = 0.0, **__): | |
""" | |
xformers.memory_efficient_attention λ체 ꡬν. | |
PyTorch 2.x κΈ°λ³Έ S-DPAλ‘ μ²λ¦¬ν΄ μλλ λ리μ§λ§ CPUμμλ λμν©λλ€. | |
""" | |
return torch.nn.functional.scaled_dot_product_attention( | |
q, k, v, dropout_p=dropout_p, is_causal=False | |
) | |
class _FakeLowerTriangularMask: # audiocraft λ΄λΆ νμ 체ν¬μ© λλ―Έ | |
pass | |
ops.memory_efficient_attention = _fake_memory_efficient_attention | |
ops.LowerTriangularMask = _FakeLowerTriangularMask | |
dummy.ops = ops | |
sys.modules["xformers"] = dummy | |
sys.modules["xformers.ops"] = ops | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# ββ audiocraft λμ μ€μΉ βββββββββββββββββββββββββββββββββββββ | |
try: | |
from audiocraft.models import MusicGen | |
except ModuleNotFoundError: | |
subprocess.check_call([ | |
sys.executable, "-m", "pip", "install", | |
"git+https://github.com/facebookresearch/audiocraft@main", | |
"--use-pep517" | |
]) | |
from audiocraft.models import MusicGen | |
from audiocraft.data.audio import audio_write | |
# ββ μ΄λ―Έμ§ μΊ‘μ λ λͺ¨λΈ βββββββββββββββββββββββββββββββββββββ | |
caption_model = VisionEncoderDecoderModel.from_pretrained( | |
"nlpconnect/vit-gpt2-image-captioning", | |
use_safetensors=True, | |
low_cpu_mem_usage=True | |
) | |
feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning") | |
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning") | |
# ββ MusicGen βββββββββββββββββββββββββββββββββββββββββββββββ | |
musicgen = MusicGen.get_pretrained("facebook/musicgen-small") | |
musicgen.set_generation_params(duration=10) # 10μ΄ μμ | |
# ββ μ νΈ ν¨μλ€ βββββββββββββββββββββββββββββββββββββββββββββ | |
def generate_caption(image: Image.Image) -> str: | |
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values | |
output_ids = caption_model.generate(pixel_values, max_length=50) | |
return tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
def generate_music(prompt: str) -> str: | |
wav = musicgen.generate([prompt]) # batch size = 1 | |
tmp_dir = tempfile.mkdtemp() | |
audio_path = os.path.join(tmp_dir, "musicgen_output.wav") | |
audio_write(audio_path, wav[0], musicgen.sample_rate, strategy="loudness") | |
return audio_path | |
def process(image: Image.Image): | |
caption = generate_caption(image) | |
prompt = f"A cheerful melody inspired by: {caption}" | |
audio_path = generate_music(prompt) | |
return caption, audio_path | |
# ββ Gradio UI ββββββββββββββββββββββββββββββββββββββββββββββ | |
demo = gr.Interface( | |
fn=process, | |
inputs=gr.Image(type="pil"), | |
outputs=[ | |
gr.Text(label="AIκ° μμ±ν κ·Έλ¦Ό μ€λͺ "), | |
gr.Audio(label="μμ±λ AI μμ (MusicGen)") | |
], | |
title="π¨ AI κ·Έλ¦Όβμμ μμ±κΈ°", | |
description="κ·Έλ¦Όμ μ λ‘λνλ©΄ AIκ° μ€λͺ μ λ§λ€κ³ , μ€λͺ μ λ°νμΌλ‘ μμ μ μμ±ν΄ λ€λ €μ€λλ€." | |
) | |
if __name__ == "__main__": | |
demo.launch() | |