yongyeol commited on
Commit
6748e07
·
verified ·
1 Parent(s): d7b41a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -94
app.py CHANGED
@@ -1,114 +1,57 @@
1
- import os, sys, types, subprocess, tempfile
2
- import torch, gradio as gr
3
- from transformers import (
4
- VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
5
- )
6
  from PIL import Image
7
 
8
  # ─────────────────────────────────────────────────────────────
9
- # 0. 환경 변수
10
- # ─────────────────────────────────────────────────────────────
11
- os.environ["HF_FORCE_SAFE_SERIALIZATION"] = "1"
12
- os.environ["XFORMERS_FORCE_DISABLE"] = "1" # audiocraft 내부 플래그
13
-
14
  # ─────────────────────────────────────────────────────────────
15
- # 1. xformers 더미 모듈 주입 (GPU 종속 제거)
16
- # ─────────────────────────────────────────────────────────────
17
- dummy = types.ModuleType("xformers")
18
- dummy.__version__ = "0.0.0"
19
- ops = types.ModuleType("xformers.ops")
20
-
21
- def _fake_mea(q, k, v, *_, dropout_p: float = 0.0, **__):
22
- return torch.nn.functional.scaled_dot_product_attention(
23
- q, k, v, dropout_p=dropout_p, is_causal=False
24
- )
25
 
26
- class _FakeLowerTriangularMask: # audiocraft가 존재 여부만 확인
27
- pass
28
-
29
- ops.memory_efficient_attention = _fake_mea
30
- ops.LowerTriangularMask = _FakeLowerTriangularMask
31
- dummy.ops = ops
32
- sys.modules["xformers"] = dummy
33
- sys.modules["xformers.ops"] = ops
34
 
35
  # ─────────────────────────────────────────────────────────────
36
- # 2. (선택) 설치하지 않은 모듈만 안전망으로 스텁 처리 ★
37
- # - 이미 requirements.txt에서 설치한 모듈(librosa, av 등)은
38
- # 스텁 대상에서 제거합니다.
39
  # ─────────────────────────────────────────────────────────────
40
- for name in ("pesq", "pystoi", "soxr"): # 필요시만 남김
41
- if name not in sys.modules:
42
- sys.modules[name] = types.ModuleType(name)
 
43
 
44
- # ─────────────────────────────────────────────────────────────
45
- # 3. audiocraft (MusicGen) 불러오기
46
- # ─────────────────────────────────────────────────────────────
47
- try:
48
- from audiocraft.models import MusicGen
49
- from audiocraft.data.audio import audio_write
50
- except ModuleNotFoundError:
51
- subprocess.check_call([
52
- sys.executable, "-m", "pip", "install",
53
- "git+https://github.com/facebookresearch/audiocraft@main",
54
- "--no-deps", "--use-pep517"
55
- ])
56
- subprocess.check_call([sys.executable, "-m", "pip", "install",
57
- "encodec", "librosa", "av", "torchdiffeq",
58
- "torchmetrics", "num2words"])
59
- from audiocraft.models import MusicGen
60
- from audiocraft.data.audio import audio_write
61
 
62
  # ─────────────────────────────────────────────────────────────
63
- # 4. 이미지 캡셔닝 모델
64
  # ─────────────────────────────────────────────────────────────
65
- # 4. 이미지 캡셔닝 모델 ------------------------------------
66
- caption_model = VisionEncoderDecoderModel.from_pretrained(
67
- "nlpconnect/vit-gpt2-image-captioning",
68
- use_safetensors=True, # 그대로
69
- low_cpu_mem_usage=False, # ← meta 로딩 비활성화
70
- device_map=None # ← Accelerate 자동 분할 끄기
71
- ).eval() # 평가 모드
72
-
73
- feature_extractor = ViTImageProcessor.from_pretrained(
74
- "nlpconnect/vit-gpt2-image-captioning"
75
- )
76
- tokenizer = AutoTokenizer.from_pretrained(
77
- "nlpconnect/vit-gpt2-image-captioning"
78
- )
79
-
80
 
81
- # ─────────────────────────────────────────────────────────────
82
- # 5. MusicGen 모델
83
- # ─────────────────────────────────────────────────────────────
84
- musicgen = MusicGen.get_pretrained("facebook/musicgen-small")
85
- musicgen.set_generation_params(duration=10)
86
 
87
  # ─────────────────────────────────────────────────────────────
88
- # 6. 파이프라인 함수
89
  # ─────────────────────────────────────────────────────────────
90
- def generate_caption(image: Image.Image) -> str:
91
- with torch.no_grad():
92
- pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
93
- output_ids = caption_model.generate(pixel_values, max_length=50)
94
- return tokenizer.decode(output_ids[0], skip_special_tokens=True)
95
-
96
-
97
-
98
- def generate_music(prompt: str) -> str:
99
- wav = musicgen.generate([prompt]) # batch size = 1
100
- tmpdir = tempfile.mkdtemp()
101
- path = os.path.join(tmpdir, "musicgen.wav")
102
- audio_write(path, wav[0], musicgen.sample_rate, strategy="loudness")
103
- return path
104
-
105
- def process(image: Image.Image):
106
  caption = generate_caption(image)
107
- path = generate_music(f"A cheerful melody inspired by: {caption}")
108
- return caption, path
109
 
110
  # ─────────────────────────────────────────────────────────────
111
- # 7. Gradio UI
112
  # ─────────────────────────────────────────────────────────────
113
  demo = gr.Interface(
114
  fn=process,
@@ -117,8 +60,10 @@ demo = gr.Interface(
117
  gr.Text(label="AI가 생성한 그림 설명"),
118
  gr.Audio(label="생성된 AI 음악 (MusicGen)")
119
  ],
120
- title="🎨 AI 그림-음악 생성기",
121
- description="그림을 업로드하면 AI가 설명을 만들고, 설명을 바탕으로 10초 길이의 음악을 생성해 들려줍니다."
 
 
122
  )
123
 
124
  if __name__ == "__main__":
 
1
+ import os, io, base64, tempfile, requests
2
+ import gradio as gr
 
 
 
3
  from PIL import Image
4
 
5
  # ─────────────────────────────────────────────────────────────
6
+ # 1. 환경 변수 & HF Inference API 설정
 
 
 
 
7
  # ─────────────────────────────────────────────────────────────
8
+ HF_TOKEN = os.getenv("HF_TOKEN")
9
+ if not HF_TOKEN:
10
+ raise RuntimeError("HF_TOKEN 비밀 값이 설정되어 있지 않습니다. Spaces Settings → Secrets에서 등록해 주세요.")
 
 
 
 
 
 
 
11
 
12
+ HEADERS = {"Authorization": f"Bearer {HF_TOKEN}"}
13
+ CAPTION_API = "https://api-inference.huggingface.co/models/Salesforce/blip-image-captioning-base"
14
+ MUSIC_API = "https://api-inference.huggingface.co/models/facebook/musicgen-small"
 
 
 
 
 
15
 
16
  # ─────────────────────────────────────────────────────────────
17
+ # 2. 이미지 캡션 생성 (BLIP-base via API)
 
 
18
  # ─────────────────────────────────────────────────────────────
19
+ def generate_caption(image_pil: Image.Image) -> str:
20
+ buf = io.BytesIO()
21
+ image_pil.save(buf, format="PNG")
22
+ buf.seek(0)
23
 
24
+ # binary upload 방식
25
+ response = requests.post(CAPTION_API, headers=HEADERS, data=buf.getvalue(), timeout=60)
26
+ response.raise_for_status()
27
+ result = response.json()
28
+ # API 응답: [{"generated_text": "..."}]
29
+ return result[0]["generated_text"]
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  # ─────────────────────────────────────────────────────────────
32
+ # 3. MusicGen-small 음악 생성 (10초, via API)
33
  # ─────────────────────────────────────────────────────────────
34
+ def generate_music(prompt: str, duration: int = 10) -> str:
35
+ payload = {"inputs": prompt, "parameters": {"duration": duration}}
36
+ response = requests.post(MUSIC_API, headers=HEADERS, json=payload, timeout=120)
37
+ response.raise_for_status()
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ # API 응답은 WAV 바이너리
40
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
41
+ tmp.write(response.content)
42
+ tmp.close()
43
+ return tmp.name
44
 
45
  # ─────────────────────────────────────────────────────────────
46
+ # 4. 전체 파이프라인
47
  # ─────────────────────────────────────────────────────────────
48
+ def process(image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  caption = generate_caption(image)
50
+ audio = generate_music(f"A cheerful melody inspired by: {caption}")
51
+ return caption, audio
52
 
53
  # ─────────────────────────────────────────────────────────────
54
+ # 5. Gradio 인터페이스
55
  # ─────────────────────────────────────────────────────────────
56
  demo = gr.Interface(
57
  fn=process,
 
60
  gr.Text(label="AI가 생성한 그림 설명"),
61
  gr.Audio(label="생성된 AI 음악 (MusicGen)")
62
  ],
63
+ title="🎨 AI 그림-음악 생성기 (Inference API 버전)",
64
+ description="이미지를 업로드하면 BLIP-base가 설명을 생성하고, 해당 설명으로 MusicGen-small이 10초 음악을 만듭니다.",
65
+ concurrency_count=1, # 메모리 보호용: 동시 1요청
66
+ cache_examples=False
67
  )
68
 
69
  if __name__ == "__main__":