yongyeol commited on
Commit
78ea8dc
·
verified ·
1 Parent(s): 159859d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -21
app.py CHANGED
@@ -3,36 +3,56 @@ import sys
3
  import types
4
  import subprocess
5
  import tempfile
 
 
 
 
6
 
7
- # ── 환경 변수 설정 ──────────────────────────────────────────────
8
  os.environ["HF_FORCE_SAFE_SERIALIZATION"] = "1"
9
- os.environ["XFORMERS_FORCE_DISABLE"] = "1" # xformers 비활성화
10
 
11
- # ── ✨ xformers 더미 모듈 삽입 ──────────────────────────────────
12
  dummy = types.ModuleType("xformers")
13
- dummy.ops = types.ModuleType("xformers.ops") # audiocraft가 ops 하위모듈도 찾음
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  sys.modules["xformers"] = dummy
15
- sys.modules["xformers.ops"] = dummy.ops
16
- # ───────────────────────────────────────────────────────────────
17
 
18
- # ── audiocraft 동적 설치 ───────────────────────────────────────
19
  try:
20
  from audiocraft.models import MusicGen
21
  except ModuleNotFoundError:
22
  subprocess.check_call([
23
  sys.executable, "-m", "pip", "install",
24
  "git+https://github.com/facebookresearch/audiocraft@main",
25
- "--use-pep517" # 의존성 포함 설치
26
  ])
27
  from audiocraft.models import MusicGen
28
 
29
- import gradio as gr
30
- from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
31
  from audiocraft.data.audio import audio_write
32
- from PIL import Image
33
- import torch
34
 
35
- # ───── 이미지 캡셔닝 모델 로딩 ─────────────────────────────────
36
  caption_model = VisionEncoderDecoderModel.from_pretrained(
37
  "nlpconnect/vit-gpt2-image-captioning",
38
  use_safetensors=True,
@@ -41,32 +61,30 @@ caption_model = VisionEncoderDecoderModel.from_pretrained(
41
  feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
42
  tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
43
 
44
- # ───── MusicGen 모델 로딩 ─────────────────────────────────────
45
  musicgen = MusicGen.get_pretrained("facebook/musicgen-small")
46
- musicgen.set_generation_params(duration=10) # 생성 음악 길이(초)
47
 
48
- # ───── 이미지 설명 문장 생성 함수 ────────────────────────────
49
  def generate_caption(image: Image.Image) -> str:
50
  pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
51
  output_ids = caption_model.generate(pixel_values, max_length=50)
52
  return tokenizer.decode(output_ids[0], skip_special_tokens=True)
53
 
54
- # ───── 설명 → 음악 생성 함수 ──────────────────────────────────
55
  def generate_music(prompt: str) -> str:
56
- wav = musicgen.generate([prompt]) # batch size = 1
57
  tmp_dir = tempfile.mkdtemp()
58
  audio_path = os.path.join(tmp_dir, "musicgen_output.wav")
59
  audio_write(audio_path, wav[0], musicgen.sample_rate, strategy="loudness")
60
  return audio_path
61
 
62
- # ───── 전체 파이프라인 ────────────────────────────────────────
63
  def process(image: Image.Image):
64
  caption = generate_caption(image)
65
  prompt = f"A cheerful melody inspired by: {caption}"
66
  audio_path = generate_music(prompt)
67
  return caption, audio_path
68
 
69
- # ───── Gradio 인터페이스 ─────────────────────────────────────
70
  demo = gr.Interface(
71
  fn=process,
72
  inputs=gr.Image(type="pil"),
@@ -74,7 +92,7 @@ demo = gr.Interface(
74
  gr.Text(label="AI가 생성한 그림 설명"),
75
  gr.Audio(label="생성된 AI 음악 (MusicGen)")
76
  ],
77
- title="🎨 AI 그림-음악 생성기",
78
  description="그림을 업로드하면 AI가 설명을 만들고, 설명을 바탕으로 음악을 생성해 들려줍니다."
79
  )
80
 
 
3
  import types
4
  import subprocess
5
  import tempfile
6
+ import torch
7
+ import gradio as gr
8
+ from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
9
+ from PIL import Image
10
 
11
+ # ── 환경 변수 ────────────────────────────────────────────────
12
  os.environ["HF_FORCE_SAFE_SERIALIZATION"] = "1"
13
+ os.environ["XFORMERS_FORCE_DISABLE"] = "1" # 실제 xformers 비활성화
14
 
15
+ # ── ✨ xformers 더미 모듈 삽입 ─────────────────────────────────
16
  dummy = types.ModuleType("xformers")
17
+ dummy.__version__ = "0.0.0"
18
+
19
+ # 하위 모듈 xformers.ops
20
+ ops = types.ModuleType("xformers.ops")
21
+
22
+ def _fake_memory_efficient_attention(q, k, v, *_, dropout_p: float = 0.0, **__):
23
+ """
24
+ xformers.memory_efficient_attention 대체 구현.
25
+ PyTorch 2.x 기본 S-DPA로 처리해 속도는 느리지만 CPU에서도 동작합니다.
26
+ """
27
+ return torch.nn.functional.scaled_dot_product_attention(
28
+ q, k, v, dropout_p=dropout_p, is_causal=False
29
+ )
30
+
31
+ class _FakeLowerTriangularMask: # audiocraft 내부 타입 체크용 더미
32
+ pass
33
+
34
+ ops.memory_efficient_attention = _fake_memory_efficient_attention
35
+ ops.LowerTriangularMask = _FakeLowerTriangularMask
36
+
37
+ dummy.ops = ops
38
  sys.modules["xformers"] = dummy
39
+ sys.modules["xformers.ops"] = ops
40
+ # ────────────────────────────────────────────────────────────
41
 
42
+ # ── audiocraft 동적 설치 ─────────────────────────────────────
43
  try:
44
  from audiocraft.models import MusicGen
45
  except ModuleNotFoundError:
46
  subprocess.check_call([
47
  sys.executable, "-m", "pip", "install",
48
  "git+https://github.com/facebookresearch/audiocraft@main",
49
+ "--use-pep517"
50
  ])
51
  from audiocraft.models import MusicGen
52
 
 
 
53
  from audiocraft.data.audio import audio_write
 
 
54
 
55
+ # ── 이미지 캡셔닝 모델 ─────────────────────────────────────
56
  caption_model = VisionEncoderDecoderModel.from_pretrained(
57
  "nlpconnect/vit-gpt2-image-captioning",
58
  use_safetensors=True,
 
61
  feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
62
  tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
63
 
64
+ # ── MusicGen ───────────────────────────────────────────────
65
  musicgen = MusicGen.get_pretrained("facebook/musicgen-small")
66
+ musicgen.set_generation_params(duration=10) # 10초 음악
67
 
68
+ # ── 유틸 함수들 ─────────────────────────────────────────────
69
  def generate_caption(image: Image.Image) -> str:
70
  pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
71
  output_ids = caption_model.generate(pixel_values, max_length=50)
72
  return tokenizer.decode(output_ids[0], skip_special_tokens=True)
73
 
 
74
  def generate_music(prompt: str) -> str:
75
+ wav = musicgen.generate([prompt]) # batch size = 1
76
  tmp_dir = tempfile.mkdtemp()
77
  audio_path = os.path.join(tmp_dir, "musicgen_output.wav")
78
  audio_write(audio_path, wav[0], musicgen.sample_rate, strategy="loudness")
79
  return audio_path
80
 
 
81
  def process(image: Image.Image):
82
  caption = generate_caption(image)
83
  prompt = f"A cheerful melody inspired by: {caption}"
84
  audio_path = generate_music(prompt)
85
  return caption, audio_path
86
 
87
+ # ── Gradio UI ──────────────────────────────────────────────
88
  demo = gr.Interface(
89
  fn=process,
90
  inputs=gr.Image(type="pil"),
 
92
  gr.Text(label="AI가 생성한 그림 설명"),
93
  gr.Audio(label="생성된 AI 음악 (MusicGen)")
94
  ],
95
+ title="🎨 AI 그림‑음악 생성기",
96
  description="그림을 업로드하면 AI가 설명을 만들고, 설명을 바탕으로 음악을 생성해 들려줍니다."
97
  )
98