Spaces:
Sleeping
Sleeping
File size: 10,878 Bytes
8824f88 749d210 8824f88 158f5d9 8824f88 92abb65 8824f88 7cb6017 8824f88 a058d44 2ef161a a058d44 2ef161a a058d44 2ef161a a058d44 2ef161a a058d44 2ef161a a058d44 2ef161a a058d44 2ef161a 8824f88 533d81f 2ccd67f 533d81f 2ccd67f 8824f88 749d210 019333f 8824f88 9d50662 8824f88 9d50662 8824f88 749d210 8824f88 9d50662 0737a9d 34353a1 0737a9d 8824f88 9d50662 8824f88 749d210 8824f88 b1780e2 8824f88 e98e7d4 8824f88 e98e7d4 8824f88 e98e7d4 8824f88 e98e7d4 8824f88 8d547a3 8824f88 749d210 8824f88 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 |
#!/usr/bin/env python
import os
from collections.abc import Iterator
from threading import Thread
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, AutoModelForImageTextToText
DESCRIPTION = """# 測試"""
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
CHAT_TEMPLATE="""{%- set today = strftime_now("%Y-%m-%d") %}
{%- set default_system_message = "Your knowledge base was last updated on 2023-10-01.\nThe current date is {today}.\n\nWhen you\'re not sure about some information or when the user\'s request requires up-to-date or specific data, you must use the available tools to fetch the information. Do not hesitate to use tools whenever they can provide a more accurate or complete response. If no relevant tools are available, then clearly state that you don\'t have the information and avoid making up anything.\nIf the user\'s question is not clear, ambiguous, or does not provide enough context for you to accurately answer the question, you do not try to answer it right away and you rather ask the user to clarify their request (e.g. \"What are some good restaurants around me?\" => \"Where are you?\" or \"When is the next flight to Tokyo\" => \"Where do you travel from?\").\nYou are always very attentive to dates, in particular you try to resolve dates and when asked about information at specific dates, you discard information that is at another date.\nYou follow these instructions in all languages, and always respond to the user in the language they use or request.\nNext sections describe the capabilities that you have.\n\n\n# MULTI-MODAL INSTRUCTIONS\n\nYou have the ability to read images, but you cannot generate images. You also cannot transcribe audio files or videos.\nYou cannot read nor transcribe audio files or videos." %}
{{- bos_token }}
{%- if messages[0]['role'] == 'system' %}
{%- if messages[0]['content'] is string %}
{%- set system_message = messages[0]['content'] %}
{%- set loop_messages = messages[1:] %}
{%- else %}
{%- set system_message = messages[0]['content'][0]['text'] %}
{%- set loop_messages = messages[1:] %}
{%- endif %}
{%- else %}
{%- set system_message = default_system_message %}
{%- set loop_messages = messages %}
{%- endif %}
{%- if not tools is defined %}
{%- set tools = none %}
{%- elif tools is not none %}
{%- set parallel_tool_prompt = "# TOOL CALLING INSTRUCTIONS\n\nYou may have access to tools that you can use to fetch information or perform actions. You must use these tools in the following situations:\n\n1. When the request requires up-to-date information.\n2. When the request requires specific data that you do not have in your knowledge base.\n3. When the request involves actions that you cannot perform without tools.\n\nAlways prioritize using tools to provide the most accurate and helpful response. If tools are not available, inform the user that you cannot perform the requested action at the moment." %}
{%- if system_message is defined %}
{%- set system_message = system_message + "\n\n" + parallel_tool_prompt %}
{%- else %}
{%- set system_message = parallel_tool_prompt %}
{%- endif %}
{%- endif %}
{{- '[SYSTEM_PROMPT]' + system_message + '[/SYSTEM_PROMPT]' }}
{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}
{%- for message in loop_messages %}
{%- if message["role"] == "user" %}
{%- if tools is not none and (message == user_messages[-1]) %}
{{- "[AVAILABLE_TOOLS] [" }}
{%- for tool in tools %}
{%- set tool = tool.function %}
{{- '{"type": "function", "function": {' }}
{%- for key, val in tool.items() if key != "return" %}
{%- if val is string %}
{{- '"' + key + '": "' + val + '"' }}
{%- else %}
{{- '"' + key + '": ' + val|tojson }}
{%- endif %}
{%- if not loop.last %}
{{- ", " }}
{%- endif %}
{%- endfor %}
{{- "}}" }}
{%- if not loop.last %}
{{- ", " }}
{%- else %}
{{- "]" }}
{%- endif %}
{%- endfor %}
{{- "[/AVAILABLE_TOOLS]" }}
{%- endif %}
{%- if message['content'] is string %}
{{- '[INST]' + message['content'] + '[/INST]<think>\n' }}
{%- else %}
{{- '[INST]' }}
{%- for block in message['content'] %}
{%- if block['type'] == 'text' %}
{{- block['text'] }}
{%- elif block['type'] == 'image' or block['type'] == 'image_url' %}
{{- '[IMG]' }}
{%- else %}
{{- raise_exception('Only text and image blocks are supported in message content!') }}
{%- endif %}
{%- endfor %}
{{- '[/INST]<think>\n' }}
{%- endif %}
{%- if enable_thinking is defined %}
{%- if enable_thinking is false %}
{{- 'Answer directly</think>\n' }}
{%- endif %}
{%- endif %}
{%- elif message["role"] == "tool_calls" or message.tool_calls is defined %}
{%- if message.tool_calls is defined %}
{%- set tool_calls = message.tool_calls %}
{%- else %}
{%- set tool_calls = message.content %}
{%- endif %}
{%- for tool_call in tool_calls %}
{{- "[TOOL_CALLS]" }}
{{- tool_call.function.name }}
{%- if not tool_call.id is defined or tool_call.id|length < 9 %}
{{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (1)" + tool_call.id) }}
{%- endif %}
{{- '[CALL_ID]' + tool_call.id[-9:] }}
{{- '[ARGS]' + tool_call.function.arguments|tojson }}
{%- endfor %}
{{- eos_token }}
{%- elif message['role'] == 'assistant' %}
{%- if message['content'] is string %}
{{- message['content'] + eos_token }}
{%- else %}
{{- message['content'][0]['text'] + eos_token }}
{%- endif %}
{%- elif message["role"] == "tool_results" or message["role"] == "tool" %}
{%- if message.content is defined and message.content.content is defined %}
{%- set content = message.content.content %}
{%- else %}
{%- set content = message.content %}
{%- endif %}
{{- '[TOOL_RESULTS] {"content": ' + content|string + ", " }}
{%- if not message.tool_call_id is defined or message.tool_call_id|length < 9 %}
{{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (2)" + message.tool_call_id) }}
{%- endif %}
{{- '"call_id": "' + message.tool_call_id[-9:] + '"}[/TOOL_RESULTS]' }}
{%- else %}
{{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}
{%- endif %}
{%- endfor %}"""
if torch.cuda.is_available():
#model_id = "mistralai/Mistral-Small-24B-Instruct-2501"
model_id = "AlexHung29629/add_vision_3"
#model_id = "AlexHung29629/Draft1"
#model_id = "AlexHung29629/tir_grpo"
#model_id = "AlexHung29629/my_checkpoint_1"
model = AutoModelForImageTextToText.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)
@spaces.GPU
def generate(
message: str,
chat_history: list[dict],
chat_template: str,
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.95,
top_k: int = 50,
repetition_penalty: float = 1.0,
) -> Iterator[str]:
conversation = [*chat_history, {"role": "user", "content": message}]
#input_ids = tokenizer.apply_chat_template(conversation, chat_template=chat_template, enable_thinking=False, return_tensors="pt")
input_ids = tokenizer.apply_chat_template(conversation, chat_template=chat_template, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=False)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
demo = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.TextArea(placeholder=CHAT_TEMPLATE, label="Chat template"),
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0.1,
maximum=4.0,
step=0.1,
value=0.3,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.95,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=40,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.0,
),
],
stop_btn=None,
examples=[
["請列舉三個台南美食"],
["Can you explain briefly to me what is the Python programming language?"],
["Explain the plot of Cinderella in a sentence."],
["How many hours does it take a man to eat a Helicopter?"],
["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
],
type="messages",
description=DESCRIPTION,
css_paths="style.css",
)
if __name__ == "__main__":
demo.queue(max_size=20).launch()
|