yongyeol commited on
Commit
343dde8
Β·
verified Β·
1 Parent(s): e3eaf60

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -23
app.py CHANGED
@@ -1,42 +1,62 @@
1
- import os, tempfile, requests
2
  import gradio as gr
3
  from PIL import Image
4
  from transformers import pipeline
5
 
6
- # ────────────────────── 1. 캑셔닝 νŒŒμ΄ν”„λΌμΈ ──────────────────────
 
 
 
 
 
7
  caption_pipe = pipeline(
8
  "image-to-text",
9
- model="Salesforce/blip-image-captioning-base", # tiny λͺ¨λΈλ‘œ λ°”κΎΈλ €λ©΄ μ—¬κΈ°λ§Œ μˆ˜μ •
10
- device=-1, # -1 β†’ CPU, 0 이상 β†’ GPU ID (Spaces CPU라면 -1 μœ μ§€)
 
 
 
 
 
 
 
11
  )
12
 
13
- # ────────────────────── 2. MusicGen(Inf-API) ─────────────────────
14
- HF_TOKEN = os.getenv("HF_TOKEN")
15
- HEADERS = {"Authorization": f"Bearer {HF_TOKEN}"}
16
- MUSIC_API = "https://api-inference.huggingface.co/models/facebook/musicgen-small"
 
 
 
 
 
17
 
18
- def generate_music(prompt: str, duration=10) -> str:
19
- payload = {"inputs": prompt, "parameters": {"duration": duration}}
20
- r = requests.post(MUSIC_API, headers=HEADERS, json=payload, timeout=120)
21
- r.raise_for_status()
22
  tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
23
- tmp.write(r.content)
24
- tmp.close()
25
  return tmp.name
26
 
27
- # ────────────────────── 3. 전체 νŒŒμ΄ν”„λΌμΈ ──────────────────────
28
- def process(image: Image.Image):
29
- caption = caption_pipe(image)[0]["generated_text"]
30
- music = generate_music(f"A cheerful melody inspired by: {caption}")
31
- return caption, music
 
 
32
 
33
- # ────────────────────── 4. Gradio UI ────────────────────────────
 
 
34
  demo = gr.Interface(
35
  fn=process,
36
  inputs=gr.Image(type="pil"),
37
- outputs=[gr.Text(), gr.Audio()],
38
- title="🎨 둜컬 BLIP-base + MusicGen-API",
39
- description="CPUμ—μ„œ BLIP-base둜 μΊ‘μ…˜μ„ μƒμ„±ν•˜κ³ , ν•΄λ‹Ή μΊ‘μ…˜μ„ MusicGen-small Inference API둜 전달해 10초 μŒμ•…μ„ λ§Œλ“­λ‹ˆλ‹€."
 
 
 
 
40
  ).queue()
41
 
42
  if __name__ == "__main__":
 
1
+ import os, tempfile, soundfile as sf
2
  import gradio as gr
3
  from PIL import Image
4
  from transformers import pipeline
5
 
6
+ # ────────────────────────────────────────────────
7
+ # 1. νŒŒμ΄ν”„λΌμΈ λ‘œλ“œ (CPU: device=-1)
8
+ # ────────────────────────────────────────────────
9
+ CAPTION_ID = "Salesforce/blip-image-captioning-base" # μš©λŸ‰β†“: blip-image-captioning
10
+ MUSIC_ID = "facebook/musicgen-melody" # μš©λŸ‰β†“: musicgen-small
11
+
12
  caption_pipe = pipeline(
13
  "image-to-text",
14
+ model=CAPTION_ID,
15
+ device=-1
16
+ )
17
+
18
+ music_pipe = pipeline(
19
+ "text-to-audio",
20
+ model=MUSIC_ID,
21
+ device=-1,
22
+ generate_kwargs={"duration": 10} # 초 λ‹¨μœ„
23
  )
24
 
25
+ # ────────────────────────────────────────────────
26
+ # 2. μœ ν‹Έ ν•¨μˆ˜
27
+ # ────────────────────────────────────────────────
28
+ def generate_caption(img: Image.Image) -> str:
29
+ return caption_pipe(img)[0]["generated_text"]
30
+
31
+ def generate_music(prompt: str) -> str:
32
+ result = music_pipe(prompt, forward_params={"do_sample": True})[0]
33
+ audio, sr = result["audio"], result["sampling_rate"]
34
 
 
 
 
 
35
  tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
36
+ sf.write(tmp.name, audio, sr)
 
37
  return tmp.name
38
 
39
+ # ────────────────────────────────────────────────
40
+ # 3. 전체 νŒŒμ΄ν”„λΌμΈ
41
+ # ────────────────────────────────────────────────
42
+ def process(image):
43
+ caption = generate_caption(image)
44
+ audio = generate_music(f"A cheerful melody inspired by: {caption}")
45
+ return caption, audio
46
 
47
+ # ────────────────────────────────────────────────
48
+ # 4. Gradio UI
49
+ # ────────────────────────────────────────────────
50
  demo = gr.Interface(
51
  fn=process,
52
  inputs=gr.Image(type="pil"),
53
+ outputs=[
54
+ gr.Text(label="AIκ°€ μƒμ„±ν•œ κ·Έλ¦Ό μ„€λͺ…"),
55
+ gr.Audio(label="μƒμ„±λœ AI μŒμ•… (MusicGen)")
56
+ ],
57
+ title="🎨 둜컬 BLIP-base + MusicGen-melody",
58
+ description="이미지λ₯Ό μ—…λ‘œλ“œν•˜λ©΄ BLIP-baseκ°€ μ„€λͺ…을 μƒμ„±ν•˜κ³ , "
59
+ "κ·Έ μ„€λͺ…μœΌλ‘œ MusicGen-melodyκ°€ 10초 μŒμ•…μ„ λ§Œλ“­λ‹ˆλ‹€."
60
  ).queue()
61
 
62
  if __name__ == "__main__":