Spaces:
Runtime error
Runtime error
File size: 4,399 Bytes
06f6c9e 78ea8dc bfbdf81 78ea8dc 4ae4657 78ea8dc bfbdf81 78ea8dc 4ae4657 78ea8dc 4ae4657 78ea8dc 87e6f23 78ea8dc 87e6f23 4ae4657 06f6c9e 78ea8dc 4ae4657 87e6f23 c3cf7db 20017db 8e74b09 78ea8dc ad5c75b 9472531 06f6c9e ad5c75b 4b414b1 78ea8dc 20017db 78ea8dc 20017db 78ea8dc 06f6c9e 4b414b1 8e74b09 06f6c9e 4b414b1 06f6c9e 78ea8dc 20017db 4b414b1 06f6c9e 4b414b1 8e74b09 20017db 4b414b1 78ea8dc 4b414b1 8e74b09 78ea8dc 06f6c9e 4b414b1 8e74b09 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import os
import sys
import types
import subprocess
import tempfile
import torch
import gradio as gr
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
from PIL import Image
# ββ νκ²½ λ³μ ββββββββββββββββββββββββββββββββββββββββββββββββ
os.environ["HF_FORCE_SAFE_SERIALIZATION"] = "1"
os.environ["XFORMERS_FORCE_DISABLE"] = "1" # μ€μ xformers λΉνμ±ν
# ββ β¨ xformers λλ―Έ λͺ¨λ μ½μ
βββββββββββββββββββββββββββββββββ
dummy = types.ModuleType("xformers")
dummy.__version__ = "0.0.0"
# νμ λͺ¨λ xformers.ops
ops = types.ModuleType("xformers.ops")
def _fake_memory_efficient_attention(q, k, v, *_, dropout_p: float = 0.0, **__):
"""
xformers.memory_efficient_attention λ체 ꡬν.
PyTorch 2.x κΈ°λ³Έ S-DPAλ‘ μ²λ¦¬ν΄ μλλ λ리μ§λ§ CPUμμλ λμν©λλ€.
"""
return torch.nn.functional.scaled_dot_product_attention(
q, k, v, dropout_p=dropout_p, is_causal=False
)
class _FakeLowerTriangularMask: # audiocraft λ΄λΆ νμ
체ν¬μ© λλ―Έ
pass
ops.memory_efficient_attention = _fake_memory_efficient_attention
ops.LowerTriangularMask = _FakeLowerTriangularMask
dummy.ops = ops
sys.modules["xformers"] = dummy
sys.modules["xformers.ops"] = ops
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# ββ audiocraft λμ μ€μΉ βββββββββββββββββββββββββββββββββββββ
try:
from audiocraft.models import MusicGen
except ModuleNotFoundError:
subprocess.check_call([
sys.executable, "-m", "pip", "install",
"git+https://github.com/facebookresearch/audiocraft@main",
"--use-pep517"
])
from audiocraft.models import MusicGen
from audiocraft.data.audio import audio_write
# ββ μ΄λ―Έμ§ μΊ‘μ
λ λͺ¨λΈ βββββββββββββββββββββββββββββββββββββ
caption_model = VisionEncoderDecoderModel.from_pretrained(
"nlpconnect/vit-gpt2-image-captioning",
use_safetensors=True,
low_cpu_mem_usage=True
)
feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
# ββ MusicGen βββββββββββββββββββββββββββββββββββββββββββββββ
musicgen = MusicGen.get_pretrained("facebook/musicgen-small")
musicgen.set_generation_params(duration=10) # 10μ΄ μμ
# ββ μ νΈ ν¨μλ€ βββββββββββββββββββββββββββββββββββββββββββββ
def generate_caption(image: Image.Image) -> str:
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
output_ids = caption_model.generate(pixel_values, max_length=50)
return tokenizer.decode(output_ids[0], skip_special_tokens=True)
def generate_music(prompt: str) -> str:
wav = musicgen.generate([prompt]) # batch size = 1
tmp_dir = tempfile.mkdtemp()
audio_path = os.path.join(tmp_dir, "musicgen_output.wav")
audio_write(audio_path, wav[0], musicgen.sample_rate, strategy="loudness")
return audio_path
def process(image: Image.Image):
caption = generate_caption(image)
prompt = f"A cheerful melody inspired by: {caption}"
audio_path = generate_music(prompt)
return caption, audio_path
# ββ Gradio UI ββββββββββββββββββββββββββββββββββββββββββββββ
demo = gr.Interface(
fn=process,
inputs=gr.Image(type="pil"),
outputs=[
gr.Text(label="AIκ° μμ±ν κ·Έλ¦Ό μ€λͺ
"),
gr.Audio(label="μμ±λ AI μμ
(MusicGen)")
],
title="π¨ AI κ·Έλ¦Όβμμ
μμ±κΈ°",
description="κ·Έλ¦Όμ μ
λ‘λνλ©΄ AIκ° μ€λͺ
μ λ§λ€κ³ , μ€λͺ
μ λ°νμΌλ‘ μμ
μ μμ±ν΄ λ€λ €μ€λλ€."
)
if __name__ == "__main__":
demo.launch()
|