yongyeol commited on
Commit
20017db
·
verified ·
1 Parent(s): 8e74b09

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -21
app.py CHANGED
@@ -1,26 +1,21 @@
1
  import gradio as gr
2
  from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
 
 
3
  from PIL import Image
4
  import torch
5
- import requests
6
  import os
7
- from dotenv import load_dotenv
8
-
9
- # ───── 환경 변수 로딩 (토큰 안전하게 가져오기) ─────
10
- load_dotenv()
11
- HF_TOKEN = os.getenv("HF_TOKEN")
12
- if HF_TOKEN is None:
13
- raise ValueError("HF_TOKEN이 .env 파일에 없습니다.")
14
-
15
- # ───── Hugging Face Inference API 설정 ─────
16
- headers = {"Authorization": f"Bearer {HF_TOKEN}"}
17
- MUSICGEN_API = "https://api-inference.huggingface.co/models/facebook/musicgen-small"
18
 
19
  # ───── 이미지 캡셔닝 모델 로딩 ─────
20
  caption_model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
21
  feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
22
  tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
23
 
 
 
 
 
24
  # ───── 이미지 → 설명 문장 생성 ─────
25
  def generate_caption(image):
26
  pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
@@ -28,21 +23,20 @@ def generate_caption(image):
28
  caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)
29
  return caption
30
 
31
- # ───── 설명 → 음악 생성 (MusicGen API 호출) ─────
32
  def generate_music(prompt):
33
- payload = {"inputs": prompt}
34
- response = requests.post(MUSICGEN_API, headers=headers, json=payload)
35
- if response.status_code == 200:
36
- return response.content
37
- else:
38
- raise Exception(f"MusicGen 오류: {response.status_code}, {response.text}")
39
 
40
  # ───── 전체 파이프라인 연결 ─────
41
  def process(image):
42
  caption = generate_caption(image)
43
  prompt = f"A cheerful melody inspired by: {caption}"
44
- audio = generate_music(prompt)
45
- return caption, (audio, "musicgen_output.wav")
46
 
47
  # ───── Gradio 인터페이스 구성 ─────
48
  demo = gr.Interface(
 
1
  import gradio as gr
2
  from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
3
+ from audiocraft.models import MusicGen
4
+ from audiocraft.data.audio import audio_write
5
  from PIL import Image
6
  import torch
 
7
  import os
8
+ import tempfile
 
 
 
 
 
 
 
 
 
 
9
 
10
  # ───── 이미지 캡셔닝 모델 로딩 ─────
11
  caption_model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
12
  feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
13
  tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
14
 
15
+ # ───── MusicGen 모델 로딩 ─────
16
+ musicgen = MusicGen.get_pretrained("facebook/musicgen-small")
17
+ musicgen.set_generation_params(duration=10) # 생성할 음악 길이 (초)
18
+
19
  # ───── 이미지 → 설명 문장 생성 ─────
20
  def generate_caption(image):
21
  pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
 
23
  caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)
24
  return caption
25
 
26
+ # ───── 설명 → 음악 생성 ─────
27
  def generate_music(prompt):
28
+ wav = musicgen.generate([prompt]) # batch size 1
29
+ tmp_dir = tempfile.mkdtemp()
30
+ audio_path = os.path.join(tmp_dir, "musicgen_output.wav")
31
+ audio_write(audio_path, wav[0], musicgen.sample_rate, strategy="loudness")
32
+ return audio_path
 
33
 
34
  # ───── 전체 파이프라인 연결 ─────
35
  def process(image):
36
  caption = generate_caption(image)
37
  prompt = f"A cheerful melody inspired by: {caption}"
38
+ audio_path = generate_music(prompt)
39
+ return caption, audio_path
40
 
41
  # ───── Gradio 인터페이스 구성 ─────
42
  demo = gr.Interface(