Spaces:
Sleeping
Sleeping
File size: 3,672 Bytes
40b810e 8586da3 40b810e 8586da3 40b810e cb2bc8c 8586da3 40b810e 7d2e0c9 8586da3 40b810e 8586da3 40b810e 820f54d 6ee0045 8586da3 40b810e 8586da3 6ee0045 8586da3 40b810e 7d2e0c9 8586da3 40b810e 8586da3 40b810e 8586da3 40b810e 8586da3 40b810e 8586da3 40b810e 8586da3 7d2e0c9 8586da3 7d2e0c9 8586da3 40b810e 8586da3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import gradio as gr
import logging
from PIL import Image
import torch
from uroman import Uroman
from transformers import (
BlipProcessor,
BlipForConditionalGeneration,
pipeline,
AutoTokenizer,
VitsModel,
)
logging.basicConfig(level=logging.INFO)
# βββββββββ 1. λͺ¨λΈ λ‘λ βββββββββ
processor = BlipProcessor.from_pretrained(
"Salesforce/blip-image-captioning-large"
)
blip_model = BlipForConditionalGeneration.from_pretrained(
"Salesforce/blip-image-captioning-large"
).to("cuda" if torch.cuda.is_available() else "cpu")
translation_pipeline = pipeline(
"translation",
model="facebook/nllb-200-distilled-600M",
src_lang="eng_Latn",
tgt_lang="kor_Hang",
max_length=200,
device=0 if torch.cuda.is_available() else -1,
)
# --- TTS (ko / en) ---
tts_ko = VitsModel.from_pretrained("facebook/mms-tts-kor").to(
"cuda" if torch.cuda.is_available() else "cpu"
)
tok_ko = AutoTokenizer.from_pretrained("facebook/mms-tts-kor")
tts_en = VitsModel.from_pretrained("facebook/mms-tts-eng").to(
"cuda" if torch.cuda.is_available() else "cpu"
)
tok_en = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
uroman = Uroman()
# βββββββββ 2. κ³΅ν΅ ν¨μ βββββββββ
def tts(model, tokenizer, text: str):
roman = uroman.romanize_string(text)
ids = tokenizer(roman, return_tensors="pt").input_ids.long().to(model.device)
with torch.no_grad():
wav = model(input_ids=ids).waveform.squeeze().cpu().numpy()
return model.config.sampling_rate, wav
def generate(img: Image.Image, lang: str):
"""
lang == "ko" β νκ΅μ΄ μΊ‘μ
+μμ±
lang == "en" β μμ΄ μΊ‘μ
+μμ±
"""
if img is None:
raise gr.Error("λ¨Όμ μ΄λ―Έμ§λ₯Ό μ
λ‘λνμΈμ π·")
# β μμ΄ μΊ‘μ
pix = processor(images=img, return_tensors="pt").pixel_values.to(blip_model.device)
cap_en = processor.batch_decode(
blip_model.generate(pix, max_length=64), skip_special_tokens=True
)[0].strip()
if lang == "en":
sr, wav = tts(tts_en, tok_en, cap_en)
return cap_en, (sr, wav)
# β‘ λ²μ(βko)
try:
cap_ko = translation_pipeline(cap_en)[0]["translation_text"].strip()
except Exception as e:
logging.error(f"[ERROR] λ²μ μ€ν¨: {e}")
cap_ko = ""
if not cap_ko:
return "λ²μ μ€λ₯κ° λ°μνμ΅λλ€.", None
sr, wav = tts(tts_ko, tok_ko, cap_ko)
return cap_ko, (sr, wav)
# βββββββββ 3. Gradio UI βββββββββ
with gr.Blocks(title="Image β Caption & TTS", css="footer{display:none;}") as demo:
gr.Markdown(
"## μ΄λ―Έμ§ β νκΈ / English μΊ‘μ
& μμ± λ³ν\n"
"BLIP (caption) β NLLB (translate) β VITS (TTS)"
)
img_state = gr.State() # μ΅κ·Ό μ΄λ―Έμ§ μ μ₯
input_img = gr.Image(type="pil", label="π· μ΄λ―Έμ§ μ
λ‘λ")
caption_box = gr.Textbox(label="π μΊ‘μ
κ²°κ³Ό")
audio_play = gr.Audio(label="π μμ± μ¬μ", type="numpy")
with gr.Row():
ko_btn = gr.Button("νκΈ μμ±")
en_btn = gr.Button("English")
# μ΄λ―Έμ§ μ
λ‘λ μ state μ
λ°μ΄νΈ
def store_img(img):
return img
input_img.change(store_img, inputs=input_img, outputs=img_state, queue=False)
# λ²νΌ β μμ± ν¨μ μ°κ²°
ko_btn.click(fn=lambda img: generate(img, "ko"), inputs=img_state, outputs=[caption_box, audio_play])
en_btn.click(fn=lambda img: generate(img, "en"), inputs=img_state, outputs=[caption_box, audio_play])
if __name__ == "__main__":
demo.launch()
|