yongyeol commited on
Commit
0725a88
·
verified ·
1 Parent(s): 8604551

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -45
app.py CHANGED
@@ -1,88 +1,78 @@
1
- import os
2
- import sys
3
- import types
4
- import subprocess
5
- import tempfile
6
- import torch
7
- import gradio as gr
8
- from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
9
  from PIL import Image
10
 
11
  # ── 환경 변수 ────────────────────────────────────────────────
12
  os.environ["HF_FORCE_SAFE_SERIALIZATION"] = "1"
13
- os.environ["XFORMERS_FORCE_DISABLE"] = "1" # 실제 xformers 비활성화
14
 
15
- # ── xformers 더미 모듈 삽입 ─────────────────────────────────
16
  dummy = types.ModuleType("xformers")
17
  dummy.__version__ = "0.0.0"
18
-
19
- # 하위 모듈 xformers.ops
20
  ops = types.ModuleType("xformers.ops")
21
 
22
- def _fake_memory_efficient_attention(q, k, v, *_, dropout_p: float = 0.0, **__):
23
- """
24
- xformers.memory_efficient_attention 대체 구현.
25
- PyTorch 2.x 기본 S-DPA로 처리해 속도는 느리지만 CPU에서도 동작합니다.
26
- """
27
  return torch.nn.functional.scaled_dot_product_attention(
28
  q, k, v, dropout_p=dropout_p, is_causal=False
29
  )
 
30
 
31
- class _FakeLowerTriangularMask: # audiocraft 내부 타입 체크용 더미
32
- pass
33
-
34
- ops.memory_efficient_attention = _fake_memory_efficient_attention
35
  ops.LowerTriangularMask = _FakeLowerTriangularMask
36
-
37
  dummy.ops = ops
38
  sys.modules["xformers"] = dummy
39
  sys.modules["xformers.ops"] = ops
40
  # ────────────────────────────────────────────────────────────
41
 
42
- # ── audiocraft 동적 설치 ─────────────────────────────────────
43
  try:
44
  from audiocraft.models import MusicGen
45
- except ModuleNotFoundError:
 
46
  subprocess.check_call([
47
  sys.executable, "-m", "pip", "install",
48
  "git+https://github.com/facebookresearch/audiocraft@main",
49
- "--use-pep517"
50
  ])
51
  from audiocraft.models import MusicGen
52
-
53
- from audiocraft.data.audio import audio_write
54
 
55
  # ── 이미지 캡셔닝 모델 ─────────────────────────────────────
56
  caption_model = VisionEncoderDecoderModel.from_pretrained(
57
  "nlpconnect/vit-gpt2-image-captioning",
58
- use_safetensors=True,
59
- low_cpu_mem_usage=True
 
 
 
 
 
60
  )
61
- feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
62
- tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
63
 
64
- # ── MusicGen ───────────────────────────────────────────────
65
  musicgen = MusicGen.get_pretrained("facebook/musicgen-small")
66
- musicgen.set_generation_params(duration=10) # 10초 음악
67
 
68
- # ── 유틸 함수들 ─────────────────────────────────────────────
69
  def generate_caption(image: Image.Image) -> str:
70
  pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
71
- output_ids = caption_model.generate(pixel_values, max_length=50)
72
- return tokenizer.decode(output_ids[0], skip_special_tokens=True)
73
 
74
  def generate_music(prompt: str) -> str:
75
- wav = musicgen.generate([prompt]) # batch size = 1
76
- tmp_dir = tempfile.mkdtemp()
77
- audio_path = os.path.join(tmp_dir, "musicgen_output.wav")
78
- audio_write(audio_path, wav[0], musicgen.sample_rate, strategy="loudness")
79
- return audio_path
80
 
81
  def process(image: Image.Image):
82
  caption = generate_caption(image)
83
- prompt = f"A cheerful melody inspired by: {caption}"
84
- audio_path = generate_music(prompt)
85
- return caption, audio_path
86
 
87
  # ── Gradio UI ──────────────────────────────────────────────
88
  demo = gr.Interface(
@@ -93,7 +83,7 @@ demo = gr.Interface(
93
  gr.Audio(label="생성된 AI 음악 (MusicGen)")
94
  ],
95
  title="🎨 AI 그림‑음악 생성기",
96
- description="그림을 업로드하면 AI가 설명을 만들고, 설명을 바탕으로 음악을 생성해 들려줍니다."
97
  )
98
 
99
  if __name__ == "__main__":
 
1
+ import os, sys, types, subprocess, tempfile
2
+ import torch, gradio as gr
3
+ from transformers import (
4
+ VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
5
+ )
 
 
 
6
  from PIL import Image
7
 
8
  # ── 환경 변수 ────────────────────────────────────────────────
9
  os.environ["HF_FORCE_SAFE_SERIALIZATION"] = "1"
10
+ os.environ["XFORMERS_FORCE_DISABLE"] = "1" # audiocraft 내부 플래그
11
 
12
+ # ── xformers 더미 모듈 ───────────────────────────────────────
13
  dummy = types.ModuleType("xformers")
14
  dummy.__version__ = "0.0.0"
 
 
15
  ops = types.ModuleType("xformers.ops")
16
 
17
+ def _fake_mem_eff_attn(q, k, v, *_, dropout_p: float = 0.0, **__):
 
 
 
 
18
  return torch.nn.functional.scaled_dot_product_attention(
19
  q, k, v, dropout_p=dropout_p, is_causal=False
20
  )
21
+ class _FakeLowerTriangularMask: pass
22
 
23
+ ops.memory_efficient_attention = _fake_mem_eff_attn
 
 
 
24
  ops.LowerTriangularMask = _FakeLowerTriangularMask
 
25
  dummy.ops = ops
26
  sys.modules["xformers"] = dummy
27
  sys.modules["xformers.ops"] = ops
28
  # ────────────────────────────────────────────────────────────
29
 
30
+ # ── audiocraft 로드 (postInstall에서 이미 설치됐음) ───────────
31
  try:
32
  from audiocraft.models import MusicGen
33
+ from audiocraft.data.audio import audio_write
34
+ except ModuleNotFoundError: # 예외적 로컬 실행 대비
35
  subprocess.check_call([
36
  sys.executable, "-m", "pip", "install",
37
  "git+https://github.com/facebookresearch/audiocraft@main",
38
+ "--no-deps", "--use-pep517"
39
  ])
40
  from audiocraft.models import MusicGen
41
+ from audiocraft.data.audio import audio_write
 
42
 
43
  # ── 이미지 캡셔닝 모델 ─────────────────────────────────────
44
  caption_model = VisionEncoderDecoderModel.from_pretrained(
45
  "nlpconnect/vit-gpt2-image-captioning",
46
+ use_safetensors=True, low_cpu_mem_usage=True
47
+ )
48
+ feature_extractor = ViTImageProcessor.from_pretrained(
49
+ "nlpconnect/vit-gpt2-image-captioning"
50
+ )
51
+ tokenizer = AutoTokenizer.from_pretrained(
52
+ "nlpconnect/vit-gpt2-image-captioning"
53
  )
 
 
54
 
55
+ # ── MusicGen 모델 ──────────────────────────────────────────
56
  musicgen = MusicGen.get_pretrained("facebook/musicgen-small")
57
+ musicgen.set_generation_params(duration=10)
58
 
59
+ # ── 파이프라인 함수들 ──────────────────────────────────────
60
  def generate_caption(image: Image.Image) -> str:
61
  pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
62
+ ids = caption_model.generate(pixel_values, max_length=50)
63
+ return tokenizer.decode(ids[0], skip_special_tokens=True)
64
 
65
  def generate_music(prompt: str) -> str:
66
+ wav = musicgen.generate([prompt])
67
+ tmpdir = tempfile.mkdtemp()
68
+ path = os.path.join(tmpdir, "musicgen.wav")
69
+ audio_write(path, wav[0], musicgen.sample_rate, strategy="loudness")
70
+ return path
71
 
72
  def process(image: Image.Image):
73
  caption = generate_caption(image)
74
+ path = generate_music(f"A cheerful melody inspired by: {caption}")
75
+ return caption, path
 
76
 
77
  # ── Gradio UI ──────────────────────────────────────────────
78
  demo = gr.Interface(
 
83
  gr.Audio(label="생성된 AI 음악 (MusicGen)")
84
  ],
85
  title="🎨 AI 그림‑음악 생성기",
86
+ description="그림을 업로드하면 AI가 설명을 만들고, 설명을 바탕으로 음악을 10초간 생성해 들려줍니다."
87
  )
88
 
89
  if __name__ == "__main__":