Spaces:
Runtime error
Runtime error
File size: 6,878 Bytes
0725a88 78ea8dc bfbdf81 2a55caa 4ae4657 0725a88 bfbdf81 2a55caa 4ae4657 78ea8dc 2a55caa 78ea8dc 2a55caa 78ea8dc 4ae4657 78ea8dc 87e6f23 2a55caa 87e6f23 0725a88 2a55caa 4ae4657 06f6c9e 0725a88 4ae4657 2a55caa 87e6f23 0725a88 8e74b09 2a55caa ad5c75b 9472531 2a55caa 0725a88 ad5c75b 4b414b1 2a55caa 20017db 2a55caa 20017db 2a55caa 06f6c9e 4b414b1 0725a88 4b414b1 06f6c9e 2a55caa 0725a88 4b414b1 06f6c9e 4b414b1 2a55caa 0725a88 4b414b1 2a55caa 4b414b1 8e74b09 2a55caa 4b414b1 8e74b09 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import os, sys, types, subprocess, tempfile
import torch, gradio as gr
from transformers import (
VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
)
from PIL import Image
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 0. νκ²½ λ³μ
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
os.environ["HF_FORCE_SAFE_SERIALIZATION"] = "1"
os.environ["XFORMERS_FORCE_DISABLE"] = "1" # audiocraft λ΄λΆ νλκ·Έ
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 1. xformers λλ―Έ λͺ¨λ μ£Όμ
(GPU μ’
μ μ κ±°)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
dummy = types.ModuleType("xformers")
dummy.__version__ = "0.0.0"
ops = types.ModuleType("xformers.ops")
def _fake_mea(q, k, v, *_, dropout_p: float = 0.0, **__):
# PyTorch 2.x νμ€ S-DPAλ‘ λ체 (CPUμμλ λμ)
return torch.nn.functional.scaled_dot_product_attention(
q, k, v, dropout_p=dropout_p, is_causal=False
)
class _FakeLowerTriangularMask:
"""audiocraftκ° νμ
μ‘΄μ¬λ§ νμΈνλ―λ‘ λΉ ν΄λμ€λ‘ λ체"""
pass
ops.memory_efficient_attention = _fake_mea
ops.LowerTriangularMask = _FakeLowerTriangularMask
dummy.ops = ops
sys.modules["xformers"] = dummy
sys.modules["xformers.ops"] = ops
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 2. κΈ°ν λλ½ λͺ¨λμ λν λλ―Έ(stub) μμ λ§
# (requirements.txt μ μ΄λ―Έ μ€μΉνμ§λ§, νΉμ λΉ μ Έλ λ°νμ ν΅κ³Ό)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
for name in (
"av", "librosa", "torchdiffeq", "torchmetrics",
"pesq", "pystoi", "soxr"
):
if name not in sys.modules:
sys.modules[name] = types.ModuleType(name)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 3. audiocraft (MusicGen) λΆλ¬μ€κΈ°
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
try:
from audiocraft.models import MusicGen
from audiocraft.data.audio import audio_write
except ModuleNotFoundError:
# λ‘컬 μ€ν λ±μΌλ‘ λ―Έμ€μΉ μ: μμ‘΄μ± μλ ννλ‘ μ€μΉ
subprocess.check_call([
sys.executable, "-m", "pip", "install",
"git+https://github.com/facebookresearch/audiocraft@main",
"--no-deps", "--use-pep517"
])
# νμ μ΅μ μμ‘΄μ±λ§ μ¦μ μ€μΉ (stubλ‘λ λλΆλΆ ν΅κ³Όνμ§λ§ μμ νκ²)
subprocess.check_call([sys.executable, "-m", "pip", "install",
"encodec", "torchdiffeq", "torchmetrics",
"librosa", "soxr", "av"])
from audiocraft.models import MusicGen
from audiocraft.data.audio import audio_write
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 4. μ΄λ―Έμ§ μΊ‘μ
λ λͺ¨λΈ
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
caption_model = VisionEncoderDecoderModel.from_pretrained(
"nlpconnect/vit-gpt2-image-captioning",
use_safetensors=True,
low_cpu_mem_usage=True
)
feature_extractor = ViTImageProcessor.from_pretrained(
"nlpconnect/vit-gpt2-image-captioning"
)
tokenizer = AutoTokenizer.from_pretrained(
"nlpconnect/vit-gpt2-image-captioning"
)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 5. MusicGen λͺ¨λΈ (CPU μ μ©)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
musicgen = MusicGen.get_pretrained("facebook/musicgen-small")
musicgen.set_generation_params(duration=10) # 10μ΄ κΈΈμ΄
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 6. νμ΄νλΌμΈ ν¨μ
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def generate_caption(image: Image.Image) -> str:
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
ids = caption_model.generate(pixel_values, max_length=50)
return tokenizer.decode(ids[0], skip_special_tokens=True)
def generate_music(prompt: str) -> str:
wav = musicgen.generate([prompt]) # batch size = 1
tmpdir = tempfile.mkdtemp()
path = os.path.join(tmpdir, "musicgen.wav")
audio_write(path, wav[0], musicgen.sample_rate, strategy="loudness")
return path
def process(image: Image.Image):
caption = generate_caption(image)
path = generate_music(f"A cheerful melody inspired by: {caption}")
return caption, path
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 7. Gradio UI
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
demo = gr.Interface(
fn=process,
inputs=gr.Image(type="pil"),
outputs=[
gr.Text(label="AIκ° μμ±ν κ·Έλ¦Ό μ€λͺ
"),
gr.Audio(label="μμ±λ AI μμ
(MusicGen)")
],
title="π¨ AI κ·Έλ¦Ό-μμ
μμ±κΈ°",
description="κ·Έλ¦Όμ μ
λ‘λνλ©΄ AIκ° μ€λͺ
μ λ§λ€κ³ , μ€λͺ
μ λ°νμΌλ‘ 10μ΄ κΈΈμ΄μ μμ
μ μμ±ν΄ λ€λ €μ€λλ€."
)
if __name__ == "__main__":
demo.launch()
|