import gradio as gr import logging from PIL import Image import torch from uroman import Uroman from transformers import ( BlipProcessor, BlipForConditionalGeneration, pipeline, AutoTokenizer, VitsModel, ) logging.basicConfig(level=logging.INFO) processor = BlipProcessor.from_pretrained( "Salesforce/blip-image-captioning-large" ) blip_model = BlipForConditionalGeneration.from_pretrained( "Salesforce/blip-image-captioning-large" ).to("cuda" if torch.cuda.is_available() else "cpu") translation_pipeline = pipeline( "translation", model="facebook/nllb-200-distilled-600M", src_lang="eng_Latn", tgt_lang="kor_Hang", max_length=200, device=0 if torch.cuda.is_available() else -1, ) tts_ko = VitsModel.from_pretrained("facebook/mms-tts-kor").to( "cuda" if torch.cuda.is_available() else "cpu" ) tok_ko = AutoTokenizer.from_pretrained("facebook/mms-tts-kor") tts_en = VitsModel.from_pretrained("facebook/mms-tts-eng").to( "cuda" if torch.cuda.is_available() else "cpu" ) tok_en = AutoTokenizer.from_pretrained("facebook/mms-tts-eng") uroman = Uroman() def tts(model, tokenizer, text: str): roman = uroman.romanize_string(text) ids = tokenizer(roman, return_tensors="pt").input_ids.long().to(model.device) with torch.no_grad(): wav = model(input_ids=ids).waveform.squeeze().cpu().numpy() return model.config.sampling_rate, wav def generate(img: Image.Image, lang: str): """ lang == "ko" → 한국어 캡션+음성 lang == "en" → 영어 캡션+음성 """ if img is None: raise gr.Error("이미지를 업로드하세요 📷") pix = processor(images=img, return_tensors="pt").pixel_values.to(blip_model.device) cap_en = processor.batch_decode( blip_model.generate( pix, max_length=128, # 최대 길이 ↑ min_length=20, # 너무 짧게 멈추는 것 방지 num_beams=5, # 빔 탐색 품질 ↑ (속도는 조금 느려짐) temperature=0.7, # 다양성 조절 repetition_penalty=1.1, ), skip_special_tokens=True )[0].strip() if lang == "en": sr, wav = tts(tts_en, tok_en, cap_en) return cap_en, (sr, wav) try: cap_ko = translation_pipeline(cap_en)[0]["translation_text"].strip() except Exception as e: logging.error(f"[ERROR] 번역 실패: {e}") cap_ko = "" if not cap_ko: return "번역 오류가 발생했습니다.", None sr, wav = tts(tts_ko, tok_ko, cap_ko) return cap_ko, (sr, wav) with gr.Blocks(title="픽보이스(Picvoice)", css="footer {display: none !important;}") as demo: gr.Markdown( "## 픽보이스(Picvoice)" ) img_state = gr.State() # 최근 이미지 저장 input_img = gr.Image(type="pil", label="📷 이미지 업로드") caption_box = gr.Textbox(label="📑 텍스트 생성 결과") audio_play = gr.Audio(label="🔊 음성 재생", type="numpy") with gr.Row(): ko_btn = gr.Button("한글로 생성🪄") en_btn = gr.Button("영어로 생성🪄") def store_img(img): return img input_img.change(store_img, inputs=input_img, outputs=img_state, queue=False) ko_btn.click(fn=lambda img: generate(img, "ko"), inputs=img_state, outputs=[caption_box, audio_play]) en_btn.click(fn=lambda img: generate(img, "en"), inputs=img_state, outputs=[caption_box, audio_play]) if __name__ == "__main__": demo.launch()