Spaces:
Runtime error
Runtime error
File size: 4,058 Bytes
0725a88 78ea8dc bfbdf81 78ea8dc 4ae4657 0725a88 bfbdf81 0725a88 4ae4657 78ea8dc 0725a88 78ea8dc 0725a88 78ea8dc 0725a88 78ea8dc 4ae4657 78ea8dc 87e6f23 0725a88 87e6f23 0725a88 4ae4657 06f6c9e 0725a88 4ae4657 87e6f23 0725a88 8e74b09 78ea8dc ad5c75b 9472531 0725a88 ad5c75b 4b414b1 0725a88 20017db 0725a88 20017db 0725a88 06f6c9e 4b414b1 0725a88 4b414b1 06f6c9e 0725a88 4b414b1 06f6c9e 4b414b1 0725a88 4b414b1 78ea8dc 4b414b1 8e74b09 78ea8dc 0725a88 4b414b1 8e74b09 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import os, sys, types, subprocess, tempfile
import torch, gradio as gr
from transformers import (
VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
)
from PIL import Image
# ββ νκ²½ λ³μ ββββββββββββββββββββββββββββββββββββββββββββββββ
os.environ["HF_FORCE_SAFE_SERIALIZATION"] = "1"
os.environ["XFORMERS_FORCE_DISABLE"] = "1" # audiocraft λ΄λΆ νλκ·Έ
# ββ xformers λλ―Έ λͺ¨λ βββββββββββββββββββββββββββββββββββββββ
dummy = types.ModuleType("xformers")
dummy.__version__ = "0.0.0"
ops = types.ModuleType("xformers.ops")
def _fake_mem_eff_attn(q, k, v, *_, dropout_p: float = 0.0, **__):
return torch.nn.functional.scaled_dot_product_attention(
q, k, v, dropout_p=dropout_p, is_causal=False
)
class _FakeLowerTriangularMask: pass
ops.memory_efficient_attention = _fake_mem_eff_attn
ops.LowerTriangularMask = _FakeLowerTriangularMask
dummy.ops = ops
sys.modules["xformers"] = dummy
sys.modules["xformers.ops"] = ops
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# ββ audiocraft λ‘λ (postInstallμμ μ΄λ―Έ μ€μΉλμ) βββββββββββ
try:
from audiocraft.models import MusicGen
from audiocraft.data.audio import audio_write
except ModuleNotFoundError: # μμΈμ λ‘컬 μ€ν λλΉ
subprocess.check_call([
sys.executable, "-m", "pip", "install",
"git+https://github.com/facebookresearch/audiocraft@main",
"--no-deps", "--use-pep517"
])
from audiocraft.models import MusicGen
from audiocraft.data.audio import audio_write
# ββ μ΄λ―Έμ§ μΊ‘μ
λ λͺ¨λΈ βββββββββββββββββββββββββββββββββββββ
caption_model = VisionEncoderDecoderModel.from_pretrained(
"nlpconnect/vit-gpt2-image-captioning",
use_safetensors=True, low_cpu_mem_usage=True
)
feature_extractor = ViTImageProcessor.from_pretrained(
"nlpconnect/vit-gpt2-image-captioning"
)
tokenizer = AutoTokenizer.from_pretrained(
"nlpconnect/vit-gpt2-image-captioning"
)
# ββ MusicGen λͺ¨λΈ ββββββββββββββββββββββββββββββββββββββββββ
musicgen = MusicGen.get_pretrained("facebook/musicgen-small")
musicgen.set_generation_params(duration=10)
# ββ νμ΄νλΌμΈ ν¨μλ€ ββββββββββββββββββββββββββββββββββββββ
def generate_caption(image: Image.Image) -> str:
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
ids = caption_model.generate(pixel_values, max_length=50)
return tokenizer.decode(ids[0], skip_special_tokens=True)
def generate_music(prompt: str) -> str:
wav = musicgen.generate([prompt])
tmpdir = tempfile.mkdtemp()
path = os.path.join(tmpdir, "musicgen.wav")
audio_write(path, wav[0], musicgen.sample_rate, strategy="loudness")
return path
def process(image: Image.Image):
caption = generate_caption(image)
path = generate_music(f"A cheerful melody inspired by: {caption}")
return caption, path
# ββ Gradio UI ββββββββββββββββββββββββββββββββββββββββββββββ
demo = gr.Interface(
fn=process,
inputs=gr.Image(type="pil"),
outputs=[
gr.Text(label="AIκ° μμ±ν κ·Έλ¦Ό μ€λͺ
"),
gr.Audio(label="μμ±λ AI μμ
(MusicGen)")
],
title="π¨ AI κ·Έλ¦Όβμμ
μμ±κΈ°",
description="κ·Έλ¦Όμ μ
λ‘λνλ©΄ AIκ° μ€λͺ
μ λ§λ€κ³ , μ€λͺ
μ λ°νμΌλ‘ μμ
μ 10μ΄κ° μμ±ν΄ λ€λ €μ€λλ€."
)
if __name__ == "__main__":
demo.launch()
|