File size: 3,107 Bytes
b34f0d5
7dcc8af
398cd15
f706b9e
 
 
 
 
 
8144da3
c65ce97
 
398cd15
f706b9e
e1ec564
398cd15
f706b9e
398cd15
f706b9e
398cd15
e1ec564
c753d25
 
 
 
 
 
398cd15
092cc1c
c753d25
 
 
092cc1c
398cd15
092cc1c
c753d25
092cc1c
c753d25
 
 
 
092cc1c
 
 
 
 
 
 
 
 
c753d25
092cc1c
c753d25
 
f9b088b
e1ec564
f9b088b
 
b34f0d5
e1ec564
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
398cd15
e1ec564
 
 
 
 
 
 
 
398cd15
660d467
 
e1ec564
 
 
 
 
958e155
f9b088b
 
 
b34f0d5
e1ec564
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import gradio as gr
from huggingface_hub import InferenceClient
import os
from typing import Optional

#############################
# [기본코드] - 수정/삭제 불가
#############################

# Cohere Command R+ 모델 ID 정의
COHERE_MODEL = "CohereForAI/c4ai-command-r-plus-08-2024"

def get_client():
    """
    Cohere Command R+ 모델을 위한 InferenceClient 생성.
    환경 변수에서 HuggingFace API 토큰을 가져옴.
    """
    hf_token = os.getenv("HUGGINGFACE_TOKEN")
    if not hf_token:
        raise ValueError("HuggingFace API 토큰이 환경 변수에 설정되지 않았습니다.")
    return InferenceClient(COHERE_MODEL, token=hf_token)

def respond_cohere_qna(
    question: str,
    system_message: str,
    max_tokens: int,
    temperature: float,
    top_p: float
):
    """
    Cohere Command R+ 모델을 이용해 한 번의 질문(question)에 대한 답변을 반환하는 함수.
    """
    try:
        client = get_client()
    except ValueError as e:
        return f"오류: {str(e)}"

    messages = [
        {"role": "system", "content": system_message},
        {"role": "user", "content": question}
    ]

    try:
        response_full = client.chat_completion(
            messages,
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
        )
        assistant_message = response_full.choices[0].message.content
        return assistant_message
    except Exception as e:
        return f"오류가 발생했습니다: {str(e)}"

#############################
# [UI 부분] - 수정된 부분
#############################

with gr.Blocks() as demo:
    gr.Markdown("# 블로그 생성기")

    # 말투 선택 라디오 버튼
    tone_radio = gr.Radio(
        choices=["친근한", "전문적인", "일반", "상품후기"],
        label="말투바꾸기",
        value="일반"
    )

    # 참조글 입력
    reference1 = gr.Textbox(label="참조글1", lines=2)
    reference2 = gr.Textbox(label="참조글2", lines=2)
    reference3 = gr.Textbox(label="참조글3", lines=2)

    # 생성된 블로그 글 출력
    generated_blog = gr.Textbox(label="생성된 블로그 글", lines=10, interactive=False)

    # 전송 버튼
    submit_button = gr.Button("생성")

    def generate_blog(tone, ref1, ref2, ref3):
        # 참조글을 합쳐서 질문 구성
        question = f"말투: {tone}\n참조글1: {ref1}\n참조글2: {ref2}\n참조글3: {ref3}"
        system_message = "블로그 글을 생성해주세요. 주어진 참조글을 바탕으로 요청된 말투에 맞게 작성하세요."
        return respond_cohere_qna(
            question=question,
            system_message=system_message,
            max_tokens=1000,
            temperature=0.7,
            top_p=0.95
        )

    submit_button.click(
        fn=generate_blog,
        inputs=[tone_radio, reference1, reference2, reference3],
        outputs=generated_blog
    )

#############################
# 메인 실행부
#############################
if __name__ == "__main__":
    demo.launch()