yongyeol commited on
Commit
8e74b09
Β·
verified Β·
1 Parent(s): a09b053

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -17
app.py CHANGED
@@ -4,42 +4,57 @@ from PIL import Image
4
  import torch
5
  import requests
6
  import os
 
7
 
8
- # Load caption model
 
 
 
 
 
 
 
 
 
 
9
  caption_model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
10
  feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
11
  tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
12
 
13
- # Load ChatTTS (via inference API)
14
- CHAT_TTS_API = "https://api-inference.huggingface.co/models/2Noise/ChatTTS"
15
- headers = {"Authorization": f"Bearer {os.getenv('HF_TOKEN')}"}
16
-
17
  def generate_caption(image):
18
  pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
19
- output_ids = caption_model.generate(pixel_values, max_length=50) # <- βœ… beam search 제거
20
  caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)
21
  return caption
22
 
23
-
24
- def tts_audio(text):
25
- payload = {"inputs": text}
26
- response = requests.post(CHAT_TTS_API, headers=headers, json=payload)
27
  if response.status_code == 200:
28
  return response.content
29
  else:
30
- raise Exception(f"TTS API 였λ₯˜: {response.status_code}, {response.text}")
31
-
32
 
 
33
  def process(image):
34
  caption = generate_caption(image)
35
- audio = tts_audio(caption)
36
- return caption, (audio, "result.wav")
 
37
 
 
38
  demo = gr.Interface(
39
  fn=process,
40
  inputs=gr.Image(type="pil"),
41
- outputs=[gr.Text(label="μ„€λͺ…"), gr.Audio(label="TTS μŒμ„±")],
42
- title="🎨 AI κ·Έλ¦Ό μ„€λͺ… 낭독기",
 
 
 
 
43
  )
44
 
45
- demo.launch()
 
 
4
  import torch
5
  import requests
6
  import os
7
+ from dotenv import load_dotenv
8
 
9
+ # ───── ν™˜κ²½ λ³€μˆ˜ λ‘œλ”© (토큰 μ•ˆμ „ν•˜κ²Œ κ°€μ Έμ˜€κΈ°) ─────
10
+ load_dotenv()
11
+ HF_TOKEN = os.getenv("HF_TOKEN")
12
+ if HF_TOKEN is None:
13
+ raise ValueError("HF_TOKEN이 .env νŒŒμΌμ— μ—†μŠ΅λ‹ˆλ‹€.")
14
+
15
+ # ───── Hugging Face Inference API μ„€μ • ─────
16
+ headers = {"Authorization": f"Bearer {HF_TOKEN}"}
17
+ MUSICGEN_API = "https://api-inference.huggingface.co/models/facebook/musicgen-small"
18
+
19
+ # ───── 이미지 캑셔닝 λͺ¨λΈ λ‘œλ”© ─────
20
  caption_model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
21
  feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
22
  tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
23
 
24
+ # ───── 이미지 β†’ μ„€λͺ… λ¬Έμž₯ 생성 ─────
 
 
 
25
  def generate_caption(image):
26
  pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
27
+ output_ids = caption_model.generate(pixel_values, max_length=50)
28
  caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)
29
  return caption
30
 
31
+ # ───── μ„€λͺ… β†’ μŒμ•… 생성 (MusicGen API 호좜) ─────
32
+ def generate_music(prompt):
33
+ payload = {"inputs": prompt}
34
+ response = requests.post(MUSICGEN_API, headers=headers, json=payload)
35
  if response.status_code == 200:
36
  return response.content
37
  else:
38
+ raise Exception(f"MusicGen 였λ₯˜: {response.status_code}, {response.text}")
 
39
 
40
+ # ───── 전체 νŒŒμ΄ν”„λΌμΈ μ—°κ²° ─────
41
  def process(image):
42
  caption = generate_caption(image)
43
+ prompt = f"A cheerful melody inspired by: {caption}"
44
+ audio = generate_music(prompt)
45
+ return caption, (audio, "musicgen_output.wav")
46
 
47
+ # ───── Gradio μΈν„°νŽ˜μ΄μŠ€ ꡬ성 ─────
48
  demo = gr.Interface(
49
  fn=process,
50
  inputs=gr.Image(type="pil"),
51
+ outputs=[
52
+ gr.Text(label="AIκ°€ μƒμ„±ν•œ κ·Έλ¦Ό μ„€λͺ…"),
53
+ gr.Audio(label="μƒμ„±λœ AI μŒμ•… (MusicGen)")
54
+ ],
55
+ title="🎨 AI κ·Έλ¦Ό μŒμ•… 생성기",
56
+ description="그림을 μ—…λ‘œλ“œν•˜λ©΄ AIκ°€ μ„€λͺ…을 λ§Œλ“€κ³ , μ„€λͺ…을 λ°”νƒ•μœΌλ‘œ μŒμ•…μ„ λ§Œλ“€μ–΄ λ“€λ €μ€λ‹ˆλ‹€."
57
  )
58
 
59
+ if __name__ == "__main__":
60
+ demo.launch()