File size: 2,184 Bytes
f9ce6d5
167e985
f9ce6d5
 
f021c8c
167e985
 
f021c8c
f9ce6d5
 
 
f021c8c
167e985
f9ce6d5
 
 
f021c8c
 
f9ce6d5
 
167e985
b24bc5f
167e985
 
f9ce6d5
f021c8c
167e985
f9ce6d5
 
 
 
 
 
f021c8c
d7464be
167e985
 
 
 
 
 
 
 
 
 
 
 
 
 
f9ce6d5
f021c8c
 
f9ce6d5
167e985
f9ce6d5
167e985
f9ce6d5
 
167e985
f9ce6d5
167e985
 
 
f9ce6d5
 
 
167e985
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import gradio as gr
from transformers import pipeline
from huggingface_hub import InferenceClient

# ๊ฐ์ • ๋ถ„์„ ๋ชจ๋ธ ๋กœ๋“œ
sentiment_pipeline = pipeline("sentiment-analysis", model="beomi/KcELECTRA-base")

# ์ƒ์„ฑ ๋ชจ๋ธ (Zephyr)
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")


# ๊ฐ์ • ๋ถ„์„ + ์žฌ์ž‘์„ฑ ํ•จ์ˆ˜
def rewrite_if_negative(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    #๊ฐ์ • ๋ถ„์„
    result = sentiment_pipeline(message)[0]
    label = result['label']
    score = result['score']

    #๋ฉ”์‹œ์ง€ ์ดˆ๊ธฐํ™”
    messages = [{"role": "system", "content": system_message}]
    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    #๋ฌธ์žฅ ์žฌ์ž‘์„ฑ ์—ฌ๋ถ€ ํŒ๋‹จ
    if True: #label == "LABEL_1": #and score > 0.8:
        messages.append({"role": "user", "content": f"๋‹ค์Œ ๋ฌธ์žฅ์„ ๊ณต๊ฐ ๊ฐ€๋Š” ๋ง๋กœ ๋ฐ”๊ฟ”์ค˜: {message}"})
        response = ""
        for chunk in client.chat_completion(
            messages,
            max_tokens=max_tokens,
            stream=True,
            temperature=temperature,
            top_p=top_p,
        ):
            token = chunk.choices[0].delta.content
            response += token
            yield response
    else:
        yield "ํ‘œํ˜„์ด ๊ดœ์ฐฎ."


# Gradio ์ธํ„ฐํŽ˜์ด์Šค ๊ตฌ์„ฑ
demo = gr.ChatInterface(
    fn=rewrite_if_negative,
    additional_inputs=[
        gr.Textbox(value="๋„ˆ๋Š” ๋ถ€๋“œ๋Ÿฌ์šด ๋งํˆฌ๋กœ ๋งํ•˜๋Š” AI์•ผ.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
    ],
    title="๋ฌธ์žฅ ์–ด์‹œ์Šคํ„ด์Šค",
    description="๋ฌธ์žฅ์„ ์ž…๋ ฅํ•˜๋ฉด ๊ฐ์ •์„ ๋ถ„์„ํ•˜๊ณ , ๋„ˆ๋ฌด ๋ถ€์ •์ ์ธ ๋งํˆฌ๋Š” ๊ณต๊ฐ ๊ฐ€๋Š” ํ‘œํ˜„์œผ๋กœ ๋ฐ”๊ฟ”์คŒ",
    theme="soft",
)

if __name__ == "__main__":
    demo.launch()