artificialguybr commited on
Commit
dcf6e59
·
1 Parent(s): 39d4f12

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -123
app.py CHANGED
@@ -1,134 +1,72 @@
1
- import os
2
  import gradio as gr
3
- import mdtex2html
4
- import torch
5
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, MistralConfig
6
 
7
- # Initialize model and tokenizer
8
  model_name_or_path = "teknium/OpenHermes-2-Mistral-7B"
 
 
9
 
10
- model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
11
- device_map="auto",
12
- trust_remote_code=False,
13
- load_in_8bit=True,
14
- revision="main")
15
 
16
- tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
17
- config = MistralConfig()
 
 
 
18
 
19
- # Text parsing function
20
- def _parse_text(text):
21
- lines = text.split("\n")
22
- lines = [line for line in lines if line != ""]
23
- count = 0
24
- for i, line in enumerate(lines):
25
- if "```" in line:
26
- count += 1
27
- items = line.split("`")
28
- if count % 2 == 1:
29
- lines[i] = f'<pre><code class="language-{items[-1]}">'
30
- else:
31
- lines[i] = f"<br></code></pre>"
32
- else:
33
- if i > 0:
34
- if count % 2 == 1:
35
- line = line.replace("`", r"\`")
36
- line = line.replace("<", "&lt;")
37
- line = line.replace(">", "&gt;")
38
- line = line.replace(" ", "&nbsp;")
39
- line = line.replace("*", "&ast;")
40
- line = line.replace("_", "&lowbar;")
41
- line = line.replace("-", "&#45;")
42
- line = line.replace(".", "&#46;")
43
- line = line.replace("!", "&#33;")
44
- line = line.replace("(", "&#40;")
45
- line = line.replace(")", "&#41;")
46
- line = line.replace("$", "&#36;")
47
- lines[i] = "<br>" + line
48
- text = "".join(lines)
49
- return text
50
 
51
- # Demo launching function
52
- def _launch_demo(args, model, tokenizer, config):
53
- def predict(_query, _chatbot, _task_history):
54
- print(f"User: {_parse_text(_query)}")
55
- _chatbot.append((_parse_text(_query), ""))
56
-
57
- # Prepare the chat template
58
- messages = [
59
- {"role": "system", "content": "You are Hermes 2."},
60
- {"role": "user", "content": _query}
61
- ]
62
-
63
- # Tokenize using the chat template
64
- gen_input = tokenizer.apply_chat_template(messages, return_tensors="pt")
65
-
66
- # Debug: Print the type and value of gen_input
67
- print("Debug: ", type(gen_input), gen_input)
68
-
69
- # If gen_input is a dictionary, move it to CUDA
70
- if isinstance(gen_input, dict):
71
- gen_input = {k: v.to('cuda') for k, v in gen_input.items()}
72
- else:
73
- gen_input = gen_input.to('cuda')
74
-
75
- # Generate a response using the model
76
- generated_ids = model.generate(**gen_input, max_length=300) if isinstance(gen_input, dict) else model.generate(gen_input, max_length=300)
77
-
78
- # Decode the generated IDs to text
79
- full_response = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
80
-
81
- # Update the chatbot state
82
- _chatbot[-1] = (_parse_text(_query), _parse_text(full_response))
83
- yield _chatbot
84
-
85
- print(f"History: {_task_history}")
86
- _task_history.append((_query, full_response))
87
- print(f"OpenHermes: {_parse_text(full_response)}")
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
- def regenerate(_chatbot, _task_history):
91
- if not _task_history:
92
- yield _chatbot
93
- return
94
- item = _task_history.pop(-1)
95
- _chatbot.pop(-1)
96
- yield from predict(item[0], _chatbot, _task_history)
97
-
98
- def reset_user_input():
99
- return gr.update(value="")
100
-
101
- def reset_state(_chatbot, _task_history):
102
- _task_history.clear()
103
- _chatbot.clear()
104
- import gc
105
- gc.collect()
106
- torch.cuda.empty_cache()
107
- return _chatbot
108
-
109
- with gr.Blocks() as demo:
110
- gr.Markdown("""
111
- ## OpenHermes V2 - Mistral 7B: Mistral 7B Based by Teknium!
112
- **Space created by [@artificialguybr](https://twitter.com/artificialguybr). Model by [@Teknium1](https://twitter.com/Teknium1).Thanks HF for GPU!**
113
- **OpenHermes V2 Mistral 7B was trained on 900,000 instructions, and surpasses all previous versions of Hermes 13B and below, and matches 70B on some benchmarks!**
114
- """)
115
- chatbot = gr.Chatbot(label='OpenHermes-V2', elem_classes="control-height", queue=True)
116
- query = gr.Textbox(lines=2, label='Input')
117
- task_history = gr.State([])
118
-
119
  with gr.Row():
120
- submit_btn = gr.Button("🚀 Submit")
121
- empty_btn = gr.Button("🧹 Clear History")
122
- regen_btn = gr.Button("🤔️ Regenerate")
123
-
124
- submit_btn.click(predict, [query, chatbot, task_history], [chatbot], show_progress=True, queue=True) # Enable queue
125
- submit_btn.click(reset_user_input, [], [query], queue=False) #No queue for resetting
126
- empty_btn.click(reset_state, [chatbot, task_history], outputs=[chatbot], show_progress=True, queue=False) #No queue for clearing
127
- regen_btn.click(regenerate, [chatbot, task_history], [chatbot], show_progress=True, queue=True) # Enable queue
128
- demo.queue(max_size=20)
129
- demo.launch()
130
-
131
 
132
- # Main execution
133
- if __name__ == "__main__":
134
- _launch_demo(None, model, tokenizer, config)
 
 
 
 
1
  import gradio as gr
2
+ import re
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
 
4
 
 
5
  model_name_or_path = "teknium/OpenHermes-2-Mistral-7B"
6
+ model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
8
 
9
+ BASE_SYSTEM_MESSAGE = "I carefully provide accurate, factual, thoughtful, nuanced answers and am brilliant at reasoning."
 
 
 
 
10
 
11
+ def make_prediction(prompt, max_tokens=None, temperature=None, top_p=None, top_k=None, repetition_penalty=None):
12
+ input_ids = tokenizer.encode(prompt, return_tensors="pt")
13
+ out = model.generate(input_ids, max_length=max_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty)
14
+ text = tokenizer.decode(out[0], skip_special_tokens=True)
15
+ yield text
16
 
17
+ def clear_chat(chat_history_state, chat_message):
18
+ chat_history_state = []
19
+ chat_message = ''
20
+ return chat_history_state, chat_message
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ def user(message, history):
23
+ history = history or []
24
+ history.append([message, ""])
25
+ return "", history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ def chat(history, system_message, max_tokens, temperature, top_p, top_k, repetition_penalty):
28
+ history = history or []
29
+ if system_message.strip():
30
+ messages = " "+"system\n" + system_message.strip() + "\n" + "\n".join(["\n".join([" "+"user\n"+item[0]+"", " assistant\n"+item[1]+""]) for item in history])
31
+ else:
32
+ messages = " "+"system\n" + BASE_SYSTEM_MESSAGE + "\n" + "\n".join(["\n".join([" "+"user\n"+item[0]+"", " assistant\n"+item[1]+""]) for item in history])
33
+ messages = messages.rstrip()
34
+ messages = messages.rstrip()
35
+ if temperature == 0:
36
+ top_p = 1
37
+ top_k = -1
38
+ prediction = make_prediction(messages, max_tokens=max_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty)
39
+ for tokens in prediction:
40
+ tokens = re.findall(r'(.*?)(\s|$)', tokens)
41
+ for subtoken in tokens:
42
+ subtoken = "".join(subtoken)
43
+ answer = subtoken
44
+ history[-1][1] += answer
45
+ yield history, history, ""
46
 
47
+ with gr.Blocks() as demo:
48
+ with gr.Row():
49
+ with gr.Column():
50
+ gr.Markdown(f"""## Mistral-7B-OpenOrca Playground Space!""")
51
+ with gr.Row():
52
+ chatbot = gr.Chatbot(elem_id="chatbot")
53
+ with gr.Row():
54
+ message = gr.Textbox(label="What do you want to chat about?", placeholder="Ask me anything.", lines=3)
55
+ with gr.Row():
56
+ submit = gr.Button(value="Send message", variant="secondary")
57
+ clear = gr.Button(value="New topic", variant="secondary")
58
+ with gr.Accordion("Show Model Parameters", open=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  with gr.Row():
60
+ with gr.Column():
61
+ max_tokens = gr.Slider(20, 2500, step=20, value=500)
62
+ temperature = gr.Slider(0.0, 2.0, step=0.1, value=0.4)
63
+ top_p = gr.Slider(0.0, 1.0, step=0.05, value=0.95)
64
+ top_k = gr.Slider(1, 100, step=1, value=40)
65
+ repetition_penalty = gr.Slider(1.0, 2.0, step=0.1, value=1.1)
66
+ system_msg = gr.Textbox(BASE_SYSTEM_MESSAGE, lines=5)
 
 
 
 
67
 
68
+ chat_history_state = gr.State()
69
+ clear.click(clear_chat, inputs=[chat_history_state, message], outputs=[chat_history_state, message], queue=False)
70
+ clear.click(lambda: None, None, chatbot, queue=False)
71
+ submit_click_event = submit.click(fn=user, inputs=[message, chat_history_state], outputs=[message, chat_history_state], queue=True).then(fn=chat, inputs=[chat_history_state, system_msg, max_tokens, temperature, top_p, top_k, repetition_penalty], outputs=[chatbot, chat_history_state, message], queue=True)
72
+ demo.queue(max_size=128, concurrency_count=48)