Spellcheck / app.py
yongyeol's picture
Update app.py
cb2bc8c verified
raw
history blame
3.33 kB
import gradio as gr import logging from PIL import Image from transformers import ( BlipProcessor, BlipForConditionalGeneration, pipeline, AutoTokenizer, VitsModel ) import torch
─────────────── λ‘œκΉ… μ„€μ • ───────────────
logging.basicConfig(level=logging.INFO)
─────────────── 1. BLIP 이미지 캑셔닝 (μ˜μ–΄ 생성) ───────────────
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large") blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
─────────────── 2. μ˜μ–΄β†’ν•œκ΅­μ–΄ λ²ˆμ—­: NLLB νŒŒμ΄ν”„λΌμΈ ───────────────
translation_pipeline = pipeline( "translation", model="facebook/nllb-200-distilled-600M", src_lang="eng_Latn", tgt_lang="kor_Hang", max_length=200 )
─────────────── 3. ν•œκ΅­μ–΄ TTS: VITS 직접 λ‘œλ”© 방식 ───────────────
tts_model = VitsModel.from_pretrained("facebook/mms-tts-kor") tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-kor") tts_model.to("cuda" if torch.cuda.is_available() else "cpu")
def synthesize_tts(text: str): inputs = tts_tokenizer(text, return_tensors="pt") input_ids = inputs["input_ids"].to(tts_model.device) # ⚠ fix: use LongTensor only with torch.no_grad(): output = tts_model(input_ids=input_ids) waveform = output.waveform.squeeze().cpu().numpy() return (tts_model.config.sampling_rate, waveform)
─────────────── 4. 이미지 β†’ μΊ‘μ…˜ + λ²ˆμ—­ + μŒμ„± 좜λ ₯ ───────────────
def describe_and_speak(img: Image.Image): logging.info("[DEBUG] describe_and_speak ν•¨μˆ˜ 호좜됨")
# β‘  μ˜μ–΄ μΊ‘μ…˜ 생성
pixel_values = processor(images=img, return_tensors="pt").pixel_values
generated_ids = blip_model.generate(pixel_values, max_length=64)
caption_en = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
logging.info(f"[DEBUG] caption_en: {caption_en}")
print(f"[DEBUG] caption_en: {caption_en}")
# β‘‘ λ²ˆμ—­
try:
result = translation_pipeline(caption_en)
caption_ko = result[0]['translation_text'].strip()
except Exception as e:
logging.error(f"[ERROR] λ²ˆμ—­ 였λ₯˜: {e}")
caption_ko = ""
logging.info(f"[DEBUG] caption_ko: {caption_ko}")
print(f"[DEBUG] caption_ko: {caption_ko}")
if not caption_ko:
return "이미지에 λŒ€ν•œ μ„€λͺ…을 생성할 수 μ—†μŠ΅λ‹ˆλ‹€.", None
# β‘’ TTS ν•©μ„±
try:
sr, wav = synthesize_tts(caption_ko)
return caption_ko, (sr, wav)
except Exception as e:
logging.error(f"[ERROR] TTS μ—λŸ¬: {e}")
return caption_ko, None
─────────────── 5. Gradio μΈν„°νŽ˜μ΄μŠ€ ───────────────
demo = gr.Interface( fn=describe_and_speak, inputs=gr.Image(type="pil", sources=["upload", "camera"], label="μž…λ ₯ 이미지"), outputs=[ gr.Textbox(label="ν•œκΈ€ μΊ‘μ…˜"), gr.Audio(label="μŒμ„± μž¬μƒ", type="numpy") ], title="이미지 β†’ ν•œκΈ€ μΊ‘μ…˜ & μŒμ„± λ³€ν™˜", description="BLIP으둜 μ˜μ–΄ μΊ‘μ…˜ 생성 β†’ NLLB둜 ν•œκ΅­μ–΄ λ²ˆμ—­ β†’ VITS둜 μŒμ„± 생성" )
if name == "main": demo.launch(debug=True)