Spaces:
Runtime error
Runtime error
import os, sys, types, subprocess, tempfile | |
import torch, gradio as gr | |
from transformers import ( | |
VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer | |
) | |
from PIL import Image | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# 0. νκ²½ λ³μ | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
os.environ["HF_FORCE_SAFE_SERIALIZATION"] = "1" | |
os.environ["XFORMERS_FORCE_DISABLE"] = "1" # audiocraft λ΄λΆ νλκ·Έ | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# 1. xformers λλ―Έ λͺ¨λ μ£Όμ (GPU μ’ μ μ κ±°) | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
dummy = types.ModuleType("xformers") | |
dummy.__version__ = "0.0.0" | |
ops = types.ModuleType("xformers.ops") | |
def _fake_mea(q, k, v, *_, dropout_p: float = 0.0, **__): | |
return torch.nn.functional.scaled_dot_product_attention( | |
q, k, v, dropout_p=dropout_p, is_causal=False | |
) | |
class _FakeLowerTriangularMask: # audiocraftκ° μ‘΄μ¬ μ¬λΆλ§ νμΈ | |
pass | |
ops.memory_efficient_attention = _fake_mea | |
ops.LowerTriangularMask = _FakeLowerTriangularMask | |
dummy.ops = ops | |
sys.modules["xformers"] = dummy | |
sys.modules["xformers.ops"] = ops | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# 2. (μ ν) μ€μΉνμ§ μμ λͺ¨λλ§ μμ λ§μΌλ‘ μ€ν μ²λ¦¬ β | |
# - μ΄λ―Έ requirements.txtμμ μ€μΉν λͺ¨λ(librosa, av λ±)μ | |
# μ€ν λμμμ μ κ±°ν©λλ€. | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
for name in ("pesq", "pystoi", "soxr"): # β νμμλ§ λ¨κΉ | |
if name not in sys.modules: | |
sys.modules[name] = types.ModuleType(name) | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# 3. audiocraft (MusicGen) λΆλ¬μ€κΈ° | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
try: | |
from audiocraft.models import MusicGen | |
from audiocraft.data.audio import audio_write | |
except ModuleNotFoundError: | |
subprocess.check_call([ | |
sys.executable, "-m", "pip", "install", | |
"git+https://github.com/facebookresearch/audiocraft@main", | |
"--no-deps", "--use-pep517" | |
]) | |
subprocess.check_call([sys.executable, "-m", "pip", "install", | |
"encodec", "librosa", "av", "torchdiffeq", | |
"torchmetrics", "num2words"]) | |
from audiocraft.models import MusicGen | |
from audiocraft.data.audio import audio_write | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# 4. μ΄λ―Έμ§ μΊ‘μ λ λͺ¨λΈ | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# 4. μ΄λ―Έμ§ μΊ‘μ λ λͺ¨λΈ ------------------------------------ | |
caption_model = VisionEncoderDecoderModel.from_pretrained( | |
"nlpconnect/vit-gpt2-image-captioning", | |
use_safetensors=True, # κ·Έλλ‘ | |
low_cpu_mem_usage=False, # β meta λ‘λ© λΉνμ±ν | |
device_map=None # β Accelerate μλ λΆν λκΈ° | |
).eval() # νκ° λͺ¨λ | |
feature_extractor = ViTImageProcessor.from_pretrained( | |
"nlpconnect/vit-gpt2-image-captioning" | |
) | |
tokenizer = AutoTokenizer.from_pretrained( | |
"nlpconnect/vit-gpt2-image-captioning" | |
) | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# 5. MusicGen λͺ¨λΈ | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
musicgen = MusicGen.get_pretrained("facebook/musicgen-small") | |
musicgen.set_generation_params(duration=10) | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# 6. νμ΄νλΌμΈ ν¨μ | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def generate_caption(image: Image.Image) -> str: | |
with torch.no_grad(): | |
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values | |
output_ids = caption_model.generate(pixel_values, max_length=50) | |
return tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
def generate_music(prompt: str) -> str: | |
wav = musicgen.generate([prompt]) # batch size = 1 | |
tmpdir = tempfile.mkdtemp() | |
path = os.path.join(tmpdir, "musicgen.wav") | |
audio_write(path, wav[0], musicgen.sample_rate, strategy="loudness") | |
return path | |
def process(image: Image.Image): | |
caption = generate_caption(image) | |
path = generate_music(f"A cheerful melody inspired by: {caption}") | |
return caption, path | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# 7. Gradio UI | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
demo = gr.Interface( | |
fn=process, | |
inputs=gr.Image(type="pil"), | |
outputs=[ | |
gr.Text(label="AIκ° μμ±ν κ·Έλ¦Ό μ€λͺ "), | |
gr.Audio(label="μμ±λ AI μμ (MusicGen)") | |
], | |
title="π¨ AI κ·Έλ¦Ό-μμ μμ±κΈ°", | |
description="κ·Έλ¦Όμ μ λ‘λνλ©΄ AIκ° μ€λͺ μ λ§λ€κ³ , μ€λͺ μ λ°νμΌλ‘ 10μ΄ κΈΈμ΄μ μμ μ μμ±ν΄ λ€λ €μ€λλ€." | |
) | |
if __name__ == "__main__": | |
demo.launch() | |