File size: 3,543 Bytes
40b810e
 
 
8586da3
 
40b810e
 
 
 
 
8586da3
40b810e
 
cb2bc8c
 
f5e6532
8586da3
 
 
 
 
 
40b810e
 
 
 
 
 
7d2e0c9
8586da3
40b810e
 
f5e6532
8586da3
 
 
 
 
 
 
 
 
40b810e
820f54d
6ee0045
8586da3
 
 
 
40b810e
8586da3
 
6ee0045
8586da3
 
 
 
 
 
 
f5e6532
 
40b810e
8586da3
 
dc9d713
 
 
 
 
 
 
 
 
8586da3
 
dc9d713
8586da3
 
 
40b810e
f5e6532
40b810e
8586da3
40b810e
8586da3
 
 
 
40b810e
8586da3
 
40b810e
8586da3
f5e6532
4ed7380
7d2e0c9
6c2e3ff
4ed7380
7d2e0c9
 
8586da3
 
 
4ed7380
8586da3
 
 
f5e6532
 
 
8586da3
 
 
 
 
 
f5e6532
8586da3
 
40b810e
 
8586da3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import gradio as gr
import logging
from PIL import Image
import torch
from uroman import Uroman
from transformers import (
    BlipProcessor,
    BlipForConditionalGeneration,
    pipeline,
    AutoTokenizer,
    VitsModel,
)

logging.basicConfig(level=logging.INFO)


processor = BlipProcessor.from_pretrained(
    "Salesforce/blip-image-captioning-large"
)
blip_model = BlipForConditionalGeneration.from_pretrained(
    "Salesforce/blip-image-captioning-large"
).to("cuda" if torch.cuda.is_available() else "cpu")

translation_pipeline = pipeline(
    "translation",
    model="facebook/nllb-200-distilled-600M",
    src_lang="eng_Latn",
    tgt_lang="kor_Hang",
    max_length=200,
    device=0 if torch.cuda.is_available() else -1,
)


tts_ko = VitsModel.from_pretrained("facebook/mms-tts-kor").to(
    "cuda" if torch.cuda.is_available() else "cpu"
)
tok_ko = AutoTokenizer.from_pretrained("facebook/mms-tts-kor")

tts_en = VitsModel.from_pretrained("facebook/mms-tts-eng").to(
    "cuda" if torch.cuda.is_available() else "cpu"
)
tok_en = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")

uroman = Uroman()


def tts(model, tokenizer, text: str):
    roman = uroman.romanize_string(text)
    ids = tokenizer(roman, return_tensors="pt").input_ids.long().to(model.device)
    with torch.no_grad():
        wav = model(input_ids=ids).waveform.squeeze().cpu().numpy()
    return model.config.sampling_rate, wav


def generate(img: Image.Image, lang: str):
    """
    lang == "ko" β†’ ν•œκ΅­μ–΄ μΊ‘μ…˜+μŒμ„±
    lang == "en" β†’ μ˜μ–΄  μΊ‘μ…˜+μŒμ„±
    """
    if img is None:
        raise gr.Error("이미지λ₯Ό μ—…λ‘œλ“œν•˜μ„Έμš” πŸ“·")


    pix = processor(images=img, return_tensors="pt").pixel_values.to(blip_model.device)
    cap_en = processor.batch_decode(
        blip_model.generate(
            pix,
            max_length=128,   # μ΅œλŒ€ 길이 ↑
            min_length=20,    # λ„ˆλ¬΄ 짧게 λ©ˆμΆ”λŠ” 것 λ°©μ§€
            num_beams=5,      # λΉ” 탐색 ν’ˆμ§ˆ ↑ (μ†λ„λŠ” 쑰금 느렀짐)
            temperature=0.7,  # λ‹€μ–‘μ„± 쑰절
            repetition_penalty=1.1,
        ),
        skip_special_tokens=True
    )[0].strip()


    if lang == "en":
        sr, wav = tts(tts_en, tok_en, cap_en)
        return cap_en, (sr, wav)


    try:
        cap_ko = translation_pipeline(cap_en)[0]["translation_text"].strip()
    except Exception as e:
        logging.error(f"[ERROR] λ²ˆμ—­ μ‹€νŒ¨: {e}")
        cap_ko = ""
    if not cap_ko:
        return "λ²ˆμ—­ 였λ₯˜κ°€ λ°œμƒν–ˆμŠ΅λ‹ˆλ‹€.", None

    sr, wav = tts(tts_ko, tok_ko, cap_ko)
    return cap_ko, (sr, wav)



with gr.Blocks(title="ν”½λ³΄μ΄μŠ€(Picvoice)", css="footer {display: none !important;}") as demo:
    gr.Markdown(
        "## ν”½λ³΄μ΄μŠ€(Picvoice)"

    )

    img_state = gr.State()  # 졜근 이미지 μ €μž₯

    input_img = gr.Image(type="pil", label="πŸ“· 이미지 μ—…λ‘œλ“œ")
    caption_box = gr.Textbox(label="πŸ“‘ ν…μŠ€νŠΈ 생성 κ²°κ³Ό")
    audio_play = gr.Audio(label="πŸ”Š μŒμ„± μž¬μƒ", type="numpy")

    with gr.Row():
        ko_btn = gr.Button("ν•œκΈ€λ‘œ 생성πŸͺ„")
        en_btn = gr.Button("μ˜μ–΄λ‘œ 생성πŸͺ„")


    def store_img(img):
        return img

    input_img.change(store_img, inputs=input_img, outputs=img_state, queue=False)


    ko_btn.click(fn=lambda img: generate(img, "ko"), inputs=img_state, outputs=[caption_box, audio_play])
    en_btn.click(fn=lambda img: generate(img, "en"), inputs=img_state, outputs=[caption_box, audio_play])

if __name__ == "__main__":
    demo.launch()