Spaces:
Sleeping
Sleeping
abrakjamson
commited on
Commit
·
85e58bb
1
Parent(s):
4da1fb0
Disable input while generating
Browse files
app.py
CHANGED
|
@@ -24,10 +24,12 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
| 24 |
trust_remote_code=True,
|
| 25 |
use_safetensors=True
|
| 26 |
)
|
| 27 |
-
|
| 28 |
-
print(f"Is CUDA available: {
|
| 29 |
-
|
|
|
|
| 30 |
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
|
|
|
|
| 31 |
|
| 32 |
model = ControlModel(model, list(range(-5, -18, -1)))
|
| 33 |
|
|
@@ -87,7 +89,8 @@ def generate_response(system_prompt, user_message, history, max_new_tokens, repi
|
|
| 87 |
Returns a list of tuples, the user message and the assistant response,
|
| 88 |
which Gradio uses to update the chatbot history
|
| 89 |
"""
|
| 90 |
-
|
|
|
|
| 91 |
# Separate checkboxes and sliders based on type
|
| 92 |
# The first x in args are the checkbox names (the file names)
|
| 93 |
# The second x in args are the slider values
|
|
@@ -139,7 +142,10 @@ def generate_response(system_prompt, user_message, history, max_new_tokens, repi
|
|
| 139 |
"repetition_penalty": repetition_penalty.value,
|
| 140 |
}
|
| 141 |
|
| 142 |
-
|
|
|
|
|
|
|
|
|
|
| 143 |
|
| 144 |
generate_kwargs = dict(
|
| 145 |
input_ids,
|
|
@@ -155,6 +161,9 @@ def generate_response(system_prompt, user_message, history, max_new_tokens, repi
|
|
| 155 |
|
| 156 |
# Display the response as it streams in, prepending the control vector info
|
| 157 |
partial_message = ""
|
|
|
|
|
|
|
|
|
|
| 158 |
for new_token in _streamer:
|
| 159 |
if new_token != '<' and new_token != '</s>': # seems to hit EOS correctly without this needed
|
| 160 |
partial_message += new_token
|
|
@@ -181,14 +190,17 @@ def generate_response(system_prompt, user_message, history, max_new_tokens, repi
|
|
| 181 |
|
| 182 |
# Update conversation history
|
| 183 |
history.append((user_message, assistant_response_display))
|
| 184 |
-
|
| 185 |
|
| 186 |
def generate_response_with_retry(system_prompt, user_message, history, max_new_tokens, repitition_penalty, do_sample, *args):
|
| 187 |
# Remove last user input and assistant response from history, then call generate_response()
|
|
|
|
|
|
|
| 188 |
if history:
|
| 189 |
history = history[0:-1]
|
| 190 |
-
|
| 191 |
-
|
|
|
|
| 192 |
|
| 193 |
# Function to reset the conversation history
|
| 194 |
def reset_chat():
|
|
@@ -281,7 +293,7 @@ def set_preset_stoner(*args):
|
|
| 281 |
for check in model_names_and_indexes:
|
| 282 |
if check == "Angry":
|
| 283 |
new_checkbox_values.append(True)
|
| 284 |
-
new_slider_values.append(0.
|
| 285 |
elif check == "Right-leaning":
|
| 286 |
new_checkbox_values.append(True)
|
| 287 |
new_slider_values.append(-0.5)
|
|
@@ -323,6 +335,15 @@ def set_preset_facts(*args):
|
|
| 323 |
|
| 324 |
return new_checkbox_values + new_slider_values
|
| 325 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
tooltip_css = """
|
| 327 |
/* Tooltip container */
|
| 328 |
.tooltip {
|
|
@@ -560,10 +581,22 @@ with gr.Blocks(
|
|
| 560 |
inputs_list = [system_prompt, user_input, chatbot, max_new_tokens, repetition_penalty, do_sample] + control_checks + control_sliders
|
| 561 |
|
| 562 |
# Define button actions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 563 |
submit_button.click(
|
| 564 |
generate_response,
|
| 565 |
inputs=inputs_list,
|
| 566 |
outputs=[chatbot]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 567 |
)
|
| 568 |
|
| 569 |
user_input.submit(
|
|
@@ -575,7 +608,11 @@ with gr.Blocks(
|
|
| 575 |
retry_button.click(
|
| 576 |
generate_response_with_retry,
|
| 577 |
inputs=inputs_list,
|
| 578 |
-
outputs=[chatbot]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 579 |
)
|
| 580 |
|
| 581 |
new_chat_button.click(
|
|
|
|
| 24 |
trust_remote_code=True,
|
| 25 |
use_safetensors=True
|
| 26 |
)
|
| 27 |
+
cuda = torch.cuda.is_available()
|
| 28 |
+
print(f"Is CUDA available: {cuda}")
|
| 29 |
+
model = model.to("cuda:0" if cuda else "cpu")
|
| 30 |
+
if cuda:
|
| 31 |
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
|
| 32 |
+
|
| 33 |
|
| 34 |
model = ControlModel(model, list(range(-5, -18, -1)))
|
| 35 |
|
|
|
|
| 89 |
Returns a list of tuples, the user message and the assistant response,
|
| 90 |
which Gradio uses to update the chatbot history
|
| 91 |
"""
|
| 92 |
+
global previous_turn
|
| 93 |
+
previous_turn = user_message
|
| 94 |
# Separate checkboxes and sliders based on type
|
| 95 |
# The first x in args are the checkbox names (the file names)
|
| 96 |
# The second x in args are the slider values
|
|
|
|
| 142 |
"repetition_penalty": repetition_penalty.value,
|
| 143 |
}
|
| 144 |
|
| 145 |
+
timeout = 120.0
|
| 146 |
+
if cuda:
|
| 147 |
+
timeout = 10.0
|
| 148 |
+
_streamer = TextIteratorStreamer(tokenizer, timeout=timeout, skip_prompt=True, skip_special_tokens=False,)
|
| 149 |
|
| 150 |
generate_kwargs = dict(
|
| 151 |
input_ids,
|
|
|
|
| 161 |
|
| 162 |
# Display the response as it streams in, prepending the control vector info
|
| 163 |
partial_message = ""
|
| 164 |
+
#show the control vector info while we wait for the first token
|
| 165 |
+
temp_output = "*" + assistant_message_title + "*" + "\n\n*Please wait*..." + partial_message
|
| 166 |
+
yield history + [(user_message, temp_output)]
|
| 167 |
for new_token in _streamer:
|
| 168 |
if new_token != '<' and new_token != '</s>': # seems to hit EOS correctly without this needed
|
| 169 |
partial_message += new_token
|
|
|
|
| 190 |
|
| 191 |
# Update conversation history
|
| 192 |
history.append((user_message, assistant_response_display))
|
| 193 |
+
return history
|
| 194 |
|
| 195 |
def generate_response_with_retry(system_prompt, user_message, history, max_new_tokens, repitition_penalty, do_sample, *args):
|
| 196 |
# Remove last user input and assistant response from history, then call generate_response()
|
| 197 |
+
global previous_turn
|
| 198 |
+
previous_ueser_message = previous_turn
|
| 199 |
if history:
|
| 200 |
history = history[0:-1]
|
| 201 |
+
# Using the previous turn's text, even though it isn't in the textbox anymore
|
| 202 |
+
for output in generate_response(system_prompt, previous_ueser_message, history, max_new_tokens, repetition_penalty, do_sample, *args):
|
| 203 |
+
yield [output, previous_ueser_message]
|
| 204 |
|
| 205 |
# Function to reset the conversation history
|
| 206 |
def reset_chat():
|
|
|
|
| 293 |
for check in model_names_and_indexes:
|
| 294 |
if check == "Angry":
|
| 295 |
new_checkbox_values.append(True)
|
| 296 |
+
new_slider_values.append(0.4)
|
| 297 |
elif check == "Right-leaning":
|
| 298 |
new_checkbox_values.append(True)
|
| 299 |
new_slider_values.append(-0.5)
|
|
|
|
| 335 |
|
| 336 |
return new_checkbox_values + new_slider_values
|
| 337 |
|
| 338 |
+
def disable_controls():
|
| 339 |
+
return gr.update(interactive= False, value= "⌛ Processing"), gr.update(interactive=False)
|
| 340 |
+
|
| 341 |
+
def enable_controls():
|
| 342 |
+
return gr.update(interactive= True, value= "💬 Submit"), gr.update(interactive= True)
|
| 343 |
+
|
| 344 |
+
def clear_input(input_textbox):
|
| 345 |
+
return ""
|
| 346 |
+
|
| 347 |
tooltip_css = """
|
| 348 |
/* Tooltip container */
|
| 349 |
.tooltip {
|
|
|
|
| 581 |
inputs_list = [system_prompt, user_input, chatbot, max_new_tokens, repetition_penalty, do_sample] + control_checks + control_sliders
|
| 582 |
|
| 583 |
# Define button actions
|
| 584 |
+
# Disable the submit button while processing
|
| 585 |
+
submit_button.click(
|
| 586 |
+
disable_controls,
|
| 587 |
+
inputs= None,
|
| 588 |
+
outputs= [submit_button, user_input]
|
| 589 |
+
)
|
| 590 |
submit_button.click(
|
| 591 |
generate_response,
|
| 592 |
inputs=inputs_list,
|
| 593 |
outputs=[chatbot]
|
| 594 |
+
).then(
|
| 595 |
+
clear_input,
|
| 596 |
+
inputs= user_input,
|
| 597 |
+
outputs= user_input
|
| 598 |
+
).then(
|
| 599 |
+
enable_controls, inputs=None, outputs=[submit_button, user_input]
|
| 600 |
)
|
| 601 |
|
| 602 |
user_input.submit(
|
|
|
|
| 608 |
retry_button.click(
|
| 609 |
generate_response_with_retry,
|
| 610 |
inputs=inputs_list,
|
| 611 |
+
outputs=[chatbot, user_input]
|
| 612 |
+
).then(
|
| 613 |
+
clear_input,
|
| 614 |
+
inputs= user_input,
|
| 615 |
+
outputs= user_input
|
| 616 |
)
|
| 617 |
|
| 618 |
new_chat_button.click(
|