Hilda Cran May mikeee commited on
Commit
fce4951
·
0 Parent(s):

Duplicate from mikeee/qwen-7b-chat

Browse files

Co-authored-by: mikeee <[email protected]>

Files changed (7) hide show
  1. .gitattributes +35 -0
  2. .gitignore +1 -0
  3. .ruff.toml +17 -0
  4. README.md +13 -0
  5. app.py +535 -0
  6. example_list.py +56 -0
  7. requirements.txt +19 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .ruff_cache
.ruff.toml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Assume Python 3.10.
2
+ target-version = "py310"
3
+ # Decrease the maximum line length to 79 characters.
4
+ line-length = 300
5
+
6
+ # pyflakes, pycodestyle, isort
7
+ # flake8 YTT, pydocstyle D, pylint PLC
8
+ select = ["F", "E", "W", "I001", "YTT", "D", "PLC"]
9
+ # select = ["ALL"]
10
+
11
+ # D103 Missing docstring in public function
12
+ # D101 Missing docstring in public class
13
+ # `multi-line-summary-first-line` (D212)
14
+ # `one-blank-line-before-class` (D203)
15
+ extend-ignore = ["D103", "D101", "D212", "D203"]
16
+
17
+ exclude = [".venv"]
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Qwen 7b Chat
3
+ emoji: ⚡
4
+ colorFrom: purple
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 3.39.0
8
+ app_file: app.py
9
+ pinned: false
10
+ duplicated_from: mikeee/qwen-7b-chat
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,535 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Run qwen 7b chat.
3
+
4
+ transformers 4.31.0
5
+
6
+ import torch
7
+ torch.cuda.empty_cache()
8
+
9
+ model.chat(
10
+ tokenizer: transformers.tokenization_utils.PreTrainedTokenizer,
11
+ query: str,
12
+ history: Optional[List[Tuple[str, str]]],
13
+ system: str = 'You are a helpful assistant.',
14
+ append_history: bool = True,
15
+ stream: Optional[bool] = <object object at 0x7f905797ec20>,
16
+ stop_words_ids: Optional[List[List[int]]] = None,
17
+ **kwargs) -> Tuple[str, List[Tuple[str, str]]]
18
+ )
19
+
20
+ model.generation_config
21
+ GenerationConfig {
22
+ "chat_format": "chatml",
23
+ "do_sample": true,
24
+ "eos_token_id": 151643,
25
+ "max_new_tokens": 512,
26
+ "max_window_size": 6144,
27
+ "pad_token_id": 151643,
28
+ "top_k": 0,
29
+ "top_p": 0.5,
30
+ "transformers_version": "4.31.0",
31
+ "trust_remote_code": true
32
+ }
33
+ """
34
+ # pylint: disable=line-too-long, invalid-name, no-member, redefined-outer-name, missing-function-docstring, missing-class-docstring, broad-except,
35
+ import gc
36
+ import os
37
+ import sys
38
+ import time
39
+ from collections import deque
40
+ from dataclasses import asdict, dataclass
41
+ from textwrap import dedent
42
+ from types import SimpleNamespace
43
+ from typing import List, Optional
44
+
45
+ import gradio as gr
46
+ import torch
47
+ from loguru import logger
48
+ from transformers import AutoModelForCausalLM, AutoTokenizer
49
+ from transformers.generation import GenerationConfig
50
+
51
+ from example_list import css, example_list
52
+
53
+ if not torch.cuda.is_available():
54
+ raise gr.Error("No cuda, cant continue...")
55
+
56
+ os.environ["TZ"] = "Asia/Shanghai"
57
+ try:
58
+ time.tzset() # type: ignore # pylint: disable=no-member
59
+ except Exception:
60
+ # Windows
61
+ logger.warning("Windows, cant run time.tzset()")
62
+
63
+ model_name = "Qwen/Qwen-7B-Chat"
64
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
65
+
66
+ n_gpus = torch.cuda.device_count()
67
+ try:
68
+ _ = f"{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB"
69
+ except AssertionError:
70
+ _ = 0
71
+ max_memory = {i: _ for i in range(n_gpus)}
72
+
73
+ del sys
74
+ # logger.remove() # to turn on trace
75
+ # logger.add(sys.stderr, level="TRACE")
76
+ # logger.trace(f"{chat_history=}")
77
+
78
+
79
+ def gen_model(model_name: str):
80
+ model = AutoModelForCausalLM.from_pretrained(
81
+ model_name,
82
+ trust_remote_code=True,
83
+ device_map="auto",
84
+ load_in_4bit=True,
85
+ max_memory=max_memory,
86
+ fp16=True,
87
+ torch_dtype=torch.float16,
88
+ bnb_4bit_quant_type="nf4",
89
+ bnb_4bit_compute_dtype=torch.bfloat16,
90
+ )
91
+ model = model.eval()
92
+ model.generation_config = GenerationConfig.from_pretrained(
93
+ model_name,
94
+ trust_remote_code=True,
95
+ )
96
+ return model
97
+
98
+
99
+ def user_clear(message, chat_history):
100
+ """Gen a response, clear message in user textbox."""
101
+ logger.debug(f"{message=}")
102
+
103
+ try:
104
+ chat_history.append([message, ""])
105
+ except Exception:
106
+ chat_history = deque([message, ""], maxlen=5)
107
+
108
+ logger.trace(f"{chat_history=}")
109
+ return "", chat_history
110
+
111
+
112
+ def user(message, chat_history):
113
+ """Gen a response."""
114
+ logger.debug(f"{message=}")
115
+ logger.trace(f"{chat_history=}")
116
+
117
+ try:
118
+ chat_history.append([message, ""])
119
+ except Exception:
120
+ chat_history = deque([message, ""], maxlen=5)
121
+ return message, chat_history
122
+
123
+
124
+ # for rerun in tests
125
+ model = None
126
+ gc.collect()
127
+ torch.cuda.empty_cache()
128
+
129
+ if not torch.cuda.is_available():
130
+ # raise gr.Error("GPU not available, cant run. Turn on GPU and retry")
131
+ raise SystemExit("GPU not available, cant run. Turn on GPU and retry")
132
+
133
+ model = gen_model(model_name)
134
+
135
+
136
+ def bot(chat_history, **kwargs):
137
+ try:
138
+ message = chat_history[-1][0]
139
+ except Exception as exc:
140
+ logger.error(f"{chat_history=}: {exc}")
141
+ return chat_history
142
+ logger.debug(f"{chat_history=}")
143
+ try:
144
+ _ = """
145
+ response, chat_history = model.chat(
146
+ tokenizer,
147
+ message,
148
+ history=chat_history,
149
+ temperature=0.7,
150
+ repetition_penalty=1.2,
151
+ # max_length=128,
152
+ )
153
+ """
154
+ logger.debug("run model.chat...")
155
+ model.generation_config.update(**kwargs)
156
+ response, chat_history = model.chat(
157
+ tokenizer,
158
+ message,
159
+ chat_history[:-1],
160
+ # **kwargs,
161
+ )
162
+ del response
163
+ return chat_history
164
+ except Exception as exc:
165
+ logger.error(exc)
166
+ chat_history[:-1].append(["message", str(exc)])
167
+ return chat_history
168
+
169
+
170
+ def bot_stream(chat_history, **kwargs):
171
+ logger.trace(f"{chat_history=}")
172
+ logger.trace(f"{kwargs=}")
173
+
174
+ try:
175
+ message = chat_history[-1][0]
176
+ except Exception as exc:
177
+ logger.error(f"{chat_history=}: {exc}")
178
+ raise gr.Error(f"{chat_history=}")
179
+ # yield chat_history
180
+
181
+ # for elm in model.chat_stream(tokenizer, message, chat_history):
182
+ model.generation_config.update(**kwargs)
183
+ response = ""
184
+ for elm in model.chat_stream(tokenizer, message, chat_history):
185
+ chat_history[-1] = [message, elm]
186
+ response = elm
187
+ yield chat_history
188
+ logger.debug(f"{response=}")
189
+ logger.debug(f"{model.generation_config=}")
190
+
191
+
192
+ SYSTEM_PROMPT = "You are a helpful assistant."
193
+ MAX_MAX_NEW_TOKENS = 2048 # sequence length 2048
194
+ MAX_NEW_TOKENS = 256
195
+
196
+
197
+ @dataclass
198
+ class Config:
199
+ max_new_tokens: int = MAX_NEW_TOKENS
200
+ repetition_penalty: float = 1.1
201
+ temperature: float = 1.0
202
+ top_k: int = 0
203
+ top_p: float = 0.9
204
+
205
+
206
+ # stats_default = SimpleNamespace(llm=model, system_prompt=SYSTEM_PROMPT, config=Config())
207
+ stats_default = SimpleNamespace(llm=None, system_prompt=SYSTEM_PROMPT, config=Config())
208
+
209
+
210
+ # input max_new_tokens temperature repetition_penalty top_k top_p system_prompt history
211
+ def api_fn( # pylint: disable=too-many-arguments
212
+ input_text: Optional[str],
213
+ # max_length: int = 256,
214
+ max_new_tokens: int = stats_default.config.max_new_tokens,
215
+ temperature: float = stats_default.config.temperature,
216
+ repetition_penalty: float = stats_default.config.repetition_penalty,
217
+ top_k: int = stats_default.config.top_k,
218
+ top_p: int = stats_default.config.top_p,
219
+ system_prompt: Optional[str] = None,
220
+ history: Optional[List[str]] = None,
221
+ ):
222
+ if input_text is None:
223
+ input_text = ""
224
+ try:
225
+ input_text = str(input_text).strip()
226
+ except Exception as exc:
227
+ logger.error(exc)
228
+ input_text = ""
229
+ if not input_text:
230
+ return ""
231
+ if history is None:
232
+ history = []
233
+ try:
234
+ temperature = float(temperature)
235
+ except Exception:
236
+ temperature = stats_default.config.temperature
237
+
238
+ if system_prompt is None:
239
+ system_prompt = stats_default.system_prompt
240
+ # if max_length < 10: max_length = 4096
241
+ if max_new_tokens < 10:
242
+ max_new_tokens = stats_default.config.max_new_tokens
243
+ if top_p < 0.1 or top_p > 1:
244
+ top_p = stats_default.config.top_p
245
+ if temperature <= 0.5:
246
+ temperature = stats_default.config.temperature
247
+
248
+ _ = {
249
+ "max_new_tokens": max_new_tokens,
250
+ "temperature": temperature,
251
+ "repetition_penalty": repetition_penalty,
252
+ "top_k": top_k,
253
+ "top_p": top_p,
254
+ }
255
+ model.generation_config.update(**_)
256
+ try:
257
+ res, _ = model.chat(
258
+ tokenizer,
259
+ input_text,
260
+ history=history,
261
+ # max_length=max_length,
262
+ append_history=False,
263
+ )
264
+ # logger.debug(f"{res=} \n{_=}")
265
+ except Exception as exc:
266
+ logger.error(f"{exc=}")
267
+ res = str(exc)
268
+
269
+ logger.debug(f"api {res=}")
270
+ logger.debug(f"api {model.generation_config=}")
271
+
272
+ return res
273
+
274
+
275
+ theme = gr.themes.Soft(text_size="sm")
276
+ with gr.Blocks(
277
+ theme=theme,
278
+ title=model_name.lower(),
279
+ css=css,
280
+ ) as block:
281
+ stats = gr.State(stats_default)
282
+
283
+ # would this reset model?
284
+ model.generation_config = GenerationConfig.from_pretrained(
285
+ model_name,
286
+ trust_remote_code=True,
287
+ )
288
+ config = asdict(stats.value.config)
289
+
290
+ def bot_stream_state(chat_history):
291
+ logger.trace(f"{chat_history=}")
292
+ yield from bot_stream(chat_history, **config)
293
+
294
+ with gr.Accordion("🎈 Info", open=False):
295
+ gr.Markdown(
296
+ dedent(
297
+ f"""
298
+ ## {model_name.lower()}
299
+
300
+ * temperature range: .51 and up; higher temperature implies more randomness. Suggested temperature for chatting and creative writing is around 1.1 while it should be set to 0.51-1.0 for summarizing and translation.
301
+ * Set `repetition_penalty` to 2.1 or higher for a chatty conversation (more unpredictable and undesirable output). Lower it to 1.1 or smaller if more focused anwsers are desired (for example for translations or fact-oriented queries).
302
+ * Smaller `top_k` probably will result in smoothier sentences.
303
+ (`top_k=0` is equivalent to `top_k` equal to very very big though.) Consult `transformers` documentation for more details.
304
+ * An API is available at https://mikeee-qwen-7b-chat.hf.space/ that can be queried, e.g., in python
305
+ ```python
306
+ from gradio_client import Client
307
+
308
+ client = Client("https://mikeee-qwen-7b-chat.hf.space/")
309
+
310
+ result = client.predict(
311
+ "你好!", # user prompt
312
+ 256, # max_new_tokens
313
+ 1.2, # temperature
314
+ 1.1, # repetition_penalty
315
+ 0, # top_k
316
+ 0.9, # top_p
317
+ "You are a helpful assistant.", # system_prompt
318
+ None, # history
319
+ api_name="/api"
320
+ )
321
+ print(result)
322
+ ```
323
+ or in javascript
324
+ ```js
325
+ import {{ client }} from "@gradio/client";
326
+
327
+ const app = await client("https://mikeee-qwen-7b-chat.hf.space/");
328
+ const result = await app.predict("api", [...]);
329
+ console.log(result.data);
330
+ ```
331
+ Check documentation and examples by clicking `Use via API` at the very bottom of [https://huggingface.co/spaces/mikeee/qwen-7b-chat](https://huggingface.co/spaces/mikeee/qwen-7b-chat).
332
+
333
+ <p></p>
334
+ Most examples are meant for another model.
335
+ You probably should try to test
336
+ some related prompts. System prompt can be changed in Advaned Options as well."""
337
+ ),
338
+ elem_classes="xsmall",
339
+ )
340
+
341
+ chatbot = gr.Chatbot(height=500, value=deque([], maxlen=5)) # type: ignore
342
+
343
+ with gr.Row():
344
+ with gr.Column(scale=5):
345
+ msg = gr.Textbox(
346
+ label="Chat Message Box",
347
+ placeholder="Ask me anything (press Shift+Enter or click Submit to send)",
348
+ show_label=False,
349
+ # container=False,
350
+ lines=4,
351
+ max_lines=30,
352
+ show_copy_button=True,
353
+ # ).style(container=False)
354
+ )
355
+ with gr.Column(scale=1, min_width=50):
356
+ with gr.Row():
357
+ submit = gr.Button("Submit", elem_classes="xsmall")
358
+ stop = gr.Button("Stop", visible=True)
359
+ clear = gr.Button("Clear History", visible=True)
360
+
361
+ msg_submit_event = msg.submit(
362
+ # fn=conversation.user_turn,
363
+ fn=user,
364
+ inputs=[msg, chatbot],
365
+ outputs=[msg, chatbot],
366
+ queue=True,
367
+ show_progress="full",
368
+ # api_name=None,
369
+ ).then(bot_stream_state, chatbot, chatbot, queue=True)
370
+ submit_click_event = submit.click(
371
+ # fn=lambda x, y: ("",) + user(x, y)[1:], # clear msg
372
+ fn=user_clear, # clear msg
373
+ inputs=[msg, chatbot],
374
+ outputs=[msg, chatbot],
375
+ queue=True,
376
+ show_progress="full",
377
+ # api_name=None,
378
+ ).then(bot_stream_state, chatbot, chatbot, queue=True)
379
+ stop.click(
380
+ fn=None,
381
+ inputs=None,
382
+ outputs=None,
383
+ cancels=[msg_submit_event, submit_click_event],
384
+ queue=False,
385
+ )
386
+ clear.click(lambda: None, None, chatbot, queue=False)
387
+
388
+ with gr.Accordion(label="Advanced Options", open=False):
389
+ system_prompt = gr.Textbox(
390
+ label="System prompt",
391
+ value=stats_default.system_prompt,
392
+ lines=3,
393
+ visible=True,
394
+ )
395
+ max_new_tokens = gr.Slider(
396
+ label="Max new tokens",
397
+ minimum=1,
398
+ maximum=MAX_MAX_NEW_TOKENS,
399
+ step=1,
400
+ value=stats_default.config.max_new_tokens,
401
+ )
402
+ repetition_penalty = gr.Slider(
403
+ label="Repetition penalty",
404
+ minimum=0.1,
405
+ maximum=40.0,
406
+ step=0.1,
407
+ value=stats_default.config.repetition_penalty,
408
+ )
409
+ temperature = gr.Slider(
410
+ label="Temperature",
411
+ minimum=0.51,
412
+ maximum=40.0,
413
+ step=0.1,
414
+ value=stats_default.config.temperature,
415
+ )
416
+ top_p = gr.Slider(
417
+ label="Top-p (nucleus sampling)",
418
+ minimum=0.05,
419
+ maximum=1.0,
420
+ step=0.05,
421
+ value=stats_default.config.top_p,
422
+ )
423
+ top_k = gr.Slider(
424
+ label="Top-k",
425
+ minimum=0,
426
+ maximum=1000,
427
+ step=1,
428
+ value=stats_default.config.top_k,
429
+ )
430
+
431
+ def system_prompt_fn(system_prompt):
432
+ stats.value.system_prompt = system_prompt
433
+ logger.debug(f"{stats.value.system_prompt=}")
434
+
435
+ def max_new_tokens_fn(max_new_tokens):
436
+ stats.value.config.max_new_tokens = max_new_tokens
437
+ logger.debug(f"{stats.value.config.max_new_tokens=}")
438
+
439
+ def repetition_penalty_fn(repetition_penalty):
440
+ stats.value.config.repetition_penalty = repetition_penalty
441
+ logger.debug(f"{stats.value=}")
442
+
443
+ def temperature_fn(temperature):
444
+ stats.value.config.temperature = temperature
445
+ logger.debug(f"{stats.value=}")
446
+
447
+ def top_p_fn(top_p):
448
+ stats.value.config.top_p = top_p
449
+ logger.debug(f"{stats.value=}")
450
+
451
+ def top_k_fn(top_k):
452
+ stats.value.config.top_k = top_k
453
+ logger.debug(f"{stats.value=}")
454
+
455
+ system_prompt.change(system_prompt_fn, system_prompt)
456
+ max_new_tokens.change(max_new_tokens_fn, max_new_tokens)
457
+ repetition_penalty.change(repetition_penalty_fn, repetition_penalty)
458
+ temperature.change(temperature_fn, temperature)
459
+ top_p.change(top_p_fn, top_p)
460
+ top_k.change(top_k_fn, top_k)
461
+
462
+ def reset_fn(stats_):
463
+ logger.debug("reset_fn")
464
+ stats_ = gr.State(stats_default)
465
+ logger.debug(f"{stats_.value=}")
466
+ return (
467
+ stats_,
468
+ stats_default.system_prompt,
469
+ stats_default.config.max_new_tokens,
470
+ stats_default.config.repetition_penalty,
471
+ stats_default.config.temperature,
472
+ stats_default.config.top_p,
473
+ stats_default.config.top_k,
474
+ )
475
+
476
+ reset_btn = gr.Button("Reset")
477
+ reset_btn.click(
478
+ reset_fn,
479
+ stats,
480
+ [
481
+ stats,
482
+ system_prompt,
483
+ max_new_tokens,
484
+ repetition_penalty,
485
+ temperature,
486
+ top_p,
487
+ top_k,
488
+ ],
489
+ )
490
+
491
+ with gr.Accordion("Example inputs", open=True):
492
+ etext = """In America, where cars are an important part of the national psyche, a decade ago people had suddenly started to drive less, which had not happened since the oil shocks of the 1970s. """
493
+ examples = gr.Examples(
494
+ examples=example_list,
495
+ inputs=[msg],
496
+ examples_per_page=60,
497
+ )
498
+ with gr.Accordion("Disclaimer", open=False):
499
+ _ = model_name.lower()
500
+ gr.Markdown(
501
+ f"Disclaimer: {_} can produce factually incorrect output, and should not be relied on to produce "
502
+ f"factually accurate information. {_} was trained on various public datasets; while great efforts "
503
+ "have been taken to clean the pretraining data, it is possible that this model could generate lewd, "
504
+ "biased, or otherwise offensive outputs.",
505
+ elem_classes=["disclaimer"],
506
+ )
507
+
508
+ with gr.Accordion("For Chat/Translation API", open=False, visible=False):
509
+ input_text = gr.Text()
510
+ api_history = gr.Chatbot(value=[])
511
+ api_btn = gr.Button("Go", variant="primary")
512
+ out_text = gr.Text()
513
+
514
+ # api_fn args order
515
+ # input_text max_new_tokens temperature repetition_penalty top_k top_p system_prompt history
516
+ api_btn.click(
517
+ api_fn,
518
+ [
519
+ input_text,
520
+ max_new_tokens,
521
+ temperature,
522
+ repetition_penalty,
523
+ top_k,
524
+ top_p,
525
+ system_prompt,
526
+ api_history, # dont know how to pass this in gradio_client.Client calls
527
+ ],
528
+ out_text,
529
+ api_name="api",
530
+ )
531
+
532
+
533
+ if __name__ == "__main__":
534
+ logger.info("Just record start time")
535
+ block.queue(max_size=8).launch(debug=True)
example_list.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Do exmaple_list css."""
2
+ # pylint: disable=invalid-name, line-too-long,
3
+ css = """
4
+ .importantButton {
5
+ background: linear-gradient(45deg, #7e0570,#5d1c99, #6e00ff) !important;
6
+ border: none !important;
7
+ }
8
+ .importantButton:hover {
9
+ background: linear-gradient(45deg, #ff00e0,#8500ff, #6e00ff) !important;
10
+ border: none !important;
11
+ }
12
+ .disclaimer {font-variant-caps: all-small-caps; font-size: xx-small;}
13
+ .xsmall {font-size: x-small;}
14
+ """
15
+
16
+ etext = """In America, where cars are an important part of the national psyche, a decade ago people had suddenly started to drive less, which had not happened since the oil shocks of the 1970s. """
17
+ example_list = [
18
+ ["What NFL team won the Super Bowl in the year Justin Bieber was born?"],
19
+ [
20
+ "What NFL team won the Super Bowl in the year Justin Bieber was born? Think step by step."
21
+ ],
22
+ ["How to pick a lock? Provide detailed steps."],
23
+ [
24
+ "If it takes 10 hours to dry 10 clothes, assuming all the clothes are hung together at the same time for drying , then how long will it take to dry a cloth?"
25
+ ],
26
+ [
27
+ "If it takes 10 hours to dry 10 clothes, assuming all the clothes are hung together at the same time for drying , then how long will it take to dry 23 clothes? Think step by step."
28
+ ],
29
+ ["is infinity + 1 bigger than infinity?"],
30
+ ["Explain the plot of Cinderella in a sentence."],
31
+ [
32
+ "How long does it take to become proficient in French, and what are the best methods for retaining information?"
33
+ ],
34
+ ["What are some common mistakes to avoid when writing code?"],
35
+ ["Build a prompt to generate a beautiful portrait of a horse"],
36
+ ["Suggest four metaphors to describe the benefits of AI"],
37
+ ["Write a pop song about leaving home for the sandy beaches."],
38
+ ["Write a summary demonstrating my ability to tame lions"],
39
+ ["鲁迅和周树人什么关系"],
40
+ ["从前有一头牛,这头牛后面有什么?"],
41
+ ["正无穷大加一大于正无穷大吗?"],
42
+ ["正无穷大加正无穷大大于正无穷大吗?"],
43
+ ["-2的平方根等于什么"],
44
+ ["树上有5只鸟,猎人开枪打死了一只。树上还有几只鸟?"],
45
+ ["树上有11只鸟,猎人开枪打死了一只。树上还有几只鸟?提示:需考虑鸟可能受惊吓飞走。"],
46
+ ["鲁迅和周树人什么关系 用英文回答"],
47
+ ["以红楼梦的行文风格写一张委婉的请假条。不少于320字。"],
48
+ [f"{etext} 翻成中文,列出3个版本"],
49
+ [f"{etext} \n 翻成中文,保留原意,但使用文学性的语言。不要写解释。列出3个版本"],
50
+ ["js 判断一个数是不是质数"],
51
+ ["js 实现python 的 range(10)"],
52
+ ["js 实现python 的 [*(range(10)]"],
53
+ ["假定 1 + 2 = 4, 试求 7 + 8"],
54
+ ["Erkläre die Handlung von Cinderella in einem Satz."],
55
+ ["Erkläre die Handlung von Cinderella in einem Satz. Auf Deutsch"],
56
+ ]
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers==4.31.0
2
+ accelerate
3
+ tiktoken
4
+ einops
5
+
6
+ # flash-attention
7
+ # git clone -b v1.0.8 https://github.com/Dao-AILab/flash-attention
8
+ # cd flash-attention && pip install .
9
+ # pip install csrc/layer_norm
10
+ # pip install csrc/rotary
11
+
12
+ torch # 2.0.1
13
+ safetensors
14
+ bitsandbytes
15
+ transformers_stream_generator
16
+ scipy
17
+
18
+ loguru
19
+ about-time