Spaces:
Sleeping
Sleeping
import gradio as gr | |
import logging | |
from PIL import Image | |
import torch | |
from uroman import Uroman | |
from transformers import ( | |
BlipProcessor, | |
BlipForConditionalGeneration, | |
pipeline, | |
AutoTokenizer, | |
VitsModel, | |
) | |
logging.basicConfig(level=logging.INFO) | |
processor = BlipProcessor.from_pretrained( | |
"Salesforce/blip-image-captioning-large" | |
) | |
blip_model = BlipForConditionalGeneration.from_pretrained( | |
"Salesforce/blip-image-captioning-large" | |
).to("cuda" if torch.cuda.is_available() else "cpu") | |
translation_pipeline = pipeline( | |
"translation", | |
model="facebook/nllb-200-distilled-600M", | |
src_lang="eng_Latn", | |
tgt_lang="kor_Hang", | |
max_length=200, | |
device=0 if torch.cuda.is_available() else -1, | |
) | |
tts_ko = VitsModel.from_pretrained("facebook/mms-tts-kor").to( | |
"cuda" if torch.cuda.is_available() else "cpu" | |
) | |
tok_ko = AutoTokenizer.from_pretrained("facebook/mms-tts-kor") | |
tts_en = VitsModel.from_pretrained("facebook/mms-tts-eng").to( | |
"cuda" if torch.cuda.is_available() else "cpu" | |
) | |
tok_en = AutoTokenizer.from_pretrained("facebook/mms-tts-eng") | |
uroman = Uroman() | |
def tts(model, tokenizer, text: str): | |
roman = uroman.romanize_string(text) | |
ids = tokenizer(roman, return_tensors="pt").input_ids.long().to(model.device) | |
with torch.no_grad(): | |
wav = model(input_ids=ids).waveform.squeeze().cpu().numpy() | |
return model.config.sampling_rate, wav | |
def generate(img: Image.Image, lang: str): | |
""" | |
lang == "ko" β νκ΅μ΄ μΊ‘μ +μμ± | |
lang == "en" β μμ΄ μΊ‘μ +μμ± | |
""" | |
if img is None: | |
raise gr.Error("μ΄λ―Έμ§λ₯Ό μ λ‘λνμΈμ π·") | |
pix = processor(images=img, return_tensors="pt").pixel_values.to(blip_model.device) | |
cap_en = processor.batch_decode( | |
blip_model.generate( | |
pix, | |
max_length=128, # μ΅λ κΈΈμ΄ β | |
min_length=20, # λ무 μ§§κ² λ©μΆλ κ² λ°©μ§ | |
num_beams=5, # λΉ νμ νμ§ β (μλλ μ‘°κΈ λλ €μ§) | |
temperature=0.7, # λ€μμ± μ‘°μ | |
repetition_penalty=1.1, | |
), | |
skip_special_tokens=True | |
)[0].strip() | |
if lang == "en": | |
sr, wav = tts(tts_en, tok_en, cap_en) | |
return cap_en, (sr, wav) | |
try: | |
cap_ko = translation_pipeline(cap_en)[0]["translation_text"].strip() | |
except Exception as e: | |
logging.error(f"[ERROR] λ²μ μ€ν¨: {e}") | |
cap_ko = "" | |
if not cap_ko: | |
return "λ²μ μ€λ₯κ° λ°μνμ΅λλ€.", None | |
sr, wav = tts(tts_ko, tok_ko, cap_ko) | |
return cap_ko, (sr, wav) | |
with gr.Blocks(title="ν½λ³΄μ΄μ€(Picvoice)", css="footer {display: none !important;}") as demo: | |
gr.Markdown( | |
"## ν½λ³΄μ΄μ€(Picvoice)" | |
) | |
img_state = gr.State() # μ΅κ·Ό μ΄λ―Έμ§ μ μ₯ | |
input_img = gr.Image(type="pil", label="π· μ΄λ―Έμ§ μ λ‘λ") | |
caption_box = gr.Textbox(label="π ν μ€νΈ μμ± κ²°κ³Ό") | |
audio_play = gr.Audio(label="π μμ± μ¬μ", type="numpy") | |
with gr.Row(): | |
ko_btn = gr.Button("νκΈλ‘ μμ±πͺ") | |
en_btn = gr.Button("μμ΄λ‘ μμ±πͺ") | |
def store_img(img): | |
return img | |
input_img.change(store_img, inputs=input_img, outputs=img_state, queue=False) | |
ko_btn.click(fn=lambda img: generate(img, "ko"), inputs=img_state, outputs=[caption_box, audio_play]) | |
en_btn.click(fn=lambda img: generate(img, "en"), inputs=img_state, outputs=[caption_box, audio_play]) | |
if __name__ == "__main__": | |
demo.launch() | |