wangrongsheng commited on
Commit
cd7ddd1
·
1 Parent(s): 4223a6d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -37
app.py CHANGED
@@ -1,40 +1,30 @@
1
- from transformers import AutoModel, AutoTokenizer
2
  import gradio as gr
3
 
 
 
 
 
 
4
  tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
5
- model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
6
- model = model.eval()
7
-
8
- MAX_TURNS = 20
9
- MAX_BOXES = MAX_TURNS * 2
10
-
11
-
12
- def predict(input, history=None):
13
- if history is None:
14
- history = []
15
- response, history = model.chat(tokenizer, input, history)
16
- updates = []
17
- for query, response in history:
18
- updates.append(gr.update(visible=True, value="用户:" + query))
19
- updates.append(gr.update(visible=True, value="ChatGLM-6B:" + response))
20
- if len(updates) < MAX_BOXES:
21
- updates = updates + [gr.Textbox.update(visible=False)] * (MAX_BOXES - len(updates))
22
- return [history] + updates
23
-
24
-
25
- with gr.Blocks() as demo:
26
- state = gr.State([])
27
- text_boxes = []
28
- for i in range(MAX_BOXES):
29
- if i % 2 == 0:
30
- text_boxes.append(gr.Markdown(visible=False, label="提问:"))
31
- else:
32
- text_boxes.append(gr.Markdown(visible=False, label="回复:"))
33
-
34
- with gr.Row():
35
- with gr.Column(scale=4):
36
- txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter").style(container=False)
37
- with gr.Column(scale=1):
38
- button = gr.Button("Generate")
39
- button.click(predict, [txt, state], [state] + text_boxes)
40
- demo.queue().launch(share=False)
 
1
+ import psutil
2
  import gradio as gr
3
 
4
+ from functools import partial
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
+
7
+ mem = psutil.virtual_memory()
8
+
9
  tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
10
+
11
+ model = AutoModelForSeq2SeqLM.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).bfloat16()
12
+
13
+ def chat(query, history=[]):
14
+ _, history = model.chat(tokenizer, query, history, max_length=512)
15
+ return history, history
16
+
17
+ description = "This is an unofficial chatbot application based on open source model ChatGLM-6B(https://github.com/THUDM/ChatGLM-6B), running on cpu(therefore max_length is limited to 512). \nIf you want to use this chat bot in your space, 'Duplicate this space' by click the button close to 'Linked Models'. \n"
18
+ title = "ChatGLM-6B Chatbot"
19
+ examples = [["Hello?"], ["你好。"], ["介绍清华"]]
20
+
21
+ chatbot_interface = gr.Interface(
22
+ fn=chat,
23
+ title=title,
24
+ description=description,
25
+ examples=examples,
26
+ inputs=["text", "state"],
27
+ outputs=["chatbot", "state"]
28
+ )
29
+
30
+ chatbot_interface.launch()