blog-chatGPT_01 / app.py
Kims12's picture
Update app.py
c231a45 verified
raw
history blame
3.37 kB
# app.py
import gradio as gr
from huggingface_hub import InferenceClient
import os
# Cohere Command R+ 모델 ID 정의
COHERE_MODEL = "CohereForAI/c4ai-command-r-plus-08-2024"
def get_client(model_name):
"""
모델 이름에 맞춰 InferenceClient 생성.
토큰은 환경 변수에서 가져옴.
"""
hf_token = os.getenv("HF_TOKEN")
if not hf_token:
raise ValueError("HuggingFace API 토큰이 필요합니다.")
if model_name == "Cohere Command R+":
model_id = COHERE_MODEL
else:
raise ValueError("유효하지 않은 모델 이름입니다.")
return InferenceClient(model_id, token=hf_token)
def respond_cohere_qna(
tone: str,
reference1: str,
reference2: str,
reference3: str,
system_message: str,
max_tokens: int,
temperature: float,
top_p: float
):
"""
Cohere Command R+ 모델을 이용해 블로그 생성 함수.
"""
model_name = "Cohere Command R+"
try:
client = get_client(model_name)
except ValueError as e:
return f"오류: {str(e)}"
question = f"말투: {tone} \n\n 참조글1: {reference1} \n\n 참조글2: {reference2} \n\n 참조글3: {reference3}"
try:
response_full = client.chat_completion(
[
{"role": "system", "content": system_message},
{"role": "user", "content": question}
],
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
)
assistant_message = response_full.choices[0].message.content
return assistant_message
except Exception as e:
return f"오류가 발생했습니다: {str(e)}"
# Gradio UI 설정
with gr.Blocks() as demo:
gr.Markdown("# 블로그 생성기")
with gr.Row():
tone = gr.Radio(
choices=["친근하게", "일반적인", "전문적인"],
value="일반적인",
label="말투바꾸기"
)
reference1 = gr.Textbox(label="참조글 1", lines=3, placeholder="블로그 글에 포함할 주요 참조 내용 1")
reference2 = gr.Textbox(label="참조글 2", lines=3, placeholder="블로그 글에 포함할 주요 참조 내용 2")
reference3 = gr.Textbox(label="참조글 3", lines=3, placeholder="블로그 글에 포함할 주요 참조 내용 3")
output = gr.Textbox(label="생성된 블로그 글", lines=10, interactive=False)
with gr.Accordion("고급 설정 (Cohere)", open=False):
system_message = gr.Textbox(
value="""반드시 한글로 답변할 것.\n너는 블로그 작성을 도와주는 비서이다.\n사용자의 요구사항을 정확히 반영하여 작성하라.""",
label="System Message",
lines=3
)
max_tokens = gr.Slider(minimum=100, maximum=5000, value=2000, step=100, label="Max Tokens")
temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-P")
generate_button = gr.Button("생성")
generate_button.click(
fn=respond_cohere_qna,
inputs=[tone, reference1, reference2, reference3, system_message, max_tokens, temperature, top_p],
outputs=output
)
if __name__ == "__main__":
demo.launch()