File size: 3,943 Bytes
dd48380
 
 
 
 
 
 
 
 
 
affab3c
dd48380
 
 
 
 
 
 
fbcf846
73bf78a
 
fbcf846
73bf78a
dd48380
 
 
 
 
 
73bf78a
dd48380
 
 
 
 
 
 
 
fbcf846
dd48380
 
73bf78a
dd48380
 
fbcf846
dd48380
 
 
 
 
 
 
 
 
 
 
73bf78a
 
 
 
 
 
 
dd48380
 
 
 
 
 
 
 
 
 
 
73bf78a
317f842
73bf78a
 
 
 
 
dd48380
 
73bf78a
dd48380
c4e3ff4
dd48380
 
 
 
 
fbcf846
dd48380
 
 
 
 
 
 
 
fbcf846
dd48380
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import argparse
from pathlib import Path

import chatglm_cpp
import gradio as gr

import urllib

DEFAULT_MODEL_PATH = "chatglm3-6b.bin"

urllib.request.urlretrieve(
    "https://huggingface.co/Braddy/chatglm3-6b-chitchat/resolve/main/q5_1.bin?download=true", 
    DEFAULT_MODEL_PATH
    )

parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", default=DEFAULT_MODEL_PATH, type=Path, help="model path")
parser.add_argument("--mode", default="chat", type=str, choices=["chat", "generate"], help="inference mode")
parser.add_argument("-l", "--max_new_tokens", default=64, type=int, help="max total output tokens")
parser.add_argument("-c", "--max_context_length", default=1024, type=int, help="max context length")
parser.add_argument("--top_k", default=40, type=int, help="top-k sampling")
parser.add_argument("--top_p", default=0.75, type=float, help="top-p sampling")
parser.add_argument("--temp", default=0.5, type=float, help="temperature")
parser.add_argument("--repeat_penalty", default=1.0, type=float, help="penalize repeat sequence of tokens")
parser.add_argument("-t", "--threads", default=0, type=int, help="number of threads for inference")
parser.add_argument("--plain", action="store_true", help="display in plain text without markdown support")
args = parser.parse_args()

pipeline = chatglm_cpp.Pipeline(args.model)
system_message = chatglm_cpp.ChatMessage(role="system", content="请你现在扮演一个软件工程师,名字叫做贺英旭。你需要以这个身份和朋友们对话。")


def postprocess(text):
    if args.plain:
        return f"<pre>{text}</pre>"
    return text


def predict(input, chatbot, max_new_tokens, top_p, temperature, messages):
    chatbot.append((postprocess(input), ""))
    messages.append(chatglm_cpp.ChatMessage(role="user", content=input))
    full_messages = [system_message] + messages

    generation_kwargs = dict(
        max_new_tokens=max_new_tokens,
        max_context_length=args.max_context_length,
        do_sample=temperature > 0,
        top_k=args.top_k,
        top_p=top_p,
        temperature=temperature,
        repetition_penalty=args.repeat_penalty,
        num_threads=args.threads,
        stream=True,
    )

    response = ""
    chunks = []
    for chunk in pipeline.chat(full_messages, **generation_kwargs):
        response += chunk.content
        chunks.append(chunk)
        chatbot[-1] = (chatbot[-1][0], postprocess(response))
        yield chatbot, messages
    messages.append(pipeline.merge_streaming_messages(chunks))

    yield chatbot, messages


def reset_user_input():
    return gr.update(value="")


def reset_state():
    return [], []

title = """
<div style="text-align: center">
    <h1>Chichat</h1>
    <p style="text-align: center;">Free feel to talk about anything :)</p>
</div>
"""


with gr.Blocks() as demo:
    gr.HTML(title)

    chatbot = gr.Chatbot(height=300, label="Check this out!")
    with gr.Row():
        with gr.Column(scale=4):
            user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=8)
            submitBtn = gr.Button("Submit", variant="primary")
        with gr.Column(scale=1):
            max_new_tokens = gr.Slider(0, 512, value=args.max_new_tokens, step=1.0, label="Maximum output tokens", interactive=True)
            top_p = gr.Slider(0, 1, value=args.top_p, step=0.01, label="Top P", interactive=True)
            temperature = gr.Slider(0, 1, value=args.temp, step=0.01, label="Temperature", interactive=True)
            emptyBtn = gr.Button("Clear History")

    messages = gr.State([])

    submitBtn.click(
        predict,
        [user_input, chatbot, max_new_tokens, top_p, temperature, messages],
        [chatbot, messages],
        show_progress=True,
    )
    submitBtn.click(reset_user_input, [], [user_input])

    emptyBtn.click(reset_state, outputs=[chatbot, messages], show_progress=True)

demo.queue().launch(share=False, inbrowser=True)