Metal3d's picture
Changing the loop creation
d69fd10 unverified
raw
history blame
4.91 kB
import asyncio
import functools
import re
import gradio as gr
import spaces
from transformers import AsyncTextIteratorStreamer, AutoModelForCausalLM, AutoTokenizer
JS = """
() => {
// auto scroll .auto-scroll elements when text has changed
const observer = new MutationObserver((mutations) => {
mutations.forEach((mutation) => {
// find the parent element with .auto-scroll class and having the "overflow"
// style attribute to "auto"
let element = mutation.target;
while(element.parentElement !== null && element.parentElement.style.overflow !== "auto") {
element = element.parentElement;
}
if (element.parentElement === null) {
return;
}
element = element.parentElement;
element.scrollTop = element.scrollHeight;
});
})
document.querySelectorAll('.auto-scroll > *').forEach((elem) => {
console.log("observing", elem)
observer.observe(elem, {
childList: true,
characterData: true,
})
});
}
"""
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto",
)
print(dir(model))
print(model.config)
tokenizer = AutoTokenizer.from_pretrained(model_name)
def reformat_math(text):
"""Fix MathJax delimiters to use the Gradio syntax.
This is a workaround to display math formulas in Gradio. For now, I havn't found a way to
make it work as expected using others latex_delimites...
"""
text = re.sub(r"\\\[\s*(.*?)\s*\\\]", r"$$\1$$", text, flags=re.DOTALL)
text = re.sub(r"\\\(\s*(.*?)\s*\\\)", r"$\1$", text, flags=re.DOTALL)
return text
@spaces.GPU
def _generate(history):
text = tokenizer.apply_chat_template(
history,
tokenize=False,
add_generation_prompt=True,
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
streamer = AsyncTextIteratorStreamer(tokenizer, skip_special_tokens=True)
loop = asyncio.new_event_loop()
task = loop.run_in_executor(
None,
functools.partial(
model.generate,
max_new_tokens=1024 * 128,
streamer=streamer,
**model_inputs,
),
)
return loop, task, streamer
async def chat(prompt, history):
"""Respond to a chat prompt."""
message = {
"role": "user",
"content": prompt,
}
# build the messages list
history = [] if history is None else history
message_list = history + [message]
loop, task, streamer = _generate(message_list)
buffer = ""
reasoning = ""
thinking = False
try:
async for new_text in streamer:
if task.done() or task.cancelled():
print("Cancelled")
break # Stop le streaming si la tâche est annulée
if not thinking and "<think>" in new_text:
thinking = True
continue
if thinking and "</think>" in new_text:
thinking = False
continue
if thinking:
reasoning += new_text
heading = "# Reasoning\n\n"
yield "I'm thinking, please wait a moment...", heading + reasoning
continue
buffer += new_text
yield reformat_math(buffer), reasoning
except asyncio.CancelledError:
# this doesn't work, I don't find a way to stop generation thread
print("Cancelled")
streamer.on_finalized_text("cancelled", True)
print("Signal sent")
raise
loop.close()
chat_bot = gr.Chatbot(
latex_delimiters=[
{"left": "$$", "right": "$$", "display": True},
{"left": "$", "right": "$", "display": False},
],
scale=1,
type="messages",
)
with gr.Blocks(js=JS) as demo:
reasoning = gr.Markdown(
"# Reasoning\n\nWhen the model will reasoning, its thoughts will be displayed here.",
label="Reasoning",
show_label=True,
container=True,
elem_classes="auto-scroll",
max_height="90vh",
render=False,
)
with gr.Row(equal_height=True, height="90vh"):
with gr.Column(scale=3):
gr.ChatInterface(
chat,
type="messages",
chatbot=chat_bot,
title=str(model_name),
description=(
f"*{model_name}* is a large language model "
"trained on a mixture of instruction and "
"conversational data."
),
additional_outputs=[reasoning],
)
with gr.Column():
reasoning.render()
if __name__ == "__main__":
demo.queue().launch()