Spaces:
Runtime error
Runtime error
import os, sys, types, subprocess, tempfile | |
import torch, gradio as gr | |
from transformers import ( | |
VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer | |
) | |
from PIL import Image | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# 0. νκ²½ λ³μ | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
os.environ["HF_FORCE_SAFE_SERIALIZATION"] = "1" | |
os.environ["XFORMERS_FORCE_DISABLE"] = "1" # audiocraft λ΄λΆ νλκ·Έ | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# 1. xformers λλ―Έ λͺ¨λ μ£Όμ (GPU μ’ μ μ κ±°) | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
dummy = types.ModuleType("xformers") | |
dummy.__version__ = "0.0.0" | |
ops = types.ModuleType("xformers.ops") | |
def _fake_mea(q, k, v, *_, dropout_p: float = 0.0, **__): | |
# PyTorch 2.x νμ€ S-DPAλ‘ λ체 (CPUμμλ λμ) | |
return torch.nn.functional.scaled_dot_product_attention( | |
q, k, v, dropout_p=dropout_p, is_causal=False | |
) | |
class _FakeLowerTriangularMask: | |
"""audiocraftκ° νμ μ‘΄μ¬λ§ νμΈνλ―λ‘ λΉ ν΄λμ€λ‘ λ체""" | |
pass | |
ops.memory_efficient_attention = _fake_mea | |
ops.LowerTriangularMask = _FakeLowerTriangularMask | |
dummy.ops = ops | |
sys.modules["xformers"] = dummy | |
sys.modules["xformers.ops"] = ops | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# 2. κΈ°ν λλ½ λͺ¨λμ λν λλ―Έ(stub) μμ λ§ | |
# (requirements.txt μ μ΄λ―Έ μ€μΉνμ§λ§, νΉμ λΉ μ Έλ λ°νμ ν΅κ³Ό) | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
for name in ( | |
"av", "librosa", "torchdiffeq", "torchmetrics", | |
"pesq", "pystoi", "soxr" | |
): | |
if name not in sys.modules: | |
sys.modules[name] = types.ModuleType(name) | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# 3. audiocraft (MusicGen) λΆλ¬μ€κΈ° | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
try: | |
from audiocraft.models import MusicGen | |
from audiocraft.data.audio import audio_write | |
except ModuleNotFoundError: | |
# λ‘컬 μ€ν λ±μΌλ‘ λ―Έμ€μΉ μ: μμ‘΄μ± μλ ννλ‘ μ€μΉ | |
subprocess.check_call([ | |
sys.executable, "-m", "pip", "install", | |
"git+https://github.com/facebookresearch/audiocraft@main", | |
"--no-deps", "--use-pep517" | |
]) | |
# νμ μ΅μ μμ‘΄μ±λ§ μ¦μ μ€μΉ (stubλ‘λ λλΆλΆ ν΅κ³Όνμ§λ§ μμ νκ²) | |
subprocess.check_call([sys.executable, "-m", "pip", "install", | |
"encodec", "torchdiffeq", "torchmetrics", | |
"librosa", "soxr", "av"]) | |
from audiocraft.models import MusicGen | |
from audiocraft.data.audio import audio_write | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# 4. μ΄λ―Έμ§ μΊ‘μ λ λͺ¨λΈ | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
caption_model = VisionEncoderDecoderModel.from_pretrained( | |
"nlpconnect/vit-gpt2-image-captioning", | |
use_safetensors=True, | |
low_cpu_mem_usage=True | |
) | |
feature_extractor = ViTImageProcessor.from_pretrained( | |
"nlpconnect/vit-gpt2-image-captioning" | |
) | |
tokenizer = AutoTokenizer.from_pretrained( | |
"nlpconnect/vit-gpt2-image-captioning" | |
) | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# 5. MusicGen λͺ¨λΈ (CPU μ μ©) | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
musicgen = MusicGen.get_pretrained("facebook/musicgen-small") | |
musicgen.set_generation_params(duration=10) # 10μ΄ κΈΈμ΄ | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# 6. νμ΄νλΌμΈ ν¨μ | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def generate_caption(image: Image.Image) -> str: | |
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values | |
ids = caption_model.generate(pixel_values, max_length=50) | |
return tokenizer.decode(ids[0], skip_special_tokens=True) | |
def generate_music(prompt: str) -> str: | |
wav = musicgen.generate([prompt]) # batch size = 1 | |
tmpdir = tempfile.mkdtemp() | |
path = os.path.join(tmpdir, "musicgen.wav") | |
audio_write(path, wav[0], musicgen.sample_rate, strategy="loudness") | |
return path | |
def process(image: Image.Image): | |
caption = generate_caption(image) | |
path = generate_music(f"A cheerful melody inspired by: {caption}") | |
return caption, path | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# 7. Gradio UI | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
demo = gr.Interface( | |
fn=process, | |
inputs=gr.Image(type="pil"), | |
outputs=[ | |
gr.Text(label="AIκ° μμ±ν κ·Έλ¦Ό μ€λͺ "), | |
gr.Audio(label="μμ±λ AI μμ (MusicGen)") | |
], | |
title="π¨ AI κ·Έλ¦Ό-μμ μμ±κΈ°", | |
description="κ·Έλ¦Όμ μ λ‘λνλ©΄ AIκ° μ€λͺ μ λ§λ€κ³ , μ€λͺ μ λ°νμΌλ‘ 10μ΄ κΈΈμ΄μ μμ μ μμ±ν΄ λ€λ €μ€λλ€." | |
) | |
if __name__ == "__main__": | |
demo.launch() | |