Dimple-7B / app.py
rp-yu's picture
Update app.py
f30487f verified
try:
import spaces
except ImportError:
# Local run: define dummy decorator
class spaces:
@staticmethod
def GPU(duration=10):
def dummy(func):
return func
return dummy
import argparse
import json
import time
import gradio as gr
from filelock import FileLock
from PIL import Image
import threading
from utils import (
build_logger,
server_error_msg,
violates_moderation,
moderation_msg,
get_log_filename,
)
from conversation import Conversation
from model import (
FullSequenceStreamer,
get_model,
)
logger = build_logger("dimple", "dimple.log")
no_change_btn = gr.Button()
enable_btn = gr.Button(interactive=True)
disable_btn = gr.Button(interactive=False)
@spaces.GPU(duration=10)
def make_zerogpu_happy():
pass
def write2file(path, content):
lock = FileLock(f"{path}.lock")
with lock:
with open(path, "a") as fout:
fout.write(content)
model, processor = get_model("cuda:0")
get_window_url_params = """
function() {
const params = new URLSearchParams(window.location.search);
url_params = Object.fromEntries(params);
console.log(url_params);
return url_params;
}
"""
def init_state(state=None):
if state is not None:
del state
return Conversation()
def vote_last_response(state, liked, request: gr.Request):
conv_data = {
"tstamp": round(time.time(), 4),
"like": liked,
"model": '"rp-yu/Dimple-7B"',
"state": state.dict(),
"ip": request.client.host,
}
write2file(get_log_filename(), json.dumps(conv_data) + "\n")
def upvote_last_response(state, request: gr.Request):
logger.info(f"upvote. ip: {request.client.host}")
vote_last_response(state, True, request)
textbox = gr.MultimodalTextbox(value=None, interactive=True)
return (textbox,) + (disable_btn,) * 3
def downvote_last_response(state, request: gr.Request):
logger.info(f"downvote. ip: {request.client.host}")
vote_last_response(state, False, request)
textbox = gr.MultimodalTextbox(value=None, interactive=True)
return (textbox,) + (disable_btn,) * 3
def vote_selected_response(
state, request: gr.Request, data: gr.LikeData
):
logger.info(
f"Vote: {data.liked}, index: {data.index}, value: {data.value} , ip: {request.client.host}"
)
conv_data = {
"tstamp": round(time.time(), 4),
"like": data.liked,
"index": data.index,
"model": 'rp-yu/Dimple-7B',
"state": state.dict(),
"ip": request.client.host,
}
write2file(get_log_filename(), json.dumps(conv_data) + "\n")
return
def flag_last_response(state, request: gr.Request):
logger.info(f"flag. ip: {request.client.host}")
vote_last_response(state, "flag", request)
textbox = gr.MultimodalTextbox(value=None, interactive=True)
return (textbox,) + (disable_btn,) * 3
def regenerate(state, image_process_mode, request: gr.Request):
logger.info(f"regenerate. ip: {request.client.host}")
# state.messages[-1][-1] = None
state.update_message(Conversation.ASSISTANT, content='', image=None, idx=-1)
prev_human_msg = state.messages[-2]
if type(prev_human_msg[1]) in (tuple, list):
prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
state.skip_next = False
textbox = gr.MultimodalTextbox(value=None, interactive=True)
return (state, state.to_gradio_chatbot(), textbox) + (disable_btn,) * 5
def clear_history(request: gr.Request):
logger.info(f"clear_history. ip: {request.client.host}")
state = init_state()
textbox = gr.MultimodalTextbox(value=None, interactive=True)
return (state, state.to_gradio_chatbot(), textbox) + (disable_btn,) * 5
def add_text(state, message, system_prompt, request: gr.Request):
print(f"state: {state}")
if not state:
state = init_state()
images = message.get("files", [])
text = message.get("text", "").strip()
logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
# import pdb; pdb.set_trace()
textbox = gr.MultimodalTextbox(value=None, interactive=False)
if len(text) <= 0 and len(images) == 0:
state.skip_next = True
return (state, state.to_gradio_chatbot(), textbox) + (no_change_btn,) * 5
if args.moderate:
flagged = violates_moderation(text)
if flagged:
state.skip_next = True
textbox = gr.MultimodalTextbox(
value={"text": moderation_msg}, interactive=True
)
return (state, state.to_gradio_chatbot(), textbox) + (no_change_btn,) * 5
images = [Image.open(path).convert("RGB") for path in images]
if len(images) > 0 and len(state.get_images(source=state.USER)) > 0:
state = init_state(state)
state.set_system_message(system_prompt)
state.append_message(Conversation.USER, text, images)
state.skip_next = False
return (state, state.to_gradio_chatbot(), textbox) + (
disable_btn,
) * 5
def http_bot(
state,
temperature,
top_p,
p_threshold,
alg_temp,
max_new_tokens,
steps,
alg,
):
start_tstamp = time.time()
if hasattr(state, "skip_next") and state.skip_next:
# This generate call is skipped due to invalid inputs
yield (
state,
state.to_gradio_chatbot(),
gr.MultimodalTextbox(interactive=False),
) + (no_change_btn,) * 5
return
all_images = state.get_images(source=state.USER)
all_image_paths = [state.save_image(image) for image in all_images]
if len(all_images) == 0:
all_images = None
messages = state.get_prompt()
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True, add_vision_id=False
)
inputs = processor(
text=text,
images=all_images,
videos=None,
padding="longest",
return_tensors="pt",
).to(model.device)
input_ids = inputs.pop("input_ids")
streamer = FullSequenceStreamer(
processor.tokenizer,
timeout=10,
skip_special_tokens=True,
)
def run_generate():
output = model.diffusion_generate(
input_ids,
max_new_tokens=int(max_new_tokens),
output_history=True,
return_dict_in_generate=True,
steps=int(steps),
temperature=float(temperature),
top_p=float(top_p),
alg=alg,
alg_temp = float(alg_temp),
use_cache=True,
alg_p_threshold=float(p_threshold),
use_original_confidence=True,
decoding_pipeline="dim",
streamer = streamer,
**inputs
)
thread = threading.Thread(target=run_generate)
thread.start()
logger.info(f"==== wait for first token ====\n")
state.append_message(Conversation.ASSISTANT, state.streaming_placeholder)
yield (
state,
state.to_gradio_chatbot(),
gr.MultimodalTextbox(interactive=False),
) + (disable_btn,) * 5
try:
# Stream output
for ans in streamer:
if len(ans) > 1:
ans = "\n".join(ans)
else:
ans = ans[0]
state.update_message(Conversation.ASSISTANT, ans, None)
yield (
state,
state.to_gradio_chatbot(),
gr.MultimodalTextbox(interactive=False),
) + (disable_btn,) * 5
except Exception as e:
state.update_message(Conversation.ASSISTANT, server_error_msg, None)
yield (
state,
state.to_gradio_chatbot(),
gr.MultimodalTextbox(interactive=True),
) + (
disable_btn,
disable_btn,
disable_btn,
enable_btn,
enable_btn,
)
return
state.end_of_current_turn()
yield (
state,
state.to_gradio_chatbot(),
gr.MultimodalTextbox(interactive=True),
) + (enable_btn,) * 5
finish_tstamp = time.time()
logger.info(f"{ans}")
data = {
"tstamp": round(finish_tstamp, 4),
"like": None,
"model": "rp-yu/Dimple-7B",
"start": round(start_tstamp, 4),
"finish": round(start_tstamp, 4),
"state": state.dict(),
"images": all_image_paths,
}
write2file(get_log_filename(), json.dumps(data) + "\n")
title_html = """
<div style="width:100%; max-width:600px; margin:auto;">
<img src="https://cdn-uploads.huggingface.co/production/uploads/635364b3c41f548fe39db945/Iny16670lQgUwURiUfP-i.png" style="width:100%;"><br>
<a href="https://arxiv.org/abs/">[πŸ“œ Dimple Paper]</a><br>
<a href="https://github.com/yu-rp/Dimple">[🌟 Github]</a><br>
<a href="https://huggingface.co/rp-yu/Dimple-7B">[πŸ€— Huggingface Model]</a><br>
<a href="https://huggingface.co/spaces/rp-yu/dimple">[πŸ’¬ Huggingface Demo]</a><br>
</div>
"""
tos_markdown = """
<div style="text-align: right;">
Acknowledgement: This demo is built upon the Hugging Face Space of <a href="https://huggingface.co/spaces/OpenGVLab/InternVL" target="_blank">InternVL</a>.
</div>
"""
# .gradio-container {margin: 5px 10px 0 10px !important};
block_css = """
.gradio-container {margin: 0.1% 1% 0 1% !important; max-width: 98% !important;};
#buttons button {
min-width: min(120px,100%);
}
.gradient-text {
font-size: 28px;
width: auto;
font-weight: bold;
background: linear-gradient(45deg, red, orange, yellow, green, blue, indigo, violet);
background-clip: text;
-webkit-background-clip: text;
color: transparent;
}
.plain-text {
font-size: 22px;
width: auto;
font-weight: bold;
}
"""
def build_demo():
textbox = gr.MultimodalTextbox(
interactive=True,
file_types=["image"],
placeholder="Enter message or upload file...",
show_label=False,
)
with gr.Blocks(
title="Dimple-7B",
theme=gr.themes.Default(),
css=block_css,
) as demo:
state = gr.State()
with gr.Row():
with gr.Column(scale=2):
gr.HTML(title_html)
with gr.Accordion("Settings", open=False) as setting_row:
system_prompt = gr.Textbox(
value="You are a helpful assistant.",
label="System Prompt",
interactive=True,
)
temperature = gr.Slider(
minimum=0.0,
maximum=2.0,
value=0.2,
step=0.1,
interactive=True,
label="Temperature",
)
top_p = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.95,
step=0.1,
interactive=True,
label="Top P",
)
alg = gr.Radio(
choices=["origin", "maskgit_plus", "entropy"],
value="origin",
label="Selection Algorithm",
interactive=True,
)
p_threshold = gr.Slider(
minimum=0.,
maximum=1.0,
value=0.95,
step=0.01,
interactive=True,
label="Probability threshold for Confident Decoding",
)
alg_temp = gr.Slider(
minimum=0.0,
maximum=2.0,
value=0.2,
step=0.1,
interactive=True,
label="Temperature for Selectiion Algorithm",
)
max_new_tokens = gr.Slider(
minimum=1,
maximum=128,
value=64,
step=2,
interactive=True,
label="Max output tokens",
)
steps = gr.Slider(
minimum=1,
maximum=128,
value=64,
step=2,
interactive=True,
label="Number of decoding steps",
)
examples = gr.Examples(
examples=[
[
{
"files": [
"gallery/14.jfif",
],
"text": "Please help me analyze this picture.",
}
],
[
{
"files": [
"gallery/prod_9.jpg",
],
"text": "Please help me describe the image.",
}
],
[
{
"files": [
"gallery/15.PNG",
],
"text": "Please help me analyze this picture.",
}
],
],
inputs=[textbox],
)
with gr.Column(scale=8):
chatbot = gr.Chatbot(
elem_id="chatbot",
label="Dimple-7B",
height=580,
show_copy_button=True,
show_share_button=True,
avatar_images=[
"assets/human.png",
"assets/assistant.png",
],
bubble_full_width=False,
)
with gr.Row():
with gr.Column(scale=8):
textbox.render()
with gr.Column(scale=1, min_width=50):
submit_btn = gr.Button(value="Send", variant="primary")
with gr.Row(elem_id="buttons") as button_row:
upvote_btn = gr.Button(value="πŸ‘ Upvote", interactive=False)
downvote_btn = gr.Button(value="πŸ‘Ž Downvote", interactive=False)
flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
# stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
regenerate_btn = gr.Button(
value="πŸ”„ Regenerate", interactive=False
)
clear_btn = gr.Button(value="πŸ—‘οΈ Clear", interactive=False)
gr.Markdown(tos_markdown)
url_params = gr.JSON(visible=False)
# Register listeners
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
upvote_btn.click(
upvote_last_response,
[state],
[textbox, upvote_btn, downvote_btn, flag_btn],
)
downvote_btn.click(
downvote_last_response,
[state],
[textbox, upvote_btn, downvote_btn, flag_btn],
)
chatbot.like(
vote_selected_response,
[state],
[],
)
flag_btn.click(
flag_last_response,
[state],
[textbox, upvote_btn, downvote_btn, flag_btn],
)
regenerate_btn.click(
regenerate,
[state, system_prompt],
[state, chatbot, textbox] + btn_list,
).then(
http_bot,
[
state,
temperature,
top_p,
p_threshold,
alg_temp,
max_new_tokens,
steps,
alg,
],
[state, chatbot, textbox] + btn_list,
)
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
textbox.submit(
add_text,
[state, textbox, system_prompt],
[state, chatbot, textbox] + btn_list,
).then(
http_bot,
[
state,
temperature,
top_p,
p_threshold,
alg_temp,
max_new_tokens,
steps,
alg,
],
[state, chatbot, textbox] + btn_list,
)
submit_btn.click(
add_text,
[state, textbox, system_prompt],
[state, chatbot, textbox] + btn_list,
).then(
http_bot,
[
state,
temperature,
top_p,
p_threshold,
alg_temp,
max_new_tokens,
steps,
alg,
],
[state, chatbot, textbox] + btn_list,
)
return demo
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=7860)
parser.add_argument("--concurrency-count", type=int, default=4)
parser.add_argument("--share", action="store_true")
parser.add_argument("--moderate", action="store_true")
args = parser.parse_args()
logger.info(f"args: {args}")
logger.info(args)
demo = build_demo()
demo.queue(api_open=False).launch(
server_name=args.host,
server_port=args.port,
share=args.share,
max_threads=args.concurrency_count,
)