Spaces:
Runtime error
Runtime error
ffreemt
41e9e77
| """ | |
| Run qwen 7b chat. | |
| transformers 4.31.0 | |
| import torch | |
| torch.cuda.empty_cache() | |
| """ | |
| # pylint: disable=line-too-long, invalid-name, no-member, redefined-outer-name, missing-function-docstring, missing-class-docstring, broad-except, | |
| import gc | |
| import os | |
| import time | |
| from collections import deque | |
| from dataclasses import asdict, dataclass | |
| from types import SimpleNamespace | |
| import gradio as gr | |
| import torch | |
| from loguru import logger | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from transformers.generation import GenerationConfig | |
| from example_list import css, example_list | |
| if not torch.cuda.is_available(): | |
| raise gr.Error("No cuda, cant continue...") | |
| os.environ["TZ"] = "Asia/Shanghai" | |
| try: | |
| time.tzset() # type: ignore # pylint: disable=no-member | |
| except Exception: | |
| # Windows | |
| logger.warning("Windows, cant run time.tzset()") | |
| model_name = "Qwen/Qwen-7B-Chat" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
| n_gpus = torch.cuda.device_count() | |
| try: | |
| _ = f"{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB" | |
| except AssertionError: | |
| _ = 0 | |
| max_memory = {i: _ for i in range(n_gpus)} | |
| def gen_model(model_name: str): | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| trust_remote_code=True, | |
| device_map="auto", | |
| load_in_4bit=True, | |
| max_memory=max_memory, | |
| fp16=True, | |
| torch_dtype=torch.float16, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| ) | |
| model = model.eval() | |
| model.generation_config = GenerationConfig.from_pretrained( | |
| model_name, | |
| trust_remote_code=True, | |
| ) | |
| return model | |
| def user_sub(message, chat_history): | |
| """Gen a response, clear message in user textbox.""" | |
| logger.debug(f"{message=}") | |
| # logger.remove() #to turn on trace | |
| # logger.add(sys.stderr, level="INFO") | |
| logger.trace(f"{chat_history=}") | |
| try: | |
| chat_history.append([message, ""]) | |
| except Exception: | |
| chat_history = deque([message, ""], maxlen=5) | |
| return "", chat_history | |
| def user(message, chat_history): | |
| """Gen a response.""" | |
| logger.debug(f"{message=}") | |
| logger.trace(f"{chat_history=}") | |
| try: | |
| chat_history.append([message, ""]) | |
| except Exception: | |
| chat_history = deque([message, ""], maxlen=5) | |
| return message, chat_history | |
| # for rerun in tests | |
| model = None | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| model = gen_model(model_name) | |
| def bot(chat_history, **kwargs): | |
| try: | |
| message = chat_history[-1][0] | |
| except Exception as exc: | |
| logger.error(f"{chat_history=}: {exc}") | |
| return chat_history | |
| logger.debug(f"{chat_history=}") | |
| try: | |
| _ = """ | |
| response, chat_history = model.chat( | |
| tokenizer, | |
| message, | |
| history=chat_history, | |
| temperature=0.7, | |
| repetition_penalty=1.2, | |
| # max_length=128, | |
| ) | |
| """ | |
| logger.debug("run model.chat...") | |
| response, chat_history = model.chat( | |
| tokenizer, | |
| message, | |
| chat_history[:-1], | |
| **kwargs, | |
| ) | |
| del response | |
| return chat_history | |
| except Exception as exc: | |
| logger.error(exc) | |
| chat_history[:-1].append(["message", str(exc)]) | |
| return chat_history | |
| SYSTEM_PROMPT = "You are a helpful assistant." | |
| MAX_MAX_NEW_TOKENS = 1024 | |
| MAX_NEW_TOKENS = 128 | |
| class Config: | |
| max_new_tokens: int = 64 | |
| repetition_penalty: float = 1.1 | |
| temperature: float = 1.0 | |
| top_k: int = 0 | |
| top_p: float = 0.9 | |
| stats_default = SimpleNamespace(llm=model, system_prompt=SYSTEM_PROMPT, config=Config()) | |
| theme = gr.themes.Soft(text_size="sm") | |
| with gr.Blocks( | |
| theme=theme, | |
| title=model_name.lower(), | |
| css=css, | |
| ) as block: | |
| stats = gr.State(stats_default) | |
| if not torch.cuda.is_available(): | |
| raise gr.Error("GPU not available, cant run. Turn on GPU and restart") | |
| model_ = stats.value.llm | |
| config = stats.value.config | |
| model_.generation_config.update(**asdict(config)) | |
| def bot_stream(chat_history): | |
| try: | |
| message = chat_history[-1][0] | |
| except Exception as exc: | |
| logger.error(f"{chat_history=}: {exc}") | |
| raise gr.Error(f"{chat_history=}") | |
| # yield chat_history | |
| for elm in model.chat_stream(tokenizer, message, chat_history): | |
| chat_history[-1] = [message, elm] | |
| yield chat_history | |
| with gr.Accordion("🎈 Info", open=False): | |
| gr.Markdown( | |
| f"""<h5><center>{model_name.lower()}</center></h4> | |
| Set `repetition_penalty` to 2.1 or higher for a chatty conversation. Lower it to 1.1 or smaller if more focused anwsers are desired (for example for translations or fact-oriented queries). Smaller `top_k` probably will result in smoothies sentences. | |
| Consult `transformers` documentation for more details. | |
| Most examples are meant for another model. | |
| You probably should try to test | |
| some related prompts.""", | |
| elem_classes="xsmall", | |
| ) | |
| chatbot = gr.Chatbot(height=500, value=deque([], maxlen=5)) # type: ignore | |
| with gr.Row(): | |
| with gr.Column(scale=5): | |
| msg = gr.Textbox( | |
| label="Chat Message Box", | |
| placeholder="Ask me anything (press Shift+Enter or click Submit to send)", | |
| show_label=False, | |
| # container=False, | |
| lines=4, | |
| max_lines=30, | |
| show_copy_button=True, | |
| # ).style(container=False) | |
| ) | |
| with gr.Column(scale=1, min_width=50): | |
| with gr.Row(): | |
| submit = gr.Button("Submit", elem_classes="xsmall") | |
| stop = gr.Button("Stop", visible=True) | |
| clear = gr.Button("Clear History", visible=True) | |
| msg_submit_event = msg.submit( | |
| # fn=conversation.user_turn, | |
| fn=user_sub, | |
| inputs=[msg, chatbot], | |
| outputs=[msg, chatbot], | |
| queue=True, | |
| show_progress="full", | |
| # api_name=None, | |
| ).then(bot_stream, chatbot, chatbot, queue=True) | |
| submit_click_event = submit.click( | |
| # fn=lambda x, y: ("",) + user(x, y)[1:], # clear msg | |
| fn=user, # clear msg | |
| inputs=[msg, chatbot], | |
| outputs=[msg, chatbot], | |
| queue=True, | |
| show_progress="full", | |
| # api_name=None, | |
| ).then(bot_stream, chatbot, chatbot, queue=True) | |
| stop.click( | |
| fn=None, | |
| inputs=None, | |
| outputs=None, | |
| cancels=[msg_submit_event, submit_click_event], | |
| queue=False, | |
| ) | |
| clear.click(lambda: None, None, chatbot, queue=False) | |
| with gr.Accordion(label="Advanced Options", open=False): | |
| system_prompt = gr.Textbox( | |
| label="System prompt", | |
| value=stats_default.system_prompt, | |
| lines=3, | |
| visible=True, | |
| ) | |
| max_new_tokens = gr.Slider( | |
| label="Max new tokens", | |
| minimum=1, | |
| maximum=MAX_MAX_NEW_TOKENS, | |
| step=1, | |
| value=stats_default.config.max_new_tokens, | |
| ) | |
| repetition_penalty = gr.Slider( | |
| label="Repetition penalty", | |
| minimum=0.1, | |
| maximum=40.0, | |
| step=0.1, | |
| value=stats_default.config.repetition_penalty, | |
| ) | |
| temperature = gr.Slider( | |
| label="Temperature", | |
| minimum=0.1, | |
| maximum=40.0, | |
| step=0.1, | |
| value=stats_default.config.temperature, | |
| ) | |
| top_p = gr.Slider( | |
| label="Top-p (nucleus sampling)", | |
| minimum=0.05, | |
| maximum=1.0, | |
| step=0.05, | |
| value=stats_default.config.top_p, | |
| ) | |
| top_k = gr.Slider( | |
| label="Top-k", | |
| minimum=0, | |
| maximum=1000, | |
| step=1, | |
| value=stats_default.config.top_k, | |
| ) | |
| def system_prompt_fn(system_prompt): | |
| stats.value.system_prompt = system_prompt | |
| logger.debug(f"{stats.value.system_prompt=}") | |
| def max_new_tokens_fn(max_new_tokens): | |
| stats.value.config.max_new_tokens = max_new_tokens | |
| logger.debug(f"{stats.value.config.max_new_tokens=}") | |
| def repetition_penalty_fn(repetition_penalty): | |
| stats.value.config.repetition_penalty = repetition_penalty | |
| logger.debug(f"{stats.value=}") | |
| def temperature_fn(temperature): | |
| stats.value.config.temperature = temperature | |
| logger.debug(f"{stats.value=}") | |
| def top_p_fn(top_p): | |
| stats.value.config.top_p = top_p | |
| logger.debug(f"{stats.value=}") | |
| def top_k_fn(top_k): | |
| stats.value.config.top_k = top_k | |
| logger.debug(f"{stats.value=}") | |
| system_prompt.change(system_prompt_fn, system_prompt) | |
| max_new_tokens.change(max_new_tokens_fn, max_new_tokens) | |
| repetition_penalty.change(repetition_penalty_fn, repetition_penalty) | |
| temperature.change(temperature_fn, temperature) | |
| top_p.change(top_p_fn, top_p) | |
| top_k.change(top_k_fn, top_k) | |
| def reset_fn(stats_): | |
| logger.debug("reset_fn") | |
| stats_ = gr.State(stats_default) | |
| logger.debug(f"{stats_.value=}") | |
| return ( | |
| stats_, | |
| stats_default.system_prompt, | |
| stats_default.config.max_new_tokens, | |
| stats_default.config.repetition_penalty, | |
| stats_default.config.temperature, | |
| stats_default.config.top_p, | |
| stats_default.config.top_k, | |
| ) | |
| reset_btn = gr.Button("Reset") | |
| reset_btn.click( | |
| reset_fn, | |
| stats, | |
| [ | |
| stats, | |
| system_prompt, | |
| max_new_tokens, | |
| repetition_penalty, | |
| temperature, | |
| top_p, | |
| top_k, | |
| ], | |
| ) | |
| with gr.Accordion("Example inputs", open=True): | |
| etext = """In America, where cars are an important part of the national psyche, a decade ago people had suddenly started to drive less, which had not happened since the oil shocks of the 1970s. """ | |
| examples = gr.Examples( | |
| examples=example_list, | |
| inputs=[msg], | |
| examples_per_page=60, | |
| ) | |
| with gr.Accordion("Disclaimer", open=False): | |
| _ = model_name.lower() | |
| gr.Markdown( | |
| f"Disclaimer: {_} can produce factually incorrect output, and should not be relied on to produce " | |
| f"factually accurate information. {_} was trained on various public datasets; while great efforts " | |
| "have been taken to clean the pretraining data, it is possible that this model could generate lewd, " | |
| "biased, or otherwise offensive outputs.", | |
| elem_classes=["disclaimer"], | |
| ) | |
| if __name__ == "__main__": | |
| block.queue(max_size=8).launch(debug=True) | |