Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import logging | |
| from PIL import Image | |
| import torch | |
| from uroman import Uroman | |
| from transformers import ( | |
| BlipProcessor, | |
| BlipForConditionalGeneration, | |
| pipeline, | |
| AutoTokenizer, | |
| VitsModel, | |
| ) | |
| logging.basicConfig(level=logging.INFO) | |
| processor = BlipProcessor.from_pretrained( | |
| "Salesforce/blip-image-captioning-large" | |
| ) | |
| blip_model = BlipForConditionalGeneration.from_pretrained( | |
| "Salesforce/blip-image-captioning-large" | |
| ).to("cuda" if torch.cuda.is_available() else "cpu") | |
| translation_pipeline = pipeline( | |
| "translation", | |
| model="facebook/nllb-200-distilled-600M", | |
| src_lang="eng_Latn", | |
| tgt_lang="kor_Hang", | |
| max_length=200, | |
| device=0 if torch.cuda.is_available() else -1, | |
| ) | |
| tts_ko = VitsModel.from_pretrained("facebook/mms-tts-kor").to( | |
| "cuda" if torch.cuda.is_available() else "cpu" | |
| ) | |
| tok_ko = AutoTokenizer.from_pretrained("facebook/mms-tts-kor") | |
| tts_en = VitsModel.from_pretrained("facebook/mms-tts-eng").to( | |
| "cuda" if torch.cuda.is_available() else "cpu" | |
| ) | |
| tok_en = AutoTokenizer.from_pretrained("facebook/mms-tts-eng") | |
| uroman = Uroman() | |
| def tts(model, tokenizer, text: str): | |
| roman = uroman.romanize_string(text) | |
| ids = tokenizer(roman, return_tensors="pt").input_ids.long().to(model.device) | |
| with torch.no_grad(): | |
| wav = model(input_ids=ids).waveform.squeeze().cpu().numpy() | |
| return model.config.sampling_rate, wav | |
| def generate(img: Image.Image, lang: str): | |
| """ | |
| lang == "ko" β νκ΅μ΄ μΊ‘μ +μμ± | |
| lang == "en" β μμ΄ μΊ‘μ +μμ± | |
| """ | |
| if img is None: | |
| raise gr.Error("μ΄λ―Έμ§λ₯Ό μ λ‘λνμΈμ π·") | |
| pix = processor(images=img, return_tensors="pt").pixel_values.to(blip_model.device) | |
| cap_en = processor.batch_decode( | |
| blip_model.generate(pix, max_length=64), skip_special_tokens=True | |
| )[0].strip() | |
| if lang == "en": | |
| sr, wav = tts(tts_en, tok_en, cap_en) | |
| return cap_en, (sr, wav) | |
| try: | |
| cap_ko = translation_pipeline(cap_en)[0]["translation_text"].strip() | |
| except Exception as e: | |
| logging.error(f"[ERROR] λ²μ μ€ν¨: {e}") | |
| cap_ko = "" | |
| if not cap_ko: | |
| return "λ²μ μ€λ₯κ° λ°μνμ΅λλ€.", None | |
| sr, wav = tts(tts_ko, tok_ko, cap_ko) | |
| return cap_ko, (sr, wav) | |
| with gr.Blocks(title="Image β Caption & TTS", css="footer {display: none !important;}") as demo: | |
| gr.Markdown( | |
| "## μ΄λ―Έμ§ β νκΈ / English μΊ‘μ & μμ± λ³ν\n" | |
| "BLIP (caption) β NLLB (translate) β VITS (TTS)" | |
| ) | |
| img_state = gr.State() # μ΅κ·Ό μ΄λ―Έμ§ μ μ₯ | |
| input_img = gr.Image(type="pil", label="π· μ΄λ―Έμ§ μ λ‘λ") | |
| caption_box = gr.Textbox(label="π μΊ‘μ κ²°κ³Ό") | |
| audio_play = gr.Audio(label="π μμ± μ¬μ", type="numpy") | |
| with gr.Row(): | |
| ko_btn = gr.Button("νκΈλ‘ μμ±πͺ") | |
| en_btn = gr.Button("μμ΄λ‘ μμ±πͺ") | |
| def store_img(img): | |
| return img | |
| input_img.change(store_img, inputs=input_img, outputs=img_state, queue=False) | |
| ko_btn.click(fn=lambda img: generate(img, "ko"), inputs=img_state, outputs=[caption_box, audio_play]) | |
| en_btn.click(fn=lambda img: generate(img, "en"), inputs=img_state, outputs=[caption_box, audio_play]) | |
| if __name__ == "__main__": | |
| demo.launch() | |