File size: 3,365 Bytes
c231a45
b34f0d5
7dcc8af
16edf41
b34f0d5
8144da3
c65ce97
 
c231a45
 
 
 
 
5d429a8
7dcc8af
f1d1009
c753d25
c231a45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c753d25
c231a45
5d429a8
c231a45
 
 
 
 
 
 
092cc1c
 
c231a45
 
5d429a8
 
 
c231a45
 
 
092cc1c
c231a45
 
092cc1c
c753d25
 
c231a45
b34f0d5
16edf41
f1d1009
16edf41
 
 
c231a45
 
c753d25
5d429a8
c231a45
 
 
16edf41
5d429a8
16edf41
c231a45
 
 
 
 
 
 
 
 
 
 
16edf41
5d429a8
c231a45
 
5d429a8
16edf41
958e155
b34f0d5
5d429a8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# app.py
import gradio as gr
from huggingface_hub import InferenceClient
import os

# Cohere Command R+ 모델 ID 정의
COHERE_MODEL = "CohereForAI/c4ai-command-r-plus-08-2024"

def get_client(model_name):
    """
    모델 이름에 맞춰 InferenceClient 생성.
    토큰은 환경 변수에서 가져옴.
    """
    hf_token = os.getenv("HF_TOKEN")
    if not hf_token:
        raise ValueError("HuggingFace API 토큰이 필요합니다.")

    if model_name == "Cohere Command R+":
        model_id = COHERE_MODEL
    else:
        raise ValueError("유효하지 않은 모델 이름입니다.")
    return InferenceClient(model_id, token=hf_token)

def respond_cohere_qna(
    tone: str,
    reference1: str,
    reference2: str,
    reference3: str,
    system_message: str,
    max_tokens: int,
    temperature: float,
    top_p: float
):
    """
    Cohere Command R+ 모델을 이용해 블로그 생성 함수.
    """
    model_name = "Cohere Command R+"
    try:
        client = get_client(model_name)
    except ValueError as e:
        return f"오류: {str(e)}"

    question = f"말투: {tone} \n\n 참조글1: {reference1} \n\n 참조글2: {reference2} \n\n 참조글3: {reference3}"

    try:
        response_full = client.chat_completion(
            [
                {"role": "system", "content": system_message},
                {"role": "user", "content": question}
            ],
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
        )
        assistant_message = response_full.choices[0].message.content
        return assistant_message
    except Exception as e:
        return f"오류가 발생했습니다: {str(e)}"

# Gradio UI 설정
with gr.Blocks() as demo:
    gr.Markdown("# 블로그 생성기")

    with gr.Row():
        tone = gr.Radio(
            choices=["친근하게", "일반적인", "전문적인"],
            value="일반적인",
            label="말투바꾸기"
        )

    reference1 = gr.Textbox(label="참조글 1", lines=3, placeholder="블로그 글에 포함할 주요 참조 내용 1")
    reference2 = gr.Textbox(label="참조글 2", lines=3, placeholder="블로그 글에 포함할 주요 참조 내용 2")
    reference3 = gr.Textbox(label="참조글 3", lines=3, placeholder="블로그 글에 포함할 주요 참조 내용 3")

    output = gr.Textbox(label="생성된 블로그 글", lines=10, interactive=False)

    with gr.Accordion("고급 설정 (Cohere)", open=False):
        system_message = gr.Textbox(
            value="""반드시 한글로 답변할 것.\n너는 블로그 작성을 도와주는 비서이다.\n사용자의 요구사항을 정확히 반영하여 작성하라.""",
            label="System Message",
            lines=3
        )
        max_tokens = gr.Slider(minimum=100, maximum=5000, value=2000, step=100, label="Max Tokens")
        temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature")
        top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-P")

    generate_button = gr.Button("생성")

    generate_button.click(
        fn=respond_cohere_qna,
        inputs=[tone, reference1, reference2, reference3, system_message, max_tokens, temperature, top_p],
        outputs=output
    )

if __name__ == "__main__":
    demo.launch()