File size: 3,672 Bytes
40b810e
 
 
8586da3
 
40b810e
 
 
 
 
8586da3
40b810e
 
cb2bc8c
 
8586da3
 
 
 
 
 
 
40b810e
 
 
 
 
 
7d2e0c9
8586da3
40b810e
 
8586da3
 
 
 
 
 
 
 
 
 
40b810e
820f54d
6ee0045
8586da3
 
 
 
 
40b810e
8586da3
 
6ee0045
8586da3
 
 
 
 
 
 
 
40b810e
7d2e0c9
8586da3
 
 
 
 
 
 
 
40b810e
8586da3
40b810e
8586da3
40b810e
8586da3
 
 
 
40b810e
8586da3
 
40b810e
8586da3
 
 
7d2e0c9
8586da3
 
7d2e0c9
 
8586da3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40b810e
 
8586da3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import gradio as gr
import logging
from PIL import Image
import torch
from uroman import Uroman
from transformers import (
    BlipProcessor,
    BlipForConditionalGeneration,
    pipeline,
    AutoTokenizer,
    VitsModel,
)

logging.basicConfig(level=logging.INFO)

# ───────── 1. λͺ¨λΈ λ‘œλ“œ ─────────
processor = BlipProcessor.from_pretrained(
    "Salesforce/blip-image-captioning-large"
)
blip_model = BlipForConditionalGeneration.from_pretrained(
    "Salesforce/blip-image-captioning-large"
).to("cuda" if torch.cuda.is_available() else "cpu")

translation_pipeline = pipeline(
    "translation",
    model="facebook/nllb-200-distilled-600M",
    src_lang="eng_Latn",
    tgt_lang="kor_Hang",
    max_length=200,
    device=0 if torch.cuda.is_available() else -1,
)

# --- TTS (ko / en) ---
tts_ko = VitsModel.from_pretrained("facebook/mms-tts-kor").to(
    "cuda" if torch.cuda.is_available() else "cpu"
)
tok_ko = AutoTokenizer.from_pretrained("facebook/mms-tts-kor")

tts_en = VitsModel.from_pretrained("facebook/mms-tts-eng").to(
    "cuda" if torch.cuda.is_available() else "cpu"
)
tok_en = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")

uroman = Uroman()


# ───────── 2. 곡톡 ν•¨μˆ˜ ─────────
def tts(model, tokenizer, text: str):
    roman = uroman.romanize_string(text)
    ids = tokenizer(roman, return_tensors="pt").input_ids.long().to(model.device)
    with torch.no_grad():
        wav = model(input_ids=ids).waveform.squeeze().cpu().numpy()
    return model.config.sampling_rate, wav


def generate(img: Image.Image, lang: str):
    """
    lang == "ko" β†’ ν•œκ΅­μ–΄ μΊ‘μ…˜+μŒμ„±
    lang == "en" β†’ μ˜μ–΄  μΊ‘μ…˜+μŒμ„±
    """
    if img is None:
        raise gr.Error("λ¨Όμ € 이미지λ₯Ό μ—…λ‘œλ“œν•˜μ„Έμš” πŸ“·")

    # β‘  μ˜μ–΄ μΊ‘μ…˜
    pix = processor(images=img, return_tensors="pt").pixel_values.to(blip_model.device)
    cap_en = processor.batch_decode(
        blip_model.generate(pix, max_length=64), skip_special_tokens=True
    )[0].strip()

    if lang == "en":
        sr, wav = tts(tts_en, tok_en, cap_en)
        return cap_en, (sr, wav)

    # β‘‘ λ²ˆμ—­(β†’ko)
    try:
        cap_ko = translation_pipeline(cap_en)[0]["translation_text"].strip()
    except Exception as e:
        logging.error(f"[ERROR] λ²ˆμ—­ μ‹€νŒ¨: {e}")
        cap_ko = ""
    if not cap_ko:
        return "λ²ˆμ—­ 였λ₯˜κ°€ λ°œμƒν–ˆμŠ΅λ‹ˆλ‹€.", None

    sr, wav = tts(tts_ko, tok_ko, cap_ko)
    return cap_ko, (sr, wav)


# ───────── 3. Gradio UI ─────────
with gr.Blocks(title="Image β†’ Caption & TTS", css="footer{display:none;}") as demo:
    gr.Markdown(
        "## 이미지 β†’ ν•œκΈ€ / English μΊ‘μ…˜ & μŒμ„± λ³€ν™˜\n"
        "BLIP (caption) β†’ NLLB (translate) β†’ VITS (TTS)"
    )

    img_state = gr.State()  # 졜근 이미지 μ €μž₯

    input_img = gr.Image(type="pil", label="πŸ“· 이미지 μ—…λ‘œλ“œ")
    caption_box = gr.Textbox(label="πŸ“‘ μΊ‘μ…˜ κ²°κ³Ό")
    audio_play = gr.Audio(label="πŸ”Š μŒμ„± μž¬μƒ", type="numpy")

    with gr.Row():
        ko_btn = gr.Button("ν•œκΈ€ 생성")
        en_btn = gr.Button("English")

    # 이미지 μ—…λ‘œλ“œ μ‹œ state μ—…λ°μ΄νŠΈ
    def store_img(img):
        return img

    input_img.change(store_img, inputs=input_img, outputs=img_state, queue=False)

    # λ²„νŠΌ ↔ 생성 ν•¨μˆ˜ μ—°κ²°
    ko_btn.click(fn=lambda img: generate(img, "ko"), inputs=img_state, outputs=[caption_box, audio_play])
    en_btn.click(fn=lambda img: generate(img, "en"), inputs=img_state, outputs=[caption_box, audio_play])

if __name__ == "__main__":
    demo.launch()