Update app.py
Browse files
app.py
CHANGED
@@ -58,7 +58,6 @@ class StarlingBot:
|
|
58 |
gc.collect()
|
59 |
torch.cuda.empty_cache()
|
60 |
|
61 |
-
starling_bot = StarlingBot()
|
62 |
examples = [
|
63 |
[
|
64 |
"The following dialogue is a conversation between Emmanuel Macron and Elon Musk:", # user_message
|
@@ -69,21 +68,37 @@ examples = [
|
|
69 |
1.9, # repetition_penalty
|
70 |
]
|
71 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
-
|
74 |
-
fn=starling_bot.predict,
|
75 |
-
inputs=[
|
76 |
-
gr.Textbox(label="🌟🤩User Message", type="text", lines=5),
|
77 |
-
gr.Textbox(label="💫🌠Starling Assistant Message or Instructions ", lines=2),
|
78 |
-
gr.Textbox(label="💫🌠Starling System Prompt or Instruction", lines=2),
|
79 |
-
gr.Checkbox(label="Advanced", value=False),
|
80 |
-
gr.Slider(label="Temperature", value=0.7, minimum=0.05, maximum=1.0, step=0.05),
|
81 |
-
gr.Slider(label="Max new tokens", value=100, minimum=25, maximum=256, step=1),
|
82 |
-
gr.Slider(label="Top-p (nucleus sampling)", value=0.90, minimum=0.01, maximum=0.99, step=0.05),
|
83 |
-
gr.Slider(label="Repetition penalty", value=1.9, minimum=1.0, maximum=2.0, step=0.05)
|
84 |
-
],
|
85 |
-
outputs="text",
|
86 |
-
gr.Markdown(title),
|
87 |
-
gr.Markdown(description),
|
88 |
-
theme="ParityError/Anime"
|
89 |
-
)
|
|
|
58 |
gc.collect()
|
59 |
torch.cuda.empty_cache()
|
60 |
|
|
|
61 |
examples = [
|
62 |
[
|
63 |
"The following dialogue is a conversation between Emmanuel Macron and Elon Musk:", # user_message
|
|
|
68 |
1.9, # repetition_penalty
|
69 |
]
|
70 |
]
|
71 |
+
# Initialize StarlingBot
|
72 |
+
starling_bot = StarlingBot()
|
73 |
+
|
74 |
+
def gradio_starling(user_message, assistant_message, system_message, do_sample, temperature, max_new_tokens, top_p, repetition_penalty):
|
75 |
+
response = starling_bot.predict(user_message, assistant_message, system_message, do_sample, temperature, max_new_tokens, top_p, repetition_penalty)
|
76 |
+
return response
|
77 |
+
|
78 |
+
with gr.Blocks(theme="ParityError/Anime") as demo:
|
79 |
+
gr.Markdown(title)
|
80 |
+
gr.Markdown(description)
|
81 |
+
with gr.Row():
|
82 |
+
system_message = gr.Textbox(label="Optional💫🌠Starling System Message", lines=2)
|
83 |
+
assistant_message = gr.Textbox(label="💫🌠Starling Assistant Message", lines=2)
|
84 |
+
user_message = gr.Textbox(label="Your Message", lines=3)
|
85 |
+
with gr.Row():
|
86 |
+
do_sample = gr.Checkbox(label="Advanced", value=True)
|
87 |
+
|
88 |
+
with gr.Accordion("Advanced Settings", open=lambda do_sample: do_sample):
|
89 |
+
with gr.Row():
|
90 |
+
temperature = gr.Slider(label="Temperature", value=0.7, minimum=0.05, maximum=1.0, step=0.05)
|
91 |
+
max_new_tokens = gr.Slider(label="Max new tokens", value=100, minimum=25, maximum=256, step=1)
|
92 |
+
top_p = gr.Slider(label="Top-p (nucleus sampling)", value=0.90, minimum=0.01, maximum=0.99, step=0.05)
|
93 |
+
repetition_penalty = gr.Slider(label="Repetition penalty", value=1.9, minimum=1.0, maximum=2.0, step=0.05)
|
94 |
+
|
95 |
+
submit_button = gr.Button("Submit")
|
96 |
+
output_text = gr.Textbox(label="💫🌠Starling Response")
|
97 |
+
|
98 |
+
submit_button.click(
|
99 |
+
gradio_starling,
|
100 |
+
inputs=[user_message, assistant_message, system_message, do_sample, temperature, max_new_tokens, top_p, repetition_penalty],
|
101 |
+
outputs=output_text
|
102 |
+
)
|
103 |
|
104 |
+
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|