imagetospeech / app.py
yongyeol's picture
Update app.py
7d2e0c9 verified
raw
history blame
4.09 kB
import gradio as gr
import logging
from PIL import Image
from transformers import (
BlipProcessor,
BlipForConditionalGeneration,
pipeline,
AutoTokenizer,
VitsModel
)
import torch
from uroman import Uroman
# ─────────────── λ‘œκΉ… μ„€μ • ───────────────
logging.basicConfig(level=logging.INFO)
# ─────────────── 1. BLIP 이미지 캑셔닝 (μ˜μ–΄ 생성) ───────────────
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
blip_model.to("cuda" if torch.cuda.is_available() else "cpu")
# ─────────────── 2. μ˜μ–΄ β†’ ν•œκ΅­μ–΄ λ²ˆμ—­ ───────────────
translation_pipeline = pipeline(
"translation",
model="facebook/nllb-200-distilled-600M",
src_lang="eng_Latn",
tgt_lang="kor_Hang",
max_length=200,
device=0 if torch.cuda.is_available() else -1
)
# ─────────────── 3. ν•œκ΅­μ–΄ TTS ───────────────
tts_model = VitsModel.from_pretrained("facebook/mms-tts-kor")
tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-kor")
tts_model.to("cuda" if torch.cuda.is_available() else "cpu")
uroman = Uroman()
def synthesize_tts(text: str):
"""ν•œκΈ€ λ¬Έμž₯을 VITS‑TTS νŒŒν˜•μœΌλ‘œ λ³€ν™˜"""
romanized = uroman.romanize_string(text)
inputs = tts_tokenizer(romanized, return_tensors="pt")
input_ids = inputs["input_ids"].long().to(tts_model.device)
with torch.no_grad():
output = tts_model(input_ids=input_ids)
waveform = output.waveform.squeeze().cpu().numpy()
return tts_model.config.sampling_rate, waveform
# ─────────────── 4. 이미지 β†’ μΊ‘μ…˜ + λ²ˆμ—­ + μŒμ„± 좜λ ₯ ───────────────
def describe_and_speak(img: Image.Image):
logging.info("[DEBUG] describe_and_speak 호좜")
# β‘  μ˜μ–΄ μΊ‘μ…˜
pixel_values = processor(images=img, return_tensors="pt").pixel_values.to(blip_model.device)
generated_ids = blip_model.generate(pixel_values, max_length=64)
caption_en = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
logging.info(f"[DEBUG] caption_en: {caption_en}")
# β‘‘ λ²ˆμ—­
try:
result = translation_pipeline(caption_en)
caption_ko = result[0]["translation_text"].strip()
except Exception as e:
logging.error(f"[ERROR] λ²ˆμ—­ 였λ₯˜: {e}")
caption_ko = ""
logging.info(f"[DEBUG] caption_ko: {caption_ko}")
if not caption_ko:
return "이미지에 λŒ€ν•œ μ„€λͺ…을 생성할 수 μ—†μŠ΅λ‹ˆλ‹€.", None
# β‘’ TTS
try:
sr, wav = synthesize_tts(caption_ko)
return caption_ko, (sr, wav)
except Exception as e:
logging.error(f"[ERROR] TTS 였λ₯˜: {e}")
return caption_ko, None
# ─────────────── 5. Gradio μΈν„°νŽ˜μ΄μŠ€ ───────────────
with gr.Blocks(
title="이미지 β†’ ν•œκΈ€ μΊ‘μ…˜ & μŒμ„± λ³€ν™˜",
css="footer {display: none !important;}" # ν‘Έν„° 숨기기
) as demo:
gr.Markdown(
"## 이미지 β†’ ν•œκΈ€ μΊ‘μ…˜ & μŒμ„± λ³€ν™˜\n"
"BLIP으둜 μ˜μ–΄ μΊ‘μ…˜ 생성 β†’ NLLB둜 ν•œκ΅­μ–΄ λ²ˆμ—­ β†’ VITS둜 μŒμ„± 생성"
)
# μž…λ ₯/좜λ ₯ μ»΄ν¬λ„ŒνŠΈ
input_img = gr.Image(
type="pil",
sources=["upload", "webcam"],
label="μž…λ ₯ 이미지"
)
caption_out = gr.Textbox(label="ν•œκΈ€ μΊ‘μ…˜")
audio_out = gr.Audio(label="μŒμ„± μž¬μƒ", type="numpy")
# 이미지가 λ³€κ²½(μ—…λ‘œλ“œβ€§μΊ‘μ²˜)될 λ•Œλ§ˆλ‹€ ν•¨μˆ˜ μžλ™ μ‹€ν–‰
input_img.change(
fn=describe_and_speak,
inputs=input_img,
outputs=[caption_out, audio_out],
queue=True # λ™μ‹œ 접속 μ‹œ μ•ˆμ „
)
if __name__ == "__main__":
demo.launch(debug=True)