Spaces:
Paused
Paused
#!/usr/bin/env python | |
import os | |
from collections.abc import Iterator | |
from threading import Thread | |
import gradio as gr | |
import spaces | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
DESCRIPTION = "# Mistral-7B v0.3" | |
if not torch.cuda.is_available(): | |
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>" | |
MAX_MAX_NEW_TOKENS = 2048 | |
DEFAULT_MAX_NEW_TOKENS = 1024 | |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) | |
CHAT_TEMPLATE="""{%- set default_system_message = "A user will ask you to solve a task. You should first draft your thinking process (inner monologue) until you have derived the final answer. Afterwards, write a self-contained summary of your thoughts (i.e. your summary should be succinct but contain all the critical steps you needed to reach the conclusion). You should use Markdown and Latex to format your response. Write both your thoughts and summary in the same language as the task posed by the user.\n\nYour thinking process must follow the template below:\n<think>\nYour thoughts or/and draft, like working through an exercise on scratch paper. Be as casual and as long as you want until you are confident to generate a correct answer.\n</think>\n\nHere, provide a concise summary that reflects your reasoning and presents a clear final answer to the user.\n\nProblem:" %} | |
{{- bos_token }} | |
{%- if messages[0]['role'] == 'system' %} | |
{%- set system_message = messages[0]['content'] %} | |
{%- set loop_messages = messages[1:] %} | |
{%- else %} | |
{%- set system_message = default_system_message %} | |
{%- set loop_messages = messages %} | |
{%- endif %} | |
{{- '[SYSTEM_PROMPT]' + system_message + '[/SYSTEM_PROMPT]' }} | |
{%- for message in loop_messages %} | |
{%- if message['role'] == 'user' %} | |
{{- '[INST]' + message['content'] + '[/INST]<think>' }} | |
{%- elif message['role'] == 'system' %} | |
{{- '[SYSTEM_PROMPT]' + message['content'] + '[/SYSTEM_PROMPT]' }} | |
{%- elif message['role'] == 'assistant' %} | |
{{- message['content'] + eos_token }} | |
{%- else %} | |
{{- raise_exception('Only user, system and assistant roles are supported!') }} | |
{%- endif %} | |
{%- endfor %}""" | |
if torch.cuda.is_available(): | |
model_id = "mistralai/Mistral-Small-24B-Instruct-2501" | |
model = AutoModelForCausalLM.from_pretrained("AlexHung29629/fix_magistra3", torch_dtype=torch.bfloat16, device_map="auto") | |
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-Small-24B-Instruct-2501") | |
def generate( | |
message: str, | |
chat_history: list[dict], | |
max_new_tokens: int = 1024, | |
temperature: float = 0.6, | |
top_p: float = 0.9, | |
top_k: int = 50, | |
repetition_penalty: float = 1.2, | |
) -> Iterator[str]: | |
conversation = [*chat_history, {"role": "user", "content": message}] | |
input_ids = tokenizer.apply_chat_template(conversation, chat_template=CHAT_TEMPLATE, return_tensors="pt") | |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: | |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] | |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") | |
input_ids = input_ids.to(model.device) | |
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=False) | |
generate_kwargs = dict( | |
{"input_ids": input_ids}, | |
streamer=streamer, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
top_p=top_p, | |
top_k=top_k, | |
temperature=temperature, | |
num_beams=1, | |
repetition_penalty=repetition_penalty, | |
) | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
outputs = [] | |
for text in streamer: | |
outputs.append(text) | |
yield "".join(outputs) | |
demo = gr.ChatInterface( | |
fn=generate, | |
additional_inputs=[ | |
gr.Slider( | |
label="Max new tokens", | |
minimum=1, | |
maximum=MAX_MAX_NEW_TOKENS, | |
step=1, | |
value=DEFAULT_MAX_NEW_TOKENS, | |
), | |
gr.Slider( | |
label="Temperature", | |
minimum=0.1, | |
maximum=4.0, | |
step=0.1, | |
value=0.3, | |
), | |
gr.Slider( | |
label="Top-p (nucleus sampling)", | |
minimum=0.05, | |
maximum=1.0, | |
step=0.05, | |
value=0.95, | |
), | |
gr.Slider( | |
label="Top-k", | |
minimum=1, | |
maximum=1000, | |
step=1, | |
value=40, | |
), | |
gr.Slider( | |
label="Repetition penalty", | |
minimum=1.0, | |
maximum=2.0, | |
step=0.05, | |
value=1.0, | |
), | |
], | |
stop_btn=None, | |
examples=[ | |
["Hello there! How are you doing?"], | |
["Can you explain briefly to me what is the Python programming language?"], | |
["Explain the plot of Cinderella in a sentence."], | |
["How many hours does it take a man to eat a Helicopter?"], | |
["Write a 100-word article on 'Benefits of Open-Source in AI research'"], | |
], | |
type="messages", | |
description=DESCRIPTION, | |
css_paths="style.css", | |
) | |
if __name__ == "__main__": | |
demo.queue(max_size=20).launch() | |