|
import gradio as gr |
|
from huggingface_hub import InferenceClient |
|
import os |
|
from typing import Optional |
|
|
|
|
|
|
|
|
|
|
|
|
|
COHERE_MODEL = "CohereForAI/c4ai-command-r-plus-08-2024" |
|
|
|
def get_client(model_name: str): |
|
""" |
|
모델 이름에 맞춰 InferenceClient 생성. |
|
토큰은 환경 변수에서 가져옴. |
|
""" |
|
hf_token = os.getenv("HF_TOKEN") |
|
if not hf_token: |
|
raise ValueError("HuggingFace API 토큰(HF_TOKEN)이 설정되지 않았습니다.") |
|
|
|
if model_name == "Cohere Command R+": |
|
model_id = COHERE_MODEL |
|
else: |
|
raise ValueError("유효하지 않은 모델 이름입니다.") |
|
return InferenceClient(model_id, token=hf_token) |
|
|
|
def respond_cohere_qna( |
|
question: str, |
|
system_message: str, |
|
max_tokens: int, |
|
temperature: float, |
|
top_p: float |
|
): |
|
""" |
|
Cohere Command R+ 모델을 이용해 한 번의 질문(question)에 대한 답변을 반환하는 함수. |
|
""" |
|
model_name = "Cohere Command R+" |
|
try: |
|
client = get_client(model_name) |
|
except ValueError as e: |
|
return f"오류: {str(e)}" |
|
|
|
messages = [ |
|
{"role": "system", "content": system_message}, |
|
{"role": "user", "content": question} |
|
] |
|
|
|
try: |
|
response_full = client.chat_completion( |
|
messages, |
|
max_tokens=max_tokens, |
|
temperature=temperature, |
|
top_p=top_p, |
|
) |
|
assistant_message = response_full.choices[0].message.content |
|
return assistant_message |
|
except Exception as e: |
|
return f"오류가 발생했습니다: {str(e)}" |
|
|
|
|
|
|
|
|
|
COHERE_SYSTEM_MESSAGE = """반드시 한글로 답변할 것. |
|
너는 최고의 비서이다. |
|
내가 요구하는 것들을 최대한 자세하고 정확하게 답변하라. |
|
""" |
|
COHERE_MAX_TOKENS = 4000 |
|
COHERE_TEMPERATURE = 0.7 |
|
COHERE_TOP_P = 0.95 |
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# 블로그 생성기") |
|
|
|
|
|
tone_radio = gr.Radio( |
|
label="말투바꾸기", |
|
choices=["친근하게", "일반적인", "전문적인"], |
|
value="일반적인" |
|
) |
|
|
|
|
|
ref1 = gr.Textbox(label="참조글 1") |
|
ref2 = gr.Textbox(label="참조글 2") |
|
ref3 = gr.Textbox(label="참조글 3") |
|
|
|
output_box = gr.Textbox(label="결과", lines=8, interactive=False) |
|
|
|
def generate_blog(tone_value, ref1_value, ref2_value, ref3_value): |
|
|
|
|
|
question = ( |
|
f"~~\n" |
|
f"말투: {tone_value}\n" |
|
f"참조글1: {ref1_value}\n" |
|
f"참조글2: {ref2_value}\n" |
|
f"참조글3: {ref3_value}\n" |
|
) |
|
|
|
|
|
response = respond_cohere_qna( |
|
question=question, |
|
system_message=COHERE_SYSTEM_MESSAGE, |
|
max_tokens=COHERE_MAX_TOKENS, |
|
temperature=COHERE_TEMPERATURE, |
|
top_p=COHERE_TOP_P |
|
) |
|
return response |
|
|
|
generate_button = gr.Button("생성하기") |
|
generate_button.click( |
|
fn=generate_blog, |
|
inputs=[tone_radio, ref1, ref2, ref3], |
|
outputs=output_box |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|