File size: 3,943 Bytes
dd48380 affab3c dd48380 fbcf846 73bf78a fbcf846 73bf78a dd48380 73bf78a dd48380 fbcf846 dd48380 73bf78a dd48380 fbcf846 dd48380 73bf78a dd48380 73bf78a 317f842 73bf78a dd48380 73bf78a dd48380 c4e3ff4 dd48380 fbcf846 dd48380 fbcf846 dd48380 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import argparse
from pathlib import Path
import chatglm_cpp
import gradio as gr
import urllib
DEFAULT_MODEL_PATH = "chatglm3-6b.bin"
urllib.request.urlretrieve(
"https://huggingface.co/Braddy/chatglm3-6b-chitchat/resolve/main/q5_1.bin?download=true",
DEFAULT_MODEL_PATH
)
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", default=DEFAULT_MODEL_PATH, type=Path, help="model path")
parser.add_argument("--mode", default="chat", type=str, choices=["chat", "generate"], help="inference mode")
parser.add_argument("-l", "--max_new_tokens", default=64, type=int, help="max total output tokens")
parser.add_argument("-c", "--max_context_length", default=1024, type=int, help="max context length")
parser.add_argument("--top_k", default=40, type=int, help="top-k sampling")
parser.add_argument("--top_p", default=0.75, type=float, help="top-p sampling")
parser.add_argument("--temp", default=0.5, type=float, help="temperature")
parser.add_argument("--repeat_penalty", default=1.0, type=float, help="penalize repeat sequence of tokens")
parser.add_argument("-t", "--threads", default=0, type=int, help="number of threads for inference")
parser.add_argument("--plain", action="store_true", help="display in plain text without markdown support")
args = parser.parse_args()
pipeline = chatglm_cpp.Pipeline(args.model)
system_message = chatglm_cpp.ChatMessage(role="system", content="请你现在扮演一个软件工程师,名字叫做贺英旭。你需要以这个身份和朋友们对话。")
def postprocess(text):
if args.plain:
return f"<pre>{text}</pre>"
return text
def predict(input, chatbot, max_new_tokens, top_p, temperature, messages):
chatbot.append((postprocess(input), ""))
messages.append(chatglm_cpp.ChatMessage(role="user", content=input))
full_messages = [system_message] + messages
generation_kwargs = dict(
max_new_tokens=max_new_tokens,
max_context_length=args.max_context_length,
do_sample=temperature > 0,
top_k=args.top_k,
top_p=top_p,
temperature=temperature,
repetition_penalty=args.repeat_penalty,
num_threads=args.threads,
stream=True,
)
response = ""
chunks = []
for chunk in pipeline.chat(full_messages, **generation_kwargs):
response += chunk.content
chunks.append(chunk)
chatbot[-1] = (chatbot[-1][0], postprocess(response))
yield chatbot, messages
messages.append(pipeline.merge_streaming_messages(chunks))
yield chatbot, messages
def reset_user_input():
return gr.update(value="")
def reset_state():
return [], []
title = """
<div style="text-align: center">
<h1>Chichat</h1>
<p style="text-align: center;">Free feel to talk about anything :)</p>
</div>
"""
with gr.Blocks() as demo:
gr.HTML(title)
chatbot = gr.Chatbot(height=300, label="Check this out!")
with gr.Row():
with gr.Column(scale=4):
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=8)
submitBtn = gr.Button("Submit", variant="primary")
with gr.Column(scale=1):
max_new_tokens = gr.Slider(0, 512, value=args.max_new_tokens, step=1.0, label="Maximum output tokens", interactive=True)
top_p = gr.Slider(0, 1, value=args.top_p, step=0.01, label="Top P", interactive=True)
temperature = gr.Slider(0, 1, value=args.temp, step=0.01, label="Temperature", interactive=True)
emptyBtn = gr.Button("Clear History")
messages = gr.State([])
submitBtn.click(
predict,
[user_input, chatbot, max_new_tokens, top_p, temperature, messages],
[chatbot, messages],
show_progress=True,
)
submitBtn.click(reset_user_input, [], [user_input])
emptyBtn.click(reset_state, outputs=[chatbot, messages], show_progress=True)
demo.queue().launch(share=False, inbrowser=True) |