|
""" |
|
Run qwen 7b chat. |
|
|
|
transformers 4.31.0 |
|
|
|
import torch |
|
torch.cuda.empty_cache() |
|
|
|
model.chat( |
|
tokenizer: transformers.tokenization_utils.PreTrainedTokenizer, |
|
query: str, |
|
history: Optional[List[Tuple[str, str]]], |
|
system: str = 'You are a helpful assistant.', |
|
append_history: bool = True, |
|
stream: Optional[bool] = <object object at 0x7f905797ec20>, |
|
stop_words_ids: Optional[List[List[int]]] = None, |
|
**kwargs) -> Tuple[str, List[Tuple[str, str]]] |
|
) |
|
|
|
model.generation_config |
|
GenerationConfig { |
|
"chat_format": "chatml", |
|
"do_sample": true, |
|
"eos_token_id": 151643, |
|
"max_new_tokens": 512, |
|
"max_window_size": 6144, |
|
"pad_token_id": 151643, |
|
"top_k": 0, |
|
"top_p": 0.5, |
|
"transformers_version": "4.31.0", |
|
"trust_remote_code": true |
|
} |
|
""" |
|
|
|
import gc |
|
import os |
|
import sys |
|
import time |
|
from collections import deque |
|
from dataclasses import asdict, dataclass |
|
from textwrap import dedent |
|
from types import SimpleNamespace |
|
from typing import List, Optional |
|
|
|
import gradio as gr |
|
import torch |
|
from loguru import logger |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from transformers.generation import GenerationConfig |
|
|
|
from example_list import css, example_list |
|
|
|
if not torch.cuda.is_available(): |
|
raise gr.Error("No cuda, cant continue...") |
|
|
|
os.environ["TZ"] = "Asia/Shanghai" |
|
try: |
|
time.tzset() |
|
except Exception: |
|
|
|
logger.warning("Windows, cant run time.tzset()") |
|
|
|
model_name = "Qwen/Qwen-7B-Chat" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
|
|
n_gpus = torch.cuda.device_count() |
|
try: |
|
_ = f"{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB" |
|
except AssertionError: |
|
_ = 0 |
|
max_memory = {i: _ for i in range(n_gpus)} |
|
|
|
del sys |
|
|
|
|
|
|
|
|
|
|
|
def gen_model(model_name: str): |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
trust_remote_code=True, |
|
device_map="auto", |
|
load_in_4bit=True, |
|
max_memory=max_memory, |
|
fp16=True, |
|
torch_dtype=torch.float16, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_compute_dtype=torch.bfloat16, |
|
) |
|
model = model.eval() |
|
model.generation_config = GenerationConfig.from_pretrained( |
|
model_name, |
|
trust_remote_code=True, |
|
) |
|
return model |
|
|
|
|
|
def user_clear(message, chat_history): |
|
"""Gen a response, clear message in user textbox.""" |
|
logger.debug(f"{message=}") |
|
|
|
try: |
|
chat_history.append([message, ""]) |
|
except Exception: |
|
chat_history = deque([message, ""], maxlen=5) |
|
|
|
logger.trace(f"{chat_history=}") |
|
return "", chat_history |
|
|
|
|
|
def user(message, chat_history): |
|
"""Gen a response.""" |
|
logger.debug(f"{message=}") |
|
logger.trace(f"{chat_history=}") |
|
|
|
try: |
|
chat_history.append([message, ""]) |
|
except Exception: |
|
chat_history = deque([message, ""], maxlen=5) |
|
return message, chat_history |
|
|
|
|
|
|
|
model = None |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
if not torch.cuda.is_available(): |
|
|
|
raise SystemExit("GPU not available, cant run. Turn on GPU and retry") |
|
|
|
model = gen_model(model_name) |
|
|
|
|
|
def bot(chat_history, **kwargs): |
|
try: |
|
message = chat_history[-1][0] |
|
except Exception as exc: |
|
logger.error(f"{chat_history=}: {exc}") |
|
return chat_history |
|
logger.debug(f"{chat_history=}") |
|
try: |
|
_ = """ |
|
response, chat_history = model.chat( |
|
tokenizer, |
|
message, |
|
history=chat_history, |
|
temperature=0.7, |
|
repetition_penalty=1.2, |
|
# max_length=128, |
|
) |
|
""" |
|
logger.debug("run model.chat...") |
|
model.generation_config.update(**kwargs) |
|
response, chat_history = model.chat( |
|
tokenizer, |
|
message, |
|
chat_history[:-1], |
|
|
|
) |
|
del response |
|
return chat_history |
|
except Exception as exc: |
|
logger.error(exc) |
|
chat_history[:-1].append(["message", str(exc)]) |
|
return chat_history |
|
|
|
|
|
def bot_stream(chat_history, **kwargs): |
|
logger.trace(f"{chat_history=}") |
|
logger.trace(f"{kwargs=}") |
|
|
|
try: |
|
message = chat_history[-1][0] |
|
except Exception as exc: |
|
logger.error(f"{chat_history=}: {exc}") |
|
raise gr.Error(f"{chat_history=}") |
|
|
|
|
|
|
|
model.generation_config.update(**kwargs) |
|
response = "" |
|
for elm in model.chat_stream(tokenizer, message, chat_history): |
|
chat_history[-1] = [message, elm] |
|
response = elm |
|
yield chat_history |
|
logger.debug(f"{response=}") |
|
logger.debug(f"{model.generation_config=}") |
|
|
|
|
|
SYSTEM_PROMPT = "You are a helpful assistant." |
|
MAX_MAX_NEW_TOKENS = 2048 |
|
MAX_NEW_TOKENS = 256 |
|
|
|
|
|
@dataclass |
|
class Config: |
|
max_new_tokens: int = MAX_NEW_TOKENS |
|
repetition_penalty: float = 1.1 |
|
temperature: float = 1.0 |
|
top_k: int = 0 |
|
top_p: float = 0.9 |
|
|
|
|
|
|
|
stats_default = SimpleNamespace(llm=None, system_prompt=SYSTEM_PROMPT, config=Config()) |
|
|
|
|
|
|
|
def api_fn( |
|
input_text: Optional[str], |
|
|
|
max_new_tokens: int = stats_default.config.max_new_tokens, |
|
temperature: float = stats_default.config.temperature, |
|
repetition_penalty: float = stats_default.config.repetition_penalty, |
|
top_k: int = stats_default.config.top_k, |
|
top_p: int = stats_default.config.top_p, |
|
system_prompt: Optional[str] = None, |
|
history: Optional[List[str]] = None, |
|
): |
|
if input_text is None: |
|
input_text = "" |
|
try: |
|
input_text = str(input_text).strip() |
|
except Exception as exc: |
|
logger.error(exc) |
|
input_text = "" |
|
if not input_text: |
|
return "" |
|
if history is None: |
|
history = [] |
|
try: |
|
temperature = float(temperature) |
|
except Exception: |
|
temperature = stats_default.config.temperature |
|
|
|
if system_prompt is None: |
|
system_prompt = stats_default.system_prompt |
|
|
|
if max_new_tokens < 10: |
|
max_new_tokens = stats_default.config.max_new_tokens |
|
if top_p < 0.1 or top_p > 1: |
|
top_p = stats_default.config.top_p |
|
if temperature <= 0.5: |
|
temperature = stats_default.config.temperature |
|
|
|
_ = { |
|
"max_new_tokens": max_new_tokens, |
|
"temperature": temperature, |
|
"repetition_penalty": repetition_penalty, |
|
"top_k": top_k, |
|
"top_p": top_p, |
|
} |
|
model.generation_config.update(**_) |
|
try: |
|
res, _ = model.chat( |
|
tokenizer, |
|
input_text, |
|
history=history, |
|
|
|
append_history=False, |
|
) |
|
|
|
except Exception as exc: |
|
logger.error(f"{exc=}") |
|
res = str(exc) |
|
|
|
logger.debug(f"api {res=}") |
|
logger.debug(f"api {model.generation_config=}") |
|
|
|
return res |
|
|
|
|
|
theme = gr.themes.Soft(text_size="sm") |
|
with gr.Blocks( |
|
theme=theme, |
|
title=model_name.lower(), |
|
css=css, |
|
) as block: |
|
stats = gr.State(stats_default) |
|
|
|
|
|
model.generation_config = GenerationConfig.from_pretrained( |
|
model_name, |
|
trust_remote_code=True, |
|
) |
|
config = asdict(stats.value.config) |
|
|
|
def bot_stream_state(chat_history): |
|
logger.trace(f"{chat_history=}") |
|
yield from bot_stream(chat_history, **config) |
|
|
|
with gr.Accordion("🎈 Info", open=False): |
|
gr.Markdown( |
|
dedent( |
|
f""" |
|
## {model_name.lower()} |
|
|
|
* temperature range: .51 and up; higher temperature implies more randomness. Suggested temperature for chatting and creative writing is around 1.1 while it should be set to 0.51-1.0 for summarizing and translation. |
|
* Set `repetition_penalty` to 2.1 or higher for a chatty conversation (more unpredictable and undesirable output). Lower it to 1.1 or smaller if more focused anwsers are desired (for example for translations or fact-oriented queries). |
|
* Smaller `top_k` probably will result in smoothier sentences. |
|
(`top_k=0` is equivalent to `top_k` equal to very very big though.) Consult `transformers` documentation for more details. |
|
* An API is available at https://mikeee-qwen-7b-chat.hf.space/ that can be queried, e.g., in python |
|
```python |
|
from gradio_client import Client |
|
|
|
client = Client("https://mikeee-qwen-7b-chat.hf.space/") |
|
|
|
result = client.predict( |
|
"你好!", # user prompt |
|
256, # max_new_tokens |
|
1.2, # temperature |
|
1.1, # repetition_penalty |
|
0, # top_k |
|
0.9, # top_p |
|
"You are a helpful assistant.", # system_prompt |
|
None, # history |
|
api_name="/api" |
|
) |
|
print(result) |
|
``` |
|
or in javascript |
|
```js |
|
import {{ client }} from "@gradio/client"; |
|
|
|
const app = await client("https://mikeee-qwen-7b-chat.hf.space/"); |
|
const result = await app.predict("api", [...]); |
|
console.log(result.data); |
|
``` |
|
Check documentation and examples by clicking `Use via API` at the very bottom of [https://huggingface.co/spaces/mikeee/qwen-7b-chat](https://huggingface.co/spaces/mikeee/qwen-7b-chat). |
|
|
|
<p></p> |
|
Most examples are meant for another model. |
|
You probably should try to test |
|
some related prompts. System prompt can be changed in Advaned Options as well.""" |
|
), |
|
elem_classes="xsmall", |
|
) |
|
|
|
chatbot = gr.Chatbot(height=500, value=deque([], maxlen=5)) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=5): |
|
msg = gr.Textbox( |
|
label="Chat Message Box", |
|
placeholder="Ask me anything (press Shift+Enter or click Submit to send)", |
|
show_label=False, |
|
|
|
lines=4, |
|
max_lines=30, |
|
show_copy_button=True, |
|
|
|
) |
|
with gr.Column(scale=1, min_width=50): |
|
with gr.Row(): |
|
submit = gr.Button("Submit", elem_classes="xsmall") |
|
stop = gr.Button("Stop", visible=True) |
|
clear = gr.Button("Clear History", visible=True) |
|
|
|
msg_submit_event = msg.submit( |
|
|
|
fn=user, |
|
inputs=[msg, chatbot], |
|
outputs=[msg, chatbot], |
|
queue=True, |
|
show_progress="full", |
|
|
|
).then(bot_stream_state, chatbot, chatbot, queue=True) |
|
submit_click_event = submit.click( |
|
|
|
fn=user_clear, |
|
inputs=[msg, chatbot], |
|
outputs=[msg, chatbot], |
|
queue=True, |
|
show_progress="full", |
|
|
|
).then(bot_stream_state, chatbot, chatbot, queue=True) |
|
stop.click( |
|
fn=None, |
|
inputs=None, |
|
outputs=None, |
|
cancels=[msg_submit_event, submit_click_event], |
|
queue=False, |
|
) |
|
clear.click(lambda: None, None, chatbot, queue=False) |
|
|
|
with gr.Accordion(label="Advanced Options", open=False): |
|
system_prompt = gr.Textbox( |
|
label="System prompt", |
|
value=stats_default.system_prompt, |
|
lines=3, |
|
visible=True, |
|
) |
|
max_new_tokens = gr.Slider( |
|
label="Max new tokens", |
|
minimum=1, |
|
maximum=MAX_MAX_NEW_TOKENS, |
|
step=1, |
|
value=stats_default.config.max_new_tokens, |
|
) |
|
repetition_penalty = gr.Slider( |
|
label="Repetition penalty", |
|
minimum=0.1, |
|
maximum=40.0, |
|
step=0.1, |
|
value=stats_default.config.repetition_penalty, |
|
) |
|
temperature = gr.Slider( |
|
label="Temperature", |
|
minimum=0.51, |
|
maximum=40.0, |
|
step=0.1, |
|
value=stats_default.config.temperature, |
|
) |
|
top_p = gr.Slider( |
|
label="Top-p (nucleus sampling)", |
|
minimum=0.05, |
|
maximum=1.0, |
|
step=0.05, |
|
value=stats_default.config.top_p, |
|
) |
|
top_k = gr.Slider( |
|
label="Top-k", |
|
minimum=0, |
|
maximum=1000, |
|
step=1, |
|
value=stats_default.config.top_k, |
|
) |
|
|
|
def system_prompt_fn(system_prompt): |
|
stats.value.system_prompt = system_prompt |
|
logger.debug(f"{stats.value.system_prompt=}") |
|
|
|
def max_new_tokens_fn(max_new_tokens): |
|
stats.value.config.max_new_tokens = max_new_tokens |
|
logger.debug(f"{stats.value.config.max_new_tokens=}") |
|
|
|
def repetition_penalty_fn(repetition_penalty): |
|
stats.value.config.repetition_penalty = repetition_penalty |
|
logger.debug(f"{stats.value=}") |
|
|
|
def temperature_fn(temperature): |
|
stats.value.config.temperature = temperature |
|
logger.debug(f"{stats.value=}") |
|
|
|
def top_p_fn(top_p): |
|
stats.value.config.top_p = top_p |
|
logger.debug(f"{stats.value=}") |
|
|
|
def top_k_fn(top_k): |
|
stats.value.config.top_k = top_k |
|
logger.debug(f"{stats.value=}") |
|
|
|
system_prompt.change(system_prompt_fn, system_prompt) |
|
max_new_tokens.change(max_new_tokens_fn, max_new_tokens) |
|
repetition_penalty.change(repetition_penalty_fn, repetition_penalty) |
|
temperature.change(temperature_fn, temperature) |
|
top_p.change(top_p_fn, top_p) |
|
top_k.change(top_k_fn, top_k) |
|
|
|
def reset_fn(stats_): |
|
logger.debug("reset_fn") |
|
stats_ = gr.State(stats_default) |
|
logger.debug(f"{stats_.value=}") |
|
return ( |
|
stats_, |
|
stats_default.system_prompt, |
|
stats_default.config.max_new_tokens, |
|
stats_default.config.repetition_penalty, |
|
stats_default.config.temperature, |
|
stats_default.config.top_p, |
|
stats_default.config.top_k, |
|
) |
|
|
|
reset_btn = gr.Button("Reset") |
|
reset_btn.click( |
|
reset_fn, |
|
stats, |
|
[ |
|
stats, |
|
system_prompt, |
|
max_new_tokens, |
|
repetition_penalty, |
|
temperature, |
|
top_p, |
|
top_k, |
|
], |
|
) |
|
|
|
with gr.Accordion("Example inputs", open=True): |
|
etext = """In America, where cars are an important part of the national psyche, a decade ago people had suddenly started to drive less, which had not happened since the oil shocks of the 1970s. """ |
|
examples = gr.Examples( |
|
examples=example_list, |
|
inputs=[msg], |
|
examples_per_page=60, |
|
) |
|
with gr.Accordion("Disclaimer", open=False): |
|
_ = model_name.lower() |
|
gr.Markdown( |
|
f"Disclaimer: {_} can produce factually incorrect output, and should not be relied on to produce " |
|
f"factually accurate information. {_} was trained on various public datasets; while great efforts " |
|
"have been taken to clean the pretraining data, it is possible that this model could generate lewd, " |
|
"biased, or otherwise offensive outputs.", |
|
elem_classes=["disclaimer"], |
|
) |
|
|
|
with gr.Accordion("For Chat/Translation API", open=False, visible=False): |
|
input_text = gr.Text() |
|
api_history = gr.Chatbot(value=[]) |
|
api_btn = gr.Button("Go", variant="primary") |
|
out_text = gr.Text() |
|
|
|
|
|
|
|
api_btn.click( |
|
api_fn, |
|
[ |
|
input_text, |
|
max_new_tokens, |
|
temperature, |
|
repetition_penalty, |
|
top_k, |
|
top_p, |
|
system_prompt, |
|
api_history, |
|
], |
|
out_text, |
|
api_name="api", |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
logger.info("Just record start time") |
|
block.queue(max_size=8).launch(debug=True) |
|
|