#!/usr/bin/env python
import os
from collections.abc import Iterator
from threading import Thread
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, AutoModelForImageTextToText
DESCRIPTION = """# 測試"""
if not torch.cuda.is_available():
DESCRIPTION += "\n
Running on CPU 🥶 This demo does not work on CPU.
"
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
CHAT_TEMPLATE="""{%- set today = strftime_now("%Y-%m-%d") %}
{%- set default_system_message = "Your knowledge base was last updated on 2023-10-01.\nThe current date is {today}.\n\nWhen you\'re not sure about some information or when the user\'s request requires up-to-date or specific data, you must use the available tools to fetch the information. Do not hesitate to use tools whenever they can provide a more accurate or complete response. If no relevant tools are available, then clearly state that you don\'t have the information and avoid making up anything.\nIf the user\'s question is not clear, ambiguous, or does not provide enough context for you to accurately answer the question, you do not try to answer it right away and you rather ask the user to clarify their request (e.g. \"What are some good restaurants around me?\" => \"Where are you?\" or \"When is the next flight to Tokyo\" => \"Where do you travel from?\").\nYou are always very attentive to dates, in particular you try to resolve dates and when asked about information at specific dates, you discard information that is at another date.\nYou follow these instructions in all languages, and always respond to the user in the language they use or request.\nNext sections describe the capabilities that you have.\n\n\n# MULTI-MODAL INSTRUCTIONS\n\nYou have the ability to read images, but you cannot generate images. You also cannot transcribe audio files or videos.\nYou cannot read nor transcribe audio files or videos." %}
{{- bos_token }}
{%- if messages[0]['role'] == 'system' %}
{%- if messages[0]['content'] is string %}
{%- set system_message = messages[0]['content'] %}
{%- set loop_messages = messages[1:] %}
{%- else %}
{%- set system_message = messages[0]['content'][0]['text'] %}
{%- set loop_messages = messages[1:] %}
{%- endif %}
{%- else %}
{%- set system_message = default_system_message %}
{%- set loop_messages = messages %}
{%- endif %}
{%- if not tools is defined %}
{%- set tools = none %}
{%- elif tools is not none %}
{%- set parallel_tool_prompt = "# TOOL CALLING INSTRUCTIONS\n\nYou may have access to tools that you can use to fetch information or perform actions. You must use these tools in the following situations:\n\n1. When the request requires up-to-date information.\n2. When the request requires specific data that you do not have in your knowledge base.\n3. When the request involves actions that you cannot perform without tools.\n\nAlways prioritize using tools to provide the most accurate and helpful response. If tools are not available, inform the user that you cannot perform the requested action at the moment." %}
{%- if system_message is defined %}
{%- set system_message = system_message + "\n\n" + parallel_tool_prompt %}
{%- else %}
{%- set system_message = parallel_tool_prompt %}
{%- endif %}
{%- endif %}
{{- '[SYSTEM_PROMPT]' + system_message + '[/SYSTEM_PROMPT]' }}
{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}
{%- for message in loop_messages %}
{%- if message["role"] == "user" %}
{%- if tools is not none and (message == user_messages[-1]) %}
{{- "[AVAILABLE_TOOLS] [" }}
{%- for tool in tools %}
{%- set tool = tool.function %}
{{- '{"type": "function", "function": {' }}
{%- for key, val in tool.items() if key != "return" %}
{%- if val is string %}
{{- '"' + key + '": "' + val + '"' }}
{%- else %}
{{- '"' + key + '": ' + val|tojson }}
{%- endif %}
{%- if not loop.last %}
{{- ", " }}
{%- endif %}
{%- endfor %}
{{- "}}" }}
{%- if not loop.last %}
{{- ", " }}
{%- else %}
{{- "]" }}
{%- endif %}
{%- endfor %}
{{- "[/AVAILABLE_TOOLS]" }}
{%- endif %}
{%- if message['content'] is string %}
{{- '[INST]' + message['content'] + '[/INST]\n' }}
{%- else %}
{{- '[INST]' }}
{%- for block in message['content'] %}
{%- if block['type'] == 'text' %}
{{- block['text'] }}
{%- elif block['type'] == 'image' or block['type'] == 'image_url' %}
{{- '[IMG]' }}
{%- else %}
{{- raise_exception('Only text and image blocks are supported in message content!') }}
{%- endif %}
{%- endfor %}
{{- '[/INST]\n' }}
{%- endif %}
{%- if enable_thinking is defined %}
{%- if enable_thinking is false %}
{{- 'Answer directly\n' }}
{%- endif %}
{%- endif %}
{%- elif message["role"] == "tool_calls" or message.tool_calls is defined %}
{%- if message.tool_calls is defined %}
{%- set tool_calls = message.tool_calls %}
{%- else %}
{%- set tool_calls = message.content %}
{%- endif %}
{%- for tool_call in tool_calls %}
{{- "[TOOL_CALLS]" }}
{{- tool_call.function.name }}
{%- if not tool_call.id is defined or tool_call.id|length < 9 %}
{{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (1)" + tool_call.id) }}
{%- endif %}
{{- '[CALL_ID]' + tool_call.id[-9:] }}
{{- '[ARGS]' + tool_call.function.arguments|tojson }}
{%- endfor %}
{{- eos_token }}
{%- elif message['role'] == 'assistant' %}
{%- if message['content'] is string %}
{{- message['content'] + eos_token }}
{%- else %}
{{- message['content'][0]['text'] + eos_token }}
{%- endif %}
{%- elif message["role"] == "tool_results" or message["role"] == "tool" %}
{%- if message.content is defined and message.content.content is defined %}
{%- set content = message.content.content %}
{%- else %}
{%- set content = message.content %}
{%- endif %}
{{- '[TOOL_RESULTS] {"content": ' + content|string + ", " }}
{%- if not message.tool_call_id is defined or message.tool_call_id|length < 9 %}
{{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (2)" + message.tool_call_id) }}
{%- endif %}
{{- '"call_id": "' + message.tool_call_id[-9:] + '"}[/TOOL_RESULTS]' }}
{%- else %}
{{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}
{%- endif %}
{%- endfor %}"""
if torch.cuda.is_available():
#model_id = "mistralai/Mistral-Small-24B-Instruct-2501"
model_id = "AlexHung29629/add_vision_3"
#model_id = "AlexHung29629/Draft1"
#model_id = "AlexHung29629/tir_grpo"
#model_id = "AlexHung29629/my_checkpoint_1"
model = AutoModelForImageTextToText.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)
@spaces.GPU
def generate(
message: str,
chat_history: list[dict],
chat_template: str,
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.95,
top_k: int = 50,
repetition_penalty: float = 1.0,
) -> Iterator[str]:
conversation = [*chat_history, {"role": "user", "content": message}]
#input_ids = tokenizer.apply_chat_template(conversation, chat_template=chat_template, enable_thinking=False, return_tensors="pt")
input_ids = tokenizer.apply_chat_template(conversation, chat_template=chat_template, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=False)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
demo = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.TextArea(placeholder=CHAT_TEMPLATE, label="Chat template"),
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0.1,
maximum=4.0,
step=0.1,
value=0.3,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.95,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=40,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.0,
),
],
stop_btn=None,
examples=[
["請列舉三個台南美食"],
["Can you explain briefly to me what is the Python programming language?"],
["Explain the plot of Cinderella in a sentence."],
["How many hours does it take a man to eat a Helicopter?"],
["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
],
type="messages",
description=DESCRIPTION,
css_paths="style.css",
)
if __name__ == "__main__":
demo.queue(max_size=20).launch()