Spaces:
Sleeping
Sleeping
File size: 3,325 Bytes
40b810e 8586da3 40b810e 8586da3 40b810e cb2bc8c f5e6532 8586da3 40b810e 7d2e0c9 8586da3 40b810e f5e6532 8586da3 40b810e 820f54d 6ee0045 8586da3 40b810e 8586da3 6ee0045 8586da3 f5e6532 40b810e 8586da3 40b810e f5e6532 40b810e 8586da3 40b810e 8586da3 40b810e 8586da3 40b810e 8586da3 f5e6532 7d2e0c9 8586da3 7d2e0c9 8586da3 f5e6532 8586da3 f5e6532 8586da3 40b810e 8586da3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import gradio as gr
import logging
from PIL import Image
import torch
from uroman import Uroman
from transformers import (
BlipProcessor,
BlipForConditionalGeneration,
pipeline,
AutoTokenizer,
VitsModel,
)
logging.basicConfig(level=logging.INFO)
processor = BlipProcessor.from_pretrained(
"Salesforce/blip-image-captioning-large"
)
blip_model = BlipForConditionalGeneration.from_pretrained(
"Salesforce/blip-image-captioning-large"
).to("cuda" if torch.cuda.is_available() else "cpu")
translation_pipeline = pipeline(
"translation",
model="facebook/nllb-200-distilled-600M",
src_lang="eng_Latn",
tgt_lang="kor_Hang",
max_length=200,
device=0 if torch.cuda.is_available() else -1,
)
tts_ko = VitsModel.from_pretrained("facebook/mms-tts-kor").to(
"cuda" if torch.cuda.is_available() else "cpu"
)
tok_ko = AutoTokenizer.from_pretrained("facebook/mms-tts-kor")
tts_en = VitsModel.from_pretrained("facebook/mms-tts-eng").to(
"cuda" if torch.cuda.is_available() else "cpu"
)
tok_en = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
uroman = Uroman()
def tts(model, tokenizer, text: str):
roman = uroman.romanize_string(text)
ids = tokenizer(roman, return_tensors="pt").input_ids.long().to(model.device)
with torch.no_grad():
wav = model(input_ids=ids).waveform.squeeze().cpu().numpy()
return model.config.sampling_rate, wav
def generate(img: Image.Image, lang: str):
"""
lang == "ko" β νκ΅μ΄ μΊ‘μ
+μμ±
lang == "en" β μμ΄ μΊ‘μ
+μμ±
"""
if img is None:
raise gr.Error("μ΄λ―Έμ§λ₯Ό μ
λ‘λνμΈμ π·")
pix = processor(images=img, return_tensors="pt").pixel_values.to(blip_model.device)
cap_en = processor.batch_decode(
blip_model.generate(pix, max_length=64), skip_special_tokens=True
)[0].strip()
if lang == "en":
sr, wav = tts(tts_en, tok_en, cap_en)
return cap_en, (sr, wav)
try:
cap_ko = translation_pipeline(cap_en)[0]["translation_text"].strip()
except Exception as e:
logging.error(f"[ERROR] λ²μ μ€ν¨: {e}")
cap_ko = ""
if not cap_ko:
return "λ²μ μ€λ₯κ° λ°μνμ΅λλ€.", None
sr, wav = tts(tts_ko, tok_ko, cap_ko)
return cap_ko, (sr, wav)
with gr.Blocks(title="Image β Caption & TTS", css="footer {display: none !important;}") as demo:
gr.Markdown(
"## μ΄λ―Έμ§ β νκΈ / English μΊ‘μ
& μμ± λ³ν\n"
"BLIP (caption) β NLLB (translate) β VITS (TTS)"
)
img_state = gr.State() # μ΅κ·Ό μ΄λ―Έμ§ μ μ₯
input_img = gr.Image(type="pil", label="π· μ΄λ―Έμ§ μ
λ‘λ")
caption_box = gr.Textbox(label="π μΊ‘μ
κ²°κ³Ό")
audio_play = gr.Audio(label="π μμ± μ¬μ", type="numpy")
with gr.Row():
ko_btn = gr.Button("νκΈλ‘ μμ±πͺ")
en_btn = gr.Button("μμ΄λ‘ μμ±πͺ")
def store_img(img):
return img
input_img.change(store_img, inputs=input_img, outputs=img_state, queue=False)
ko_btn.click(fn=lambda img: generate(img, "ko"), inputs=img_state, outputs=[caption_box, audio_play])
en_btn.click(fn=lambda img: generate(img, "en"), inputs=img_state, outputs=[caption_box, audio_play])
if __name__ == "__main__":
demo.launch()
|