imagetospeech / app.py
yongyeol's picture
Update app.py
dc9d713 verified
import gradio as gr
import logging
from PIL import Image
import torch
from uroman import Uroman
from transformers import (
BlipProcessor,
BlipForConditionalGeneration,
pipeline,
AutoTokenizer,
VitsModel,
)
logging.basicConfig(level=logging.INFO)
processor = BlipProcessor.from_pretrained(
"Salesforce/blip-image-captioning-large"
)
blip_model = BlipForConditionalGeneration.from_pretrained(
"Salesforce/blip-image-captioning-large"
).to("cuda" if torch.cuda.is_available() else "cpu")
translation_pipeline = pipeline(
"translation",
model="facebook/nllb-200-distilled-600M",
src_lang="eng_Latn",
tgt_lang="kor_Hang",
max_length=200,
device=0 if torch.cuda.is_available() else -1,
)
tts_ko = VitsModel.from_pretrained("facebook/mms-tts-kor").to(
"cuda" if torch.cuda.is_available() else "cpu"
)
tok_ko = AutoTokenizer.from_pretrained("facebook/mms-tts-kor")
tts_en = VitsModel.from_pretrained("facebook/mms-tts-eng").to(
"cuda" if torch.cuda.is_available() else "cpu"
)
tok_en = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
uroman = Uroman()
def tts(model, tokenizer, text: str):
roman = uroman.romanize_string(text)
ids = tokenizer(roman, return_tensors="pt").input_ids.long().to(model.device)
with torch.no_grad():
wav = model(input_ids=ids).waveform.squeeze().cpu().numpy()
return model.config.sampling_rate, wav
def generate(img: Image.Image, lang: str):
"""
lang == "ko" β†’ ν•œκ΅­μ–΄ μΊ‘μ…˜+μŒμ„±
lang == "en" β†’ μ˜μ–΄ μΊ‘μ…˜+μŒμ„±
"""
if img is None:
raise gr.Error("이미지λ₯Ό μ—…λ‘œλ“œν•˜μ„Έμš” πŸ“·")
pix = processor(images=img, return_tensors="pt").pixel_values.to(blip_model.device)
cap_en = processor.batch_decode(
blip_model.generate(
pix,
max_length=128, # μ΅œλŒ€ 길이 ↑
min_length=20, # λ„ˆλ¬΄ 짧게 λ©ˆμΆ”λŠ” 것 λ°©μ§€
num_beams=5, # λΉ” 탐색 ν’ˆμ§ˆ ↑ (μ†λ„λŠ” 쑰금 느렀짐)
temperature=0.7, # λ‹€μ–‘μ„± 쑰절
repetition_penalty=1.1,
),
skip_special_tokens=True
)[0].strip()
if lang == "en":
sr, wav = tts(tts_en, tok_en, cap_en)
return cap_en, (sr, wav)
try:
cap_ko = translation_pipeline(cap_en)[0]["translation_text"].strip()
except Exception as e:
logging.error(f"[ERROR] λ²ˆμ—­ μ‹€νŒ¨: {e}")
cap_ko = ""
if not cap_ko:
return "λ²ˆμ—­ 였λ₯˜κ°€ λ°œμƒν–ˆμŠ΅λ‹ˆλ‹€.", None
sr, wav = tts(tts_ko, tok_ko, cap_ko)
return cap_ko, (sr, wav)
with gr.Blocks(title="ν”½λ³΄μ΄μŠ€(Picvoice)", css="footer {display: none !important;}") as demo:
gr.Markdown(
"## ν”½λ³΄μ΄μŠ€(Picvoice)"
)
img_state = gr.State() # 졜근 이미지 μ €μž₯
input_img = gr.Image(type="pil", label="πŸ“· 이미지 μ—…λ‘œλ“œ")
caption_box = gr.Textbox(label="πŸ“‘ ν…μŠ€νŠΈ 생성 κ²°κ³Ό")
audio_play = gr.Audio(label="πŸ”Š μŒμ„± μž¬μƒ", type="numpy")
with gr.Row():
ko_btn = gr.Button("ν•œκΈ€λ‘œ 생성πŸͺ„")
en_btn = gr.Button("μ˜μ–΄λ‘œ 생성πŸͺ„")
def store_img(img):
return img
input_img.change(store_img, inputs=input_img, outputs=img_state, queue=False)
ko_btn.click(fn=lambda img: generate(img, "ko"), inputs=img_state, outputs=[caption_box, audio_play])
en_btn.click(fn=lambda img: generate(img, "en"), inputs=img_state, outputs=[caption_box, audio_play])
if __name__ == "__main__":
demo.launch()