Spaces:
Sleeping
Sleeping
import gradio as gr | |
import logging | |
from PIL import Image | |
from transformers import ( | |
BlipProcessor, | |
BlipForConditionalGeneration, | |
pipeline, | |
AutoTokenizer, | |
VitsModel | |
) | |
import torch | |
from uroman import Uroman | |
# βββββββββββββββ λ‘κΉ μ€μ βββββββββββββββ | |
logging.basicConfig(level=logging.INFO) | |
# βββββββββββββββ 1. BLIP μ΄λ―Έμ§ μΊ‘μ λ (μμ΄ μμ±) βββββββββββββββ | |
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large") | |
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large") | |
blip_model.to("cuda" if torch.cuda.is_available() else "cpu") | |
# βββββββββββββββ 2. μμ΄ β νκ΅μ΄ λ²μ βββββββββββββββ | |
translation_pipeline = pipeline( | |
"translation", | |
model="facebook/nllb-200-distilled-600M", | |
src_lang="eng_Latn", | |
tgt_lang="kor_Hang", | |
max_length=200, | |
device=0 if torch.cuda.is_available() else -1 | |
) | |
# βββββββββββββββ 3. νκ΅μ΄ TTS βββββββββββββββ | |
tts_model = VitsModel.from_pretrained("facebook/mms-tts-kor") | |
tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-kor") | |
tts_model.to("cuda" if torch.cuda.is_available() else "cpu") | |
uroman = Uroman() | |
def synthesize_tts(text: str): | |
"""νκΈ λ¬Έμ₯μ VITSβTTS ννμΌλ‘ λ³ν""" | |
romanized = uroman.romanize_string(text) | |
inputs = tts_tokenizer(romanized, return_tensors="pt") | |
input_ids = inputs["input_ids"].long().to(tts_model.device) | |
with torch.no_grad(): | |
output = tts_model(input_ids=input_ids) | |
waveform = output.waveform.squeeze().cpu().numpy() | |
return tts_model.config.sampling_rate, waveform | |
# βββββββββββββββ 4. μ΄λ―Έμ§ β μΊ‘μ + λ²μ + μμ± μΆλ ₯ βββββββββββββββ | |
def describe_and_speak(img: Image.Image): | |
logging.info("[DEBUG] describe_and_speak νΈμΆ") | |
# β μμ΄ μΊ‘μ | |
pixel_values = processor(images=img, return_tensors="pt").pixel_values.to(blip_model.device) | |
generated_ids = blip_model.generate(pixel_values, max_length=64) | |
caption_en = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() | |
logging.info(f"[DEBUG] caption_en: {caption_en}") | |
# β‘ λ²μ | |
try: | |
result = translation_pipeline(caption_en) | |
caption_ko = result[0]["translation_text"].strip() | |
except Exception as e: | |
logging.error(f"[ERROR] λ²μ μ€λ₯: {e}") | |
caption_ko = "" | |
logging.info(f"[DEBUG] caption_ko: {caption_ko}") | |
if not caption_ko: | |
return "μ΄λ―Έμ§μ λν μ€λͺ μ μμ±ν μ μμ΅λλ€.", None | |
# β’ TTS | |
try: | |
sr, wav = synthesize_tts(caption_ko) | |
return caption_ko, (sr, wav) | |
except Exception as e: | |
logging.error(f"[ERROR] TTS μ€λ₯: {e}") | |
return caption_ko, None | |
# βββββββββββββββ 5. Gradio μΈν°νμ΄μ€ βββββββββββββββ | |
with gr.Blocks( | |
title="μ΄λ―Έμ§ β νκΈ μΊ‘μ & μμ± λ³ν", | |
css="footer {display: none !important;}" # νΈν° μ¨κΈ°κΈ° | |
) as demo: | |
gr.Markdown( | |
"## μ΄λ―Έμ§ β νκΈ μΊ‘μ & μμ± λ³ν\n" | |
"BLIPμΌλ‘ μμ΄ μΊ‘μ μμ± β NLLBλ‘ νκ΅μ΄ λ²μ β VITSλ‘ μμ± μμ±" | |
) | |
# μ λ ₯/μΆλ ₯ μ»΄ν¬λνΈ | |
input_img = gr.Image( | |
type="pil", | |
sources=["upload", "webcam"], | |
label="μ λ ₯ μ΄λ―Έμ§" | |
) | |
caption_out = gr.Textbox(label="νκΈ μΊ‘μ ") | |
audio_out = gr.Audio(label="μμ± μ¬μ", type="numpy") | |
# μ΄λ―Έμ§κ° λ³κ²½(μ λ‘λβ§μΊ‘μ²)λ λλ§λ€ ν¨μ μλ μ€ν | |
input_img.change( | |
fn=describe_and_speak, | |
inputs=input_img, | |
outputs=[caption_out, audio_out], | |
queue=True # λμ μ μ μ μμ | |
) | |
if __name__ == "__main__": | |
demo.launch(debug=True) | |