Spaces:
Sleeping
Sleeping
# app.py | |
import gradio as gr | |
from huggingface_hub import InferenceClient | |
import os | |
# Cohere Command R+ 모델 ID 정의 | |
COHERE_MODEL = "CohereForAI/c4ai-command-r-plus-08-2024" | |
def get_client(model_name): | |
""" | |
모델 이름에 맞춰 InferenceClient 생성. | |
토큰은 환경 변수에서 가져옴. | |
""" | |
hf_token = os.getenv("HF_TOKEN") | |
if not hf_token: | |
raise ValueError("HuggingFace API 토큰이 필요합니다.") | |
if model_name == "Cohere Command R+": | |
model_id = COHERE_MODEL | |
else: | |
raise ValueError("유효하지 않은 모델 이름입니다.") | |
return InferenceClient(model_id, token=hf_token) | |
def respond_cohere_qna( | |
tone: str, | |
reference1: str, | |
reference2: str, | |
reference3: str, | |
system_message: str, | |
max_tokens: int, | |
temperature: float, | |
top_p: float | |
): | |
""" | |
Cohere Command R+ 모델을 이용해 블로그 생성 함수. | |
""" | |
model_name = "Cohere Command R+" | |
try: | |
client = get_client(model_name) | |
except ValueError as e: | |
return f"오류: {str(e)}" | |
question = f"말투: {tone} \n\n 참조글1: {reference1} \n\n 참조글2: {reference2} \n\n 참조글3: {reference3}" | |
try: | |
response_full = client.chat_completion( | |
[ | |
{"role": "system", "content": system_message}, | |
{"role": "user", "content": question} | |
], | |
max_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
) | |
assistant_message = response_full.choices[0].message.content | |
return assistant_message | |
except Exception as e: | |
return f"오류가 발생했습니다: {str(e)}" | |
# Gradio UI 설정 | |
with gr.Blocks() as demo: | |
gr.Markdown("# 블로그 생성기") | |
with gr.Row(): | |
tone = gr.Radio( | |
choices=["친근하게", "일반적인", "전문적인"], | |
value="일반적인", | |
label="말투바꾸기" | |
) | |
reference1 = gr.Textbox(label="참조글 1", lines=3, placeholder="블로그 글에 포함할 주요 참조 내용 1") | |
reference2 = gr.Textbox(label="참조글 2", lines=3, placeholder="블로그 글에 포함할 주요 참조 내용 2") | |
reference3 = gr.Textbox(label="참조글 3", lines=3, placeholder="블로그 글에 포함할 주요 참조 내용 3") | |
output = gr.Textbox(label="생성된 블로그 글", lines=10, interactive=False) | |
with gr.Accordion("고급 설정 (Cohere)", open=False): | |
system_message = gr.Textbox( | |
value="""반드시 한글로 답변할 것.\n너는 블로그 작성을 도와주는 비서이다.\n사용자의 요구사항을 정확히 반영하여 작성하라.""", | |
label="System Message", | |
lines=3 | |
) | |
max_tokens = gr.Slider(minimum=100, maximum=5000, value=2000, step=100, label="Max Tokens") | |
temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature") | |
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-P") | |
generate_button = gr.Button("생성") | |
generate_button.click( | |
fn=respond_cohere_qna, | |
inputs=[tone, reference1, reference2, reference3, system_message, max_tokens, temperature, top_p], | |
outputs=output | |
) | |
if __name__ == "__main__": | |
demo.launch() | |