Spaces:
Runtime error
Runtime error
File size: 6,889 Bytes
0725a88 78ea8dc bfbdf81 2a55caa 4ae4657 0725a88 bfbdf81 2a55caa 4ae4657 78ea8dc 2a55caa 78ea8dc 82191e2 2a55caa 78ea8dc 4ae4657 78ea8dc 87e6f23 2a55caa 82191e2 2a55caa 82191e2 2a55caa 87e6f23 0725a88 2a55caa 4ae4657 06f6c9e 0725a88 4ae4657 2a55caa 82191e2 87e6f23 0725a88 8e74b09 2a55caa faca888 ad5c75b 9472531 faca888 0836597 0725a88 ad5c75b 4b414b1 07cf72c 2a55caa 82191e2 2a55caa 20017db 82191e2 20017db 2a55caa 06f6c9e d7b41a8 07cf72c 4b414b1 0836597 d7b41a8 06f6c9e 2a55caa 0725a88 4b414b1 06f6c9e 4b414b1 2a55caa 0725a88 4b414b1 2a55caa 4b414b1 8e74b09 2a55caa 4b414b1 8e74b09 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import os, sys, types, subprocess, tempfile
import torch, gradio as gr
from transformers import (
VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
)
from PIL import Image
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 0. νκ²½ λ³μ
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
os.environ["HF_FORCE_SAFE_SERIALIZATION"] = "1"
os.environ["XFORMERS_FORCE_DISABLE"] = "1" # audiocraft λ΄λΆ νλκ·Έ
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 1. xformers λλ―Έ λͺ¨λ μ£Όμ
(GPU μ’
μ μ κ±°)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
dummy = types.ModuleType("xformers")
dummy.__version__ = "0.0.0"
ops = types.ModuleType("xformers.ops")
def _fake_mea(q, k, v, *_, dropout_p: float = 0.0, **__):
return torch.nn.functional.scaled_dot_product_attention(
q, k, v, dropout_p=dropout_p, is_causal=False
)
class _FakeLowerTriangularMask: # audiocraftκ° μ‘΄μ¬ μ¬λΆλ§ νμΈ
pass
ops.memory_efficient_attention = _fake_mea
ops.LowerTriangularMask = _FakeLowerTriangularMask
dummy.ops = ops
sys.modules["xformers"] = dummy
sys.modules["xformers.ops"] = ops
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 2. (μ ν) μ€μΉνμ§ μμ λͺ¨λλ§ μμ λ§μΌλ‘ μ€ν
μ²λ¦¬ β
# - μ΄λ―Έ requirements.txtμμ μ€μΉν λͺ¨λ(librosa, av λ±)μ
# μ€ν
λμμμ μ κ±°ν©λλ€.
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
for name in ("pesq", "pystoi", "soxr"): # β
νμμλ§ λ¨κΉ
if name not in sys.modules:
sys.modules[name] = types.ModuleType(name)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 3. audiocraft (MusicGen) λΆλ¬μ€κΈ°
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
try:
from audiocraft.models import MusicGen
from audiocraft.data.audio import audio_write
except ModuleNotFoundError:
subprocess.check_call([
sys.executable, "-m", "pip", "install",
"git+https://github.com/facebookresearch/audiocraft@main",
"--no-deps", "--use-pep517"
])
subprocess.check_call([sys.executable, "-m", "pip", "install",
"encodec", "librosa", "av", "torchdiffeq",
"torchmetrics", "num2words"])
from audiocraft.models import MusicGen
from audiocraft.data.audio import audio_write
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 4. μ΄λ―Έμ§ μΊ‘μ
λ λͺ¨λΈ
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 4. μ΄λ―Έμ§ μΊ‘μ
λ λͺ¨λΈ ------------------------------------
caption_model = VisionEncoderDecoderModel.from_pretrained(
"nlpconnect/vit-gpt2-image-captioning",
use_safetensors=True, # κ·Έλλ‘
low_cpu_mem_usage=False, # β meta λ‘λ© λΉνμ±ν
device_map=None # β Accelerate μλ λΆν λκΈ°
).eval() # νκ° λͺ¨λ
feature_extractor = ViTImageProcessor.from_pretrained(
"nlpconnect/vit-gpt2-image-captioning"
)
tokenizer = AutoTokenizer.from_pretrained(
"nlpconnect/vit-gpt2-image-captioning"
)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 5. MusicGen λͺ¨λΈ
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
musicgen = MusicGen.get_pretrained("facebook/musicgen-small")
musicgen.set_generation_params(duration=10)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 6. νμ΄νλΌμΈ ν¨μ
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def generate_caption(image: Image.Image) -> str:
with torch.no_grad():
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
output_ids = caption_model.generate(pixel_values, max_length=50)
return tokenizer.decode(output_ids[0], skip_special_tokens=True)
def generate_music(prompt: str) -> str:
wav = musicgen.generate([prompt]) # batch size = 1
tmpdir = tempfile.mkdtemp()
path = os.path.join(tmpdir, "musicgen.wav")
audio_write(path, wav[0], musicgen.sample_rate, strategy="loudness")
return path
def process(image: Image.Image):
caption = generate_caption(image)
path = generate_music(f"A cheerful melody inspired by: {caption}")
return caption, path
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 7. Gradio UI
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
demo = gr.Interface(
fn=process,
inputs=gr.Image(type="pil"),
outputs=[
gr.Text(label="AIκ° μμ±ν κ·Έλ¦Ό μ€λͺ
"),
gr.Audio(label="μμ±λ AI μμ
(MusicGen)")
],
title="π¨ AI κ·Έλ¦Ό-μμ
μμ±κΈ°",
description="κ·Έλ¦Όμ μ
λ‘λνλ©΄ AIκ° μ€λͺ
μ λ§λ€κ³ , μ€λͺ
μ λ°νμΌλ‘ 10μ΄ κΈΈμ΄μ μμ
μ μμ±ν΄ λ€λ €μ€λλ€."
)
if __name__ == "__main__":
demo.launch()
|