yongyeol commited on
Commit
8586da3
Β·
verified Β·
1 Parent(s): 7d2e0c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -69
app.py CHANGED
@@ -1,107 +1,115 @@
1
  import gradio as gr
2
  import logging
3
  from PIL import Image
 
 
4
  from transformers import (
5
  BlipProcessor,
6
  BlipForConditionalGeneration,
7
  pipeline,
8
  AutoTokenizer,
9
- VitsModel
10
  )
11
- import torch
12
- from uroman import Uroman
13
 
14
- # ─────────────── λ‘œκΉ… μ„€μ • ───────────────
15
  logging.basicConfig(level=logging.INFO)
16
 
17
- # ─────────────── 1. BLIP 이미지 캑셔닝 (μ˜μ–΄ 생성) ───────────────
18
- processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
19
- blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
20
- blip_model.to("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
21
 
22
- # ─────────────── 2. μ˜μ–΄ β†’ ν•œκ΅­μ–΄ λ²ˆμ—­ ───────────────
23
  translation_pipeline = pipeline(
24
  "translation",
25
  model="facebook/nllb-200-distilled-600M",
26
  src_lang="eng_Latn",
27
  tgt_lang="kor_Hang",
28
  max_length=200,
29
- device=0 if torch.cuda.is_available() else -1
30
  )
31
 
32
- # ─────────────── 3. ν•œκ΅­μ–΄ TTS ───────────────
33
- tts_model = VitsModel.from_pretrained("facebook/mms-tts-kor")
34
- tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-kor")
35
- tts_model.to("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
36
 
37
  uroman = Uroman()
38
 
39
- def synthesize_tts(text: str):
40
- """ν•œκΈ€ λ¬Έμž₯을 VITS‑TTS νŒŒν˜•μœΌλ‘œ λ³€ν™˜"""
41
- romanized = uroman.romanize_string(text)
42
- inputs = tts_tokenizer(romanized, return_tensors="pt")
43
- input_ids = inputs["input_ids"].long().to(tts_model.device)
44
  with torch.no_grad():
45
- output = tts_model(input_ids=input_ids)
46
- waveform = output.waveform.squeeze().cpu().numpy()
47
- return tts_model.config.sampling_rate, waveform
48
 
49
- # ─────────────── 4. 이미지 β†’ μΊ‘μ…˜ + λ²ˆμ—­ + μŒμ„± 좜λ ₯ ───────────────
50
- def describe_and_speak(img: Image.Image):
51
- logging.info("[DEBUG] describe_and_speak 호좜")
 
 
 
 
 
52
 
53
  # β‘  μ˜μ–΄ μΊ‘μ…˜
54
- pixel_values = processor(images=img, return_tensors="pt").pixel_values.to(blip_model.device)
55
- generated_ids = blip_model.generate(pixel_values, max_length=64)
56
- caption_en = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
57
- logging.info(f"[DEBUG] caption_en: {caption_en}")
 
 
 
 
58
 
59
- # β‘‘ λ²ˆμ—­
60
  try:
61
- result = translation_pipeline(caption_en)
62
- caption_ko = result[0]["translation_text"].strip()
63
  except Exception as e:
64
- logging.error(f"[ERROR] λ²ˆμ—­ 였λ₯˜: {e}")
65
- caption_ko = ""
66
- logging.info(f"[DEBUG] caption_ko: {caption_ko}")
 
67
 
68
- if not caption_ko:
69
- return "이미지에 λŒ€ν•œ μ„€λͺ…을 생성할 수 μ—†μŠ΅λ‹ˆλ‹€.", None
70
 
71
- # β‘’ TTS
72
- try:
73
- sr, wav = synthesize_tts(caption_ko)
74
- return caption_ko, (sr, wav)
75
- except Exception as e:
76
- logging.error(f"[ERROR] TTS 였λ₯˜: {e}")
77
- return caption_ko, None
78
-
79
- # ─────────────── 5. Gradio μΈν„°νŽ˜μ΄μŠ€ ───────────────
80
- with gr.Blocks(
81
- title="이미지 β†’ ν•œκΈ€ μΊ‘μ…˜ & μŒμ„± λ³€ν™˜",
82
- css="footer {display: none !important;}" # ν‘Έν„° 숨기기
83
- ) as demo:
84
  gr.Markdown(
85
- "## 이미지 β†’ ν•œκΈ€ μΊ‘μ…˜ & μŒμ„± λ³€ν™˜\n"
86
- "BLIP으둜 μ˜μ–΄ μΊ‘μ…˜ 생성 β†’ NLLB둜 ν•œκ΅­μ–΄ λ²ˆμ—­ β†’ VITS둜 μŒμ„± 생성"
87
  )
88
 
89
- # μž…λ ₯/좜λ ₯ μ»΄ν¬λ„ŒνŠΈ
90
- input_img = gr.Image(
91
- type="pil",
92
- sources=["upload", "webcam"],
93
- label="μž…λ ₯ 이미지"
94
- )
95
- caption_out = gr.Textbox(label="ν•œκΈ€ μΊ‘μ…˜")
96
- audio_out = gr.Audio(label="μŒμ„± μž¬μƒ", type="numpy")
97
-
98
- # 이미지가 λ³€κ²½(μ—…λ‘œλ“œβ€§μΊ‘μ²˜)될 λ•Œλ§ˆλ‹€ ν•¨μˆ˜ μžλ™ μ‹€ν–‰
99
- input_img.change(
100
- fn=describe_and_speak,
101
- inputs=input_img,
102
- outputs=[caption_out, audio_out],
103
- queue=True # λ™μ‹œ 접속 μ‹œ μ•ˆμ „
104
- )
 
 
 
105
 
106
  if __name__ == "__main__":
107
- demo.launch(debug=True)
 
1
  import gradio as gr
2
  import logging
3
  from PIL import Image
4
+ import torch
5
+ from uroman import Uroman
6
  from transformers import (
7
  BlipProcessor,
8
  BlipForConditionalGeneration,
9
  pipeline,
10
  AutoTokenizer,
11
+ VitsModel,
12
  )
 
 
13
 
 
14
  logging.basicConfig(level=logging.INFO)
15
 
16
+ # ───────── 1. λͺ¨λΈ λ‘œλ“œ ─────────
17
+ processor = BlipProcessor.from_pretrained(
18
+ "Salesforce/blip-image-captioning-large"
19
+ )
20
+ blip_model = BlipForConditionalGeneration.from_pretrained(
21
+ "Salesforce/blip-image-captioning-large"
22
+ ).to("cuda" if torch.cuda.is_available() else "cpu")
23
 
 
24
  translation_pipeline = pipeline(
25
  "translation",
26
  model="facebook/nllb-200-distilled-600M",
27
  src_lang="eng_Latn",
28
  tgt_lang="kor_Hang",
29
  max_length=200,
30
+ device=0 if torch.cuda.is_available() else -1,
31
  )
32
 
33
+ # --- TTS (ko / en) ---
34
+ tts_ko = VitsModel.from_pretrained("facebook/mms-tts-kor").to(
35
+ "cuda" if torch.cuda.is_available() else "cpu"
36
+ )
37
+ tok_ko = AutoTokenizer.from_pretrained("facebook/mms-tts-kor")
38
+
39
+ tts_en = VitsModel.from_pretrained("facebook/mms-tts-eng").to(
40
+ "cuda" if torch.cuda.is_available() else "cpu"
41
+ )
42
+ tok_en = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
43
 
44
  uroman = Uroman()
45
 
46
+
47
+ # ───────── 2. 곡톡 ν•¨μˆ˜ ─────────
48
+ def tts(model, tokenizer, text: str):
49
+ roman = uroman.romanize_string(text)
50
+ ids = tokenizer(roman, return_tensors="pt").input_ids.long().to(model.device)
51
  with torch.no_grad():
52
+ wav = model(input_ids=ids).waveform.squeeze().cpu().numpy()
53
+ return model.config.sampling_rate, wav
 
54
 
55
+
56
+ def generate(img: Image.Image, lang: str):
57
+ """
58
+ lang == "ko" β†’ ν•œκ΅­μ–΄ μΊ‘μ…˜+μŒμ„±
59
+ lang == "en" β†’ μ˜μ–΄ μΊ‘μ…˜+μŒμ„±
60
+ """
61
+ if img is None:
62
+ raise gr.Error("λ¨Όμ € 이미지λ₯Ό μ—…λ‘œλ“œν•˜μ„Έμš” πŸ“·")
63
 
64
  # β‘  μ˜μ–΄ μΊ‘μ…˜
65
+ pix = processor(images=img, return_tensors="pt").pixel_values.to(blip_model.device)
66
+ cap_en = processor.batch_decode(
67
+ blip_model.generate(pix, max_length=64), skip_special_tokens=True
68
+ )[0].strip()
69
+
70
+ if lang == "en":
71
+ sr, wav = tts(tts_en, tok_en, cap_en)
72
+ return cap_en, (sr, wav)
73
 
74
+ # β‘‘ λ²ˆμ—­(β†’ko)
75
  try:
76
+ cap_ko = translation_pipeline(cap_en)[0]["translation_text"].strip()
 
77
  except Exception as e:
78
+ logging.error(f"[ERROR] λ²ˆμ—­ μ‹€νŒ¨: {e}")
79
+ cap_ko = ""
80
+ if not cap_ko:
81
+ return "λ²ˆμ—­ 였λ₯˜κ°€ λ°œμƒν–ˆμŠ΅λ‹ˆλ‹€.", None
82
 
83
+ sr, wav = tts(tts_ko, tok_ko, cap_ko)
84
+ return cap_ko, (sr, wav)
85
 
86
+
87
+ # ───────── 3. Gradio UI ─────────
88
+ with gr.Blocks(title="Image β†’ Caption & TTS", css="footer{display:none;}") as demo:
 
 
 
 
 
 
 
 
 
 
89
  gr.Markdown(
90
+ "## 이미지 β†’ ν•œκΈ€ / English μΊ‘μ…˜ & μŒμ„± λ³€ν™˜\n"
91
+ "BLIP (caption) β†’ NLLB (translate) β†’ VITS (TTS)"
92
  )
93
 
94
+ img_state = gr.State() # 졜근 이미지 μ €μž₯
95
+
96
+ input_img = gr.Image(type="pil", label="πŸ“· 이미지 μ—…λ‘œλ“œ")
97
+ caption_box = gr.Textbox(label="πŸ“‘ μΊ‘μ…˜ κ²°κ³Ό")
98
+ audio_play = gr.Audio(label="πŸ”Š μŒμ„± μž¬μƒ", type="numpy")
99
+
100
+ with gr.Row():
101
+ ko_btn = gr.Button("ν•œκΈ€ 생성")
102
+ en_btn = gr.Button("English")
103
+
104
+ # 이미지 μ—…λ‘œλ“œ μ‹œ state μ—…λ°μ΄νŠΈ
105
+ def store_img(img):
106
+ return img
107
+
108
+ input_img.change(store_img, inputs=input_img, outputs=img_state, queue=False)
109
+
110
+ # λ²„νŠΌ ↔ 생성 ν•¨μˆ˜ μ—°κ²°
111
+ ko_btn.click(fn=lambda img: generate(img, "ko"), inputs=img_state, outputs=[caption_box, audio_play])
112
+ en_btn.click(fn=lambda img: generate(img, "en"), inputs=img_state, outputs=[caption_box, audio_play])
113
 
114
  if __name__ == "__main__":
115
+ demo.launch()