File size: 6,327 Bytes
fa23262
 
273c30f
 
fa23262
273c30f
 
 
5bc2b7b
65a4f6a
 
273c30f
5bc2b7b
 
 
 
65a4f6a
 
 
 
273c30f
 
73dae5e
273c30f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import gradio as gr

def update(name):
    return f"Welcome to Gradio, {name}!"

demo = gr.Blocks()

with demo:
    gr.Markdown(f"๊ฐ ์งˆ๋ฌธ์— ๋Œ€๋‹ต ํ›„ Enter ํ•ด์ฃผ์„ธ์š”.\n\n")
    with gr.Row():
        topic = gr.Textbox(label="Topic", placeholder="๋Œ€ํ™” ์ฃผ์ œ๋ฅผ ์ •ํ•ด์ฃผ์„ธ์š” (e.g. ์—ฌ๊ฐ€ ์ƒํ™œ, ์ผ๊ณผ ์ง์—…, ๊ฐœ์ธ ๋ฐ ๊ด€๊ณ„, etc...)")
    with gr.Row():
        with gr.Column():
            addr = gr.Textbox(label="์ง€์—ญ", placeholder="e.g. ์—ฌ๊ฐ€ ์ƒํ™œ, ์ผ๊ณผ ์ง์—…, ๊ฐœ์ธ ๋ฐ ๊ด€๊ณ„, etc...")
            age = gr.Textbox(label="๋‚˜์ด", placeholder="e.g. 20๋Œ€ ๋ฏธ๋งŒ, 40๋Œ€, 70๋Œ€ ์ด์ƒ, etc...")
            sex = gr.Textbox(label="์„ฑ๋ณ„", placeholder="e.g. ๋‚จ์„ฑ, ์—ฌ์„ฑ, etc...")
        with gr.Column():
            addr = gr.Textbox(label="์ง€์—ญ", placeholder="e.g. ์—ฌ๊ฐ€ ์ƒํ™œ, ์ผ๊ณผ ์ง์—…, ๊ฐœ์ธ ๋ฐ ๊ด€๊ณ„, etc...")
            age = gr.Textbox(label="๋‚˜์ด", placeholder="e.g. 20๋Œ€ ๋ฏธ๋งŒ, 40๋Œ€, 70๋Œ€ ์ด์ƒ, etc...")
            sex = gr.Textbox(label="์„ฑ๋ณ„", placeholder="e.g. ๋‚จ์„ฑ, ์—ฌ์„ฑ, etc...")
        out = gr.Textbox()
    btn = gr.Button("Run")
    # btn.click(fn=update, inputs=inp, outputs=out)

demo.launch()


def main(model_name):
    warnings.filterwarnings("ignore")

    tokenizer = AutoTokenizer.from_pretrained('kakaobrain/kogpt', revision='KoGPT6B-ryan1.5b')
    special_tokens_dict = {'additional_special_tokens': ['<sep>', '<eos>', '<sos>', '#@์ด๋ฆ„#', '#@๊ณ„์ •#', '#@์‹ ์›#', '#@์ „๋ฒˆ#', '#@๊ธˆ์œต#', '#@๋ฒˆํ˜ธ#', '#@์ฃผ์†Œ#', '#@์†Œ์†#', '#@๊ธฐํƒ€#']}
    num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)

    model = AutoModelForCausalLM.from_pretrained(model_name)
    model.resize_token_embeddings(len(tokenizer))
    model = model.cuda()

    info = ""
    
    while True:
        if info == "":
            print(
                f"์ง€๊ธˆ๋ถ€ํ„ฐ ๋Œ€ํ™” ์ •๋ณด๋ฅผ ์ž…๋ ฅ ๋ฐ›๊ฒ ์Šต๋‹ˆ๋‹ค.\n"
                f"๊ฐ ์งˆ๋ฌธ์— ๋Œ€๋‹ต ํ›„ Enter ํ•ด์ฃผ์„ธ์š”.\n"
                f"์•„๋ฌด ์ž…๋ ฅ ์—†์ด Enter ํ•  ๊ฒฝ์šฐ, ๋ฏธ๋ฆฌ ์ง€์ •๋œ ๊ฐ’ ์ค‘ ๋žœ๋ค์œผ๋กœ ์ •ํ•˜๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.\n"
            )

            time.sleep(1)

            yon = "no"
        else:
            yon = input(
                f"์ด์ „ ๋Œ€ํ™” ์ •๋ณด๋ฅผ ๊ทธ๋Œ€๋กœ ์œ ์ง€ํ• ๊นŒ์š”? (yes : ์œ ์ง€, no : ์ƒˆ๋กœ ์ž‘์„ฑ) :"
            )
        
        if yon == "no":
            info = "์ผ์ƒ ๋Œ€ํ™” "

            topic = input("๋Œ€ํ™” ์ฃผ์ œ๋ฅผ ์ •ํ•ด์ฃผ์„ธ์š” (e.g. ์—ฌ๊ฐ€ ์ƒํ™œ, ์ผ๊ณผ ์ง์—…, ๊ฐœ์ธ ๋ฐ ๊ด€๊ณ„, etc...) :")
            if topic == "":
                topic = random.choice(['์—ฌ๊ฐ€ ์ƒํ™œ', '์‹œ์‚ฌ/๊ต์œก', '๋ฏธ์šฉ๊ณผ ๊ฑด๊ฐ•', '์‹์Œ๋ฃŒ', '์ƒ๊ฑฐ๋ž˜(์‡ผํ•‘)', '์ผ๊ณผ ์ง์—…', '์ฃผ๊ฑฐ์™€ ์ƒํ™œ', '๊ฐœ์ธ ๋ฐ ๊ด€๊ณ„', 'ํ–‰์‚ฌ'])
                print(topic)
            info += topic + "<sep>"

            def ask_info(who, ment):
                print(ment)
                text = who + ":"
                addr = input("์–ด๋”” ์‚ฌ์„ธ์š”? (e.g. ์„œ์šธํŠน๋ณ„์‹œ, ์ œ์ฃผ๋„, etc...) :").strip()
                if addr == "":
                    addr = random.choice(['์„œ์šธํŠน๋ณ„์‹œ', '๊ฒฝ๊ธฐ๋„', '๋ถ€์‚ฐ๊ด‘์—ญ์‹œ', '๋Œ€์ „๊ด‘์—ญ์‹œ', '๊ด‘์ฃผ๊ด‘์—ญ์‹œ', '์šธ์‚ฐ๊ด‘์—ญ์‹œ', '๊ฒฝ์ƒ๋‚จ๋„', '์ธ์ฒœ๊ด‘์—ญ์‹œ', '์ถฉ์ฒญ๋ถ๋„', '์ œ์ฃผ๋„', '๊ฐ•์›๋„', '์ถฉ์ฒญ๋‚จ๋„', '์ „๋ผ๋ถ๋„', '๋Œ€๊ตฌ๊ด‘์—ญ์‹œ', '์ „๋ผ๋‚จ๋„', '๊ฒฝ์ƒ๋ถ๋„', '์„ธ์ข…ํŠน๋ณ„์ž์น˜์‹œ', '๊ธฐํƒ€'])
                    print(addr)
                text += addr + " "

                age = input("๋‚˜์ด๊ฐ€? (e.g. 20๋Œ€, 70๋Œ€ ์ด์ƒ, etc...) :").strip()
                if age == "":
                    age = random.choice(['20๋Œ€', '30๋Œ€', '50๋Œ€', '20๋Œ€ ๋ฏธ๋งŒ', '60๋Œ€', '40๋Œ€', '70๋Œ€ ์ด์ƒ'])
                    print(age)
                text += age + " "

                sex = input("์„ฑ๋ณ„์ด? (e.g. ๋‚จ์„ฑ, ์—ฌ์„ฑ, etc... (?)) :").strip()
                if sex == "":
                    sex = random.choice(['๋‚จ์„ฑ', '์—ฌ์„ฑ'])
                    print(sex)
                text += sex + "<sep>"
                return text

            info += ask_info(who="P01", ment=f"\n๋‹น์‹ ์— ๋Œ€ํ•ด ์•Œ๋ ค์ฃผ์„ธ์š”.\n")
            info += ask_info(who="P02", ment=f"\n์ฑ—๋ด‡์— ๋Œ€ํ•ด ์•Œ๋ ค์ฃผ์„ธ์š”.\n")

        pp = info.replace('<sep>', '\n')
        print(
            f"\n----------------\n"
            f"<์ž…๋ ฅ ์ •๋ณด ํ™•์ธ> (P01 : ๋‹น์‹ , P02 : ์ฑ—๋ด‡)\n"
            f"{pp}"
            f"----------------\n"
            f"๋Œ€ํ™”๋ฅผ ์ข…๋ฃŒํ•˜๊ณ  ์‹ถ์œผ๋ฉด ์–ธ์ œ๋“ ์ง€ 'end' ๋ผ๊ณ  ๋งํ•ด์ฃผ์„ธ์š”~\n"
        )
        talk = []
        switch = True
        switch2 = True
        while True:
            inp = "P01<sos>"
            myinp = input("๋‹น์‹  : ")
            if myinp == "end":
                print("๋Œ€ํ™” ์ข…๋ฃŒ!")
                break
            inp += myinp + "<eos>"
            talk.append(inp)
            talk.append("P02<sos>")

            while True:
                now_inp = info + "".join(talk)
                inpu = tokenizer(now_inp, max_length=1024, truncation='longest_first', return_tensors='pt')
                seq_len = inpu.input_ids.size(1)
                if seq_len > 512 * 0.8 and switch:
                    print(
                        f"<์ฃผ์˜> ํ˜„์žฌ ๋Œ€ํ™” ๊ธธ์ด๊ฐ€ ๊ณง ์ตœ๋Œ€ ๊ธธ์ด์— ๋„๋‹ฌํ•ฉ๋‹ˆ๋‹ค. ({seq_len} / 512)"
                    )
                    switch = False

                if seq_len >= 512 and switch2:
                    print("<์ฃผ์˜> ๋Œ€ํ™” ๊ธธ์ด๊ฐ€ ๋„ˆ๋ฌด ๊ธธ์–ด์กŒ๊ธฐ ๋•Œ๋ฌธ์—, ์ดํ›„ ๋Œ€ํ™”๋Š” ๋งจ ์•ž์˜ ๋ฐœํ™”๋ฅผ ์กฐ๊ธˆ์”ฉ ์ง€์šฐ๋ฉด์„œ ์ง„ํ–‰๋ฉ๋‹ˆ๋‹ค.")
                    talk = talk[1:]
                    switch2 = False
                else:
                    break
            
            out = model.generate(
                inputs=inpu.input_ids.cuda(), 
                attention_mask=inpu.attention_mask.cuda(),
                max_length=512, 
                do_sample=True,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.encode('<eos>')[0]
                )
            output = tokenizer.batch_decode(out)
            print("์ฑ—๋ด‡ : " + output[0][len(now_inp):-5])
            talk[-1] += output[0][len(now_inp):]

        again = input(f"๋‹ค๋ฅธ ๋Œ€ํ™”๋ฅผ ์‹œ์ž‘ํ• ๊นŒ์š”? (yes : ์ƒˆ๋กœ์šด ์‹œ์ž‘, no : ์ข…๋ฃŒ) :")
        if again == "no":
            break