Spaces:
Runtime error
Runtime error
import os, sys, types, subprocess, tempfile | |
import torch, gradio as gr | |
from transformers import ( | |
VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer | |
) | |
from PIL import Image | |
# ββ νκ²½ λ³μ ββββββββββββββββββββββββββββββββββββββββββββββββ | |
os.environ["HF_FORCE_SAFE_SERIALIZATION"] = "1" | |
os.environ["XFORMERS_FORCE_DISABLE"] = "1" # audiocraft λ΄λΆ νλκ·Έ | |
# ββ xformers λλ―Έ λͺ¨λ βββββββββββββββββββββββββββββββββββββββ | |
dummy = types.ModuleType("xformers") | |
dummy.__version__ = "0.0.0" | |
ops = types.ModuleType("xformers.ops") | |
def _fake_mem_eff_attn(q, k, v, *_, dropout_p: float = 0.0, **__): | |
return torch.nn.functional.scaled_dot_product_attention( | |
q, k, v, dropout_p=dropout_p, is_causal=False | |
) | |
class _FakeLowerTriangularMask: pass | |
ops.memory_efficient_attention = _fake_mem_eff_attn | |
ops.LowerTriangularMask = _FakeLowerTriangularMask | |
dummy.ops = ops | |
sys.modules["xformers"] = dummy | |
sys.modules["xformers.ops"] = ops | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# ββ audiocraft λ‘λ (postInstallμμ μ΄λ―Έ μ€μΉλμ) βββββββββββ | |
try: | |
from audiocraft.models import MusicGen | |
from audiocraft.data.audio import audio_write | |
except ModuleNotFoundError: # μμΈμ λ‘컬 μ€ν λλΉ | |
subprocess.check_call([ | |
sys.executable, "-m", "pip", "install", | |
"git+https://github.com/facebookresearch/audiocraft@main", | |
"--no-deps", "--use-pep517" | |
]) | |
from audiocraft.models import MusicGen | |
from audiocraft.data.audio import audio_write | |
# ββ μ΄λ―Έμ§ μΊ‘μ λ λͺ¨λΈ βββββββββββββββββββββββββββββββββββββ | |
caption_model = VisionEncoderDecoderModel.from_pretrained( | |
"nlpconnect/vit-gpt2-image-captioning", | |
use_safetensors=True, low_cpu_mem_usage=True | |
) | |
feature_extractor = ViTImageProcessor.from_pretrained( | |
"nlpconnect/vit-gpt2-image-captioning" | |
) | |
tokenizer = AutoTokenizer.from_pretrained( | |
"nlpconnect/vit-gpt2-image-captioning" | |
) | |
# ββ MusicGen λͺ¨λΈ ββββββββββββββββββββββββββββββββββββββββββ | |
musicgen = MusicGen.get_pretrained("facebook/musicgen-small") | |
musicgen.set_generation_params(duration=10) | |
# ββ νμ΄νλΌμΈ ν¨μλ€ ββββββββββββββββββββββββββββββββββββββ | |
def generate_caption(image: Image.Image) -> str: | |
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values | |
ids = caption_model.generate(pixel_values, max_length=50) | |
return tokenizer.decode(ids[0], skip_special_tokens=True) | |
def generate_music(prompt: str) -> str: | |
wav = musicgen.generate([prompt]) | |
tmpdir = tempfile.mkdtemp() | |
path = os.path.join(tmpdir, "musicgen.wav") | |
audio_write(path, wav[0], musicgen.sample_rate, strategy="loudness") | |
return path | |
def process(image: Image.Image): | |
caption = generate_caption(image) | |
path = generate_music(f"A cheerful melody inspired by: {caption}") | |
return caption, path | |
# ββ Gradio UI ββββββββββββββββββββββββββββββββββββββββββββββ | |
demo = gr.Interface( | |
fn=process, | |
inputs=gr.Image(type="pil"), | |
outputs=[ | |
gr.Text(label="AIκ° μμ±ν κ·Έλ¦Ό μ€λͺ "), | |
gr.Audio(label="μμ±λ AI μμ (MusicGen)") | |
], | |
title="π¨ AI κ·Έλ¦Όβμμ μμ±κΈ°", | |
description="κ·Έλ¦Όμ μ λ‘λνλ©΄ AIκ° μ€λͺ μ λ§λ€κ³ , μ€λͺ μ λ°νμΌλ‘ μμ μ 10μ΄κ° μμ±ν΄ λ€λ €μ€λλ€." | |
) | |
if __name__ == "__main__": | |
demo.launch() | |