File size: 2,431 Bytes
e8bac0f
2db0d53
e8bac0f
04e9db1
8b62ce7
e8bac0f
04e9db1
 
8b62ce7
a5a2931
7ad1fa3
e8bac0f
a5a2931
 
 
263e495
 
 
a5a2931
2db0d53
a5a2931
85f74eb
 
a5a2931
85f74eb
2db0d53
cde7a7b
263e495
 
8b62ce7
 
 
 
 
 
 
 
 
 
263e495
 
 
8b62ce7
 
263e495
8b62ce7
 
263e495
a6549b1
 
 
 
 
 
 
a5a2931
a6549b1
a5a2931
a6549b1
a5a2931
2db0d53
a6549b1
2db0d53
a5a2931
2db0d53
 
a5a2931
 
a6549b1
83746e4
 
a5a2931
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import gradio as gr
from huggingface_hub import InferenceClient
import os
import logging
from gradio_client import Client  # ์ด๋ฏธ์ง€ ์ƒ์„ฑ API ํด๋ผ์ด์–ธํŠธ

# ๋กœ๊น… ์„ค์ •
logging.basicConfig(level=logging.INFO)

# ํ™˜๊ฒฝ ๋ณ€์ˆ˜์—์„œ Hugging Face API ํ† ํฐ์„ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.
hf_client = InferenceClient("CohereForAI/c4ai-command-r-plus", token=os.getenv("HF_TOKEN"))

# ์ด๋ฏธ์ง€ ์ƒ์„ฑ API ํด๋ผ์ด์–ธํŠธ ์„ค์ •
client = Client("http://211.233.58.202:7960/")

def respond(message, history, system_message, max_tokens, temperature, top_p):
    # ์ดˆ๊ธฐ ์„ค์ • ๋ฐ ๋ณ€์ˆ˜ ์ •์˜
    system_prefix = "System: ์ž…๋ ฅ์–ด์˜ ์–ธ์–ด์— ๋”ฐ๋ผ ๋™์ผํ•œ ์–ธ์–ด๋กœ ๋‹ต๋ณ€ํ•˜๋ผ."
    full_system_message = f"{system_prefix}{system_message}"

    messages = [{"role": "system", "content": f"{system_prefix} {system_message}"}]
    for user_msg, assistant_msg in history:
        messages.append({"role": "user", "content": user_msg})
        if assistant_msg:
            messages.append({"role": "assistant", "content": assistant_msg})
    messages.append({"role": "user", "content": message})

    # ์ด๋ฏธ์ง€ ์ƒ์„ฑ ์š”์ฒญ
    try:
        result = client.predict(
            prompt=message,
            seed=123,
            randomize_seed=False,
            width=1024,
            height=576,
            guidance_scale=5,
            num_inference_steps=28,
            api_name="/infer_t2i"
        )
        if 'url' in result:
            return result['url']
        else:
            logging.error("Image generation failed with error: %s", result.get('error', 'Unknown error'))
            return "Failed to generate image."
    except Exception as e:
        logging.error("Error during API request: %s", str(e))
        return f"An error occurred: {str(e)}"

theme = "Nymbo/Nymbo_Theme"
css = """
footer {
    visibility: hidden;
}
"""

# Gradio ์ฑ„ํŒ… ์ธํ„ฐํŽ˜์ด์Šค๋ฅผ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.
demo = gr.ChatInterface(
    fn=respond,
    additional_inputs=[
        gr.Textbox(value="You are an AI assistant.", label="System Prompt"),
        gr.Slider(minimum=1, maximum=2000, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"
        ),
    ],
    theme=theme,
    css=css
)

if __name__ == "__main__":
    demo.launch()