imagetospeech / app.py
yongyeol's picture
Update app.py
8586da3 verified
raw
history blame
3.67 kB
import gradio as gr
import logging
from PIL import Image
import torch
from uroman import Uroman
from transformers import (
BlipProcessor,
BlipForConditionalGeneration,
pipeline,
AutoTokenizer,
VitsModel,
)
logging.basicConfig(level=logging.INFO)
# ───────── 1. λͺ¨λΈ λ‘œλ“œ ─────────
processor = BlipProcessor.from_pretrained(
"Salesforce/blip-image-captioning-large"
)
blip_model = BlipForConditionalGeneration.from_pretrained(
"Salesforce/blip-image-captioning-large"
).to("cuda" if torch.cuda.is_available() else "cpu")
translation_pipeline = pipeline(
"translation",
model="facebook/nllb-200-distilled-600M",
src_lang="eng_Latn",
tgt_lang="kor_Hang",
max_length=200,
device=0 if torch.cuda.is_available() else -1,
)
# --- TTS (ko / en) ---
tts_ko = VitsModel.from_pretrained("facebook/mms-tts-kor").to(
"cuda" if torch.cuda.is_available() else "cpu"
)
tok_ko = AutoTokenizer.from_pretrained("facebook/mms-tts-kor")
tts_en = VitsModel.from_pretrained("facebook/mms-tts-eng").to(
"cuda" if torch.cuda.is_available() else "cpu"
)
tok_en = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
uroman = Uroman()
# ───────── 2. 곡톡 ν•¨μˆ˜ ─────────
def tts(model, tokenizer, text: str):
roman = uroman.romanize_string(text)
ids = tokenizer(roman, return_tensors="pt").input_ids.long().to(model.device)
with torch.no_grad():
wav = model(input_ids=ids).waveform.squeeze().cpu().numpy()
return model.config.sampling_rate, wav
def generate(img: Image.Image, lang: str):
"""
lang == "ko" β†’ ν•œκ΅­μ–΄ μΊ‘μ…˜+μŒμ„±
lang == "en" β†’ μ˜μ–΄ μΊ‘μ…˜+μŒμ„±
"""
if img is None:
raise gr.Error("λ¨Όμ € 이미지λ₯Ό μ—…λ‘œλ“œν•˜μ„Έμš” πŸ“·")
# β‘  μ˜μ–΄ μΊ‘μ…˜
pix = processor(images=img, return_tensors="pt").pixel_values.to(blip_model.device)
cap_en = processor.batch_decode(
blip_model.generate(pix, max_length=64), skip_special_tokens=True
)[0].strip()
if lang == "en":
sr, wav = tts(tts_en, tok_en, cap_en)
return cap_en, (sr, wav)
# β‘‘ λ²ˆμ—­(β†’ko)
try:
cap_ko = translation_pipeline(cap_en)[0]["translation_text"].strip()
except Exception as e:
logging.error(f"[ERROR] λ²ˆμ—­ μ‹€νŒ¨: {e}")
cap_ko = ""
if not cap_ko:
return "λ²ˆμ—­ 였λ₯˜κ°€ λ°œμƒν–ˆμŠ΅λ‹ˆλ‹€.", None
sr, wav = tts(tts_ko, tok_ko, cap_ko)
return cap_ko, (sr, wav)
# ───────── 3. Gradio UI ─────────
with gr.Blocks(title="Image β†’ Caption & TTS", css="footer{display:none;}") as demo:
gr.Markdown(
"## 이미지 β†’ ν•œκΈ€ / English μΊ‘μ…˜ & μŒμ„± λ³€ν™˜\n"
"BLIP (caption) β†’ NLLB (translate) β†’ VITS (TTS)"
)
img_state = gr.State() # 졜근 이미지 μ €μž₯
input_img = gr.Image(type="pil", label="πŸ“· 이미지 μ—…λ‘œλ“œ")
caption_box = gr.Textbox(label="πŸ“‘ μΊ‘μ…˜ κ²°κ³Ό")
audio_play = gr.Audio(label="πŸ”Š μŒμ„± μž¬μƒ", type="numpy")
with gr.Row():
ko_btn = gr.Button("ν•œκΈ€ 생성")
en_btn = gr.Button("English")
# 이미지 μ—…λ‘œλ“œ μ‹œ state μ—…λ°μ΄νŠΈ
def store_img(img):
return img
input_img.change(store_img, inputs=input_img, outputs=img_state, queue=False)
# λ²„νŠΌ ↔ 생성 ν•¨μˆ˜ μ—°κ²°
ko_btn.click(fn=lambda img: generate(img, "ko"), inputs=img_state, outputs=[caption_box, audio_play])
en_btn.click(fn=lambda img: generate(img, "en"), inputs=img_state, outputs=[caption_box, audio_play])
if __name__ == "__main__":
demo.launch()