File size: 10,878 Bytes
8824f88
 
 
749d210
8824f88
 
 
 
 
158f5d9
8824f88
92abb65
8824f88
 
 
 
 
 
7cb6017
8824f88
a058d44
 
2ef161a
 
a058d44
 
 
 
 
 
 
2ef161a
 
 
 
a058d44
 
 
 
 
 
 
 
 
 
2ef161a
a058d44
2ef161a
a058d44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ef161a
a058d44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ef161a
a058d44
2ef161a
 
 
8824f88
533d81f
2ccd67f
533d81f
 
 
 
2ccd67f
8824f88
 
 
 
 
749d210
019333f
8824f88
 
9d50662
8824f88
9d50662
8824f88
749d210
8824f88
9d50662
 
0737a9d
 
34353a1
0737a9d
8824f88
9d50662
8824f88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
749d210
8824f88
 
b1780e2
8824f88
 
 
 
 
 
 
 
 
 
 
 
e98e7d4
8824f88
 
 
 
 
 
e98e7d4
8824f88
 
 
 
 
 
e98e7d4
8824f88
 
 
 
 
 
e98e7d4
8824f88
 
 
 
8d547a3
8824f88
 
 
 
 
749d210
 
 
8824f88
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
#!/usr/bin/env python

import os
from collections.abc import Iterator
from threading import Thread

import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, AutoModelForImageTextToText

DESCRIPTION = """# 測試"""

if not torch.cuda.is_available():
    DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"

MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))

CHAT_TEMPLATE="""{%- set today = strftime_now("%Y-%m-%d") %}
{%- set default_system_message = "Your knowledge base was last updated on 2023-10-01.\nThe current date is {today}.\n\nWhen you\'re not sure about some information or when the user\'s request requires up-to-date or specific data, you must use the available tools to fetch the information. Do not hesitate to use tools whenever they can provide a more accurate or complete response. If no relevant tools are available, then clearly state that you don\'t have the information and avoid making up anything.\nIf the user\'s question is not clear, ambiguous, or does not provide enough context for you to accurately answer the question, you do not try to answer it right away and you rather ask the user to clarify their request (e.g. \"What are some good restaurants around me?\" => \"Where are you?\" or \"When is the next flight to Tokyo\" => \"Where do you travel from?\").\nYou are always very attentive to dates, in particular you try to resolve dates and when asked about information at specific dates, you discard information that is at another date.\nYou follow these instructions in all languages, and always respond to the user in the language they use or request.\nNext sections describe the capabilities that you have.\n\n\n# MULTI-MODAL INSTRUCTIONS\n\nYou have the ability to read images, but you cannot generate images. You also cannot transcribe audio files or videos.\nYou cannot read nor transcribe audio files or videos." %}
{{- bos_token }}
{%- if messages[0]['role'] == 'system' %}
    {%- if messages[0]['content'] is string %}
        {%- set system_message = messages[0]['content'] %}
        {%- set loop_messages = messages[1:] %}
    {%- else %}
        {%- set system_message = messages[0]['content'][0]['text'] %}
        {%- set loop_messages = messages[1:] %}
    {%- endif %}
{%- else %}
    {%- set system_message = default_system_message %}
    {%- set loop_messages = messages %}
{%- endif %}
{%- if not tools is defined %}
    {%- set tools = none %}
{%- elif tools is not none %}
    {%- set parallel_tool_prompt = "# TOOL CALLING INSTRUCTIONS\n\nYou may have access to tools that you can use to fetch information or perform actions. You must use these tools in the following situations:\n\n1. When the request requires up-to-date information.\n2. When the request requires specific data that you do not have in your knowledge base.\n3. When the request involves actions that you cannot perform without tools.\n\nAlways prioritize using tools to provide the most accurate and helpful response. If tools are not available, inform the user that you cannot perform the requested action at the moment." %}
    {%- if system_message is defined %}
        {%- set system_message = system_message + "\n\n" + parallel_tool_prompt %}
    {%- else %}
        {%- set system_message = parallel_tool_prompt %}
    {%- endif %}
{%- endif %}
{{- '[SYSTEM_PROMPT]' + system_message + '[/SYSTEM_PROMPT]' }}
{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}
{%- for message in loop_messages %}
    {%- if message["role"] == "user" %}
        {%- if tools is not none and (message == user_messages[-1]) %}
            {{- "[AVAILABLE_TOOLS] [" }}
            {%- for tool in tools %}
                {%- set tool = tool.function %}
                {{- '{"type": "function", "function": {' }}
                {%- for key, val in tool.items() if key != "return" %}
                    {%- if val is string %}
                        {{- '"' + key + '": "' + val + '"' }}
                    {%- else %}
                        {{- '"' + key + '": ' + val|tojson }}
                    {%- endif %}
                    {%- if not loop.last %}
                        {{- ", " }}
                    {%- endif %}
                {%- endfor %}
                {{- "}}" }}
                {%- if not loop.last %}
                    {{- ", " }}
                {%- else %}
                    {{- "]" }}
                {%- endif %}
            {%- endfor %}
            {{- "[/AVAILABLE_TOOLS]" }}
        {%- endif %}
        {%- if message['content'] is string %}
            {{- '[INST]' + message['content'] + '[/INST]<think>\n' }}
        {%- else %}
            {{- '[INST]' }}
            {%- for block in message['content'] %}
                    {%- if block['type'] == 'text' %}
                            {{- block['text'] }}
                    {%- elif block['type'] == 'image' or block['type'] == 'image_url' %}
                            {{- '[IMG]' }}
                        {%- else %}
                            {{- raise_exception('Only text and image blocks are supported in message content!') }}
                        {%- endif %}
                {%- endfor %}
            {{- '[/INST]<think>\n' }}
        {%- endif %}
        {%- if enable_thinking is defined %}
            {%- if enable_thinking is false %}
                {{- 'Answer directly</think>\n' }}
            {%- endif %}
        {%- endif %}
    {%- elif message["role"] == "tool_calls" or message.tool_calls is defined %}
        {%- if message.tool_calls is defined %}
            {%- set tool_calls = message.tool_calls %}
        {%- else %}
            {%- set tool_calls = message.content %}
        {%- endif %}
        
        {%- for tool_call in tool_calls %}
            {{- "[TOOL_CALLS]" }}
            {{- tool_call.function.name }}
            {%- if not tool_call.id is defined or tool_call.id|length < 9 %}
                {{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (1)" + tool_call.id) }}
            {%- endif %}
            {{- '[CALL_ID]' + tool_call.id[-9:] }}
            {{- '[ARGS]' + tool_call.function.arguments|tojson }}
        {%- endfor %}
        {{- eos_token }}
    {%- elif message['role'] == 'assistant' %}
        {%- if message['content'] is string %}
            {{- message['content'] + eos_token }}
        {%- else %}
            {{- message['content'][0]['text'] + eos_token }}
        {%- endif %}
    {%- elif message["role"] == "tool_results" or message["role"] == "tool" %}
        {%- if message.content is defined and message.content.content is defined %}
            {%- set content = message.content.content %}
        {%- else %}
            {%- set content = message.content %}
        {%- endif %}
        {{- '[TOOL_RESULTS] {"content": ' + content|string + ", " }}
        {%- if not message.tool_call_id is defined or message.tool_call_id|length < 9 %}
            {{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (2)" + message.tool_call_id) }}
        {%- endif %}
        {{- '"call_id": "' + message.tool_call_id[-9:] + '"}[/TOOL_RESULTS]' }}
    {%- else %}
        {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}
    {%- endif %}
{%- endfor %}"""

if torch.cuda.is_available():
    #model_id = "mistralai/Mistral-Small-24B-Instruct-2501"
    model_id = "AlexHung29629/add_vision_3"
    #model_id = "AlexHung29629/Draft1"
    #model_id = "AlexHung29629/tir_grpo"
    #model_id = "AlexHung29629/my_checkpoint_1"
    model = AutoModelForImageTextToText.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
    tokenizer = AutoTokenizer.from_pretrained(model_id)


@spaces.GPU
def generate(
    message: str,
    chat_history: list[dict],
    chat_template: str,
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
    top_p: float = 0.95,
    top_k: int = 50,
    repetition_penalty: float = 1.0,
) -> Iterator[str]:
    conversation = [*chat_history, {"role": "user", "content": message}]

    #input_ids = tokenizer.apply_chat_template(conversation, chat_template=chat_template, enable_thinking=False, return_tensors="pt")
    input_ids = tokenizer.apply_chat_template(conversation, chat_template=chat_template, return_tensors="pt")
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    input_ids = input_ids.to(model.device)

    streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=False)
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        repetition_penalty=repetition_penalty,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)


demo = gr.ChatInterface(
    fn=generate,
    additional_inputs=[
        gr.TextArea(placeholder=CHAT_TEMPLATE, label="Chat template"),
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        ),
        gr.Slider(
            label="Temperature",
            minimum=0.1,
            maximum=4.0,
            step=0.1,
            value=0.3,
        ),
        gr.Slider(
            label="Top-p (nucleus sampling)",
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=0.95,
        ),
        gr.Slider(
            label="Top-k",
            minimum=1,
            maximum=1000,
            step=1,
            value=40,
        ),
        gr.Slider(
            label="Repetition penalty",
            minimum=1.0,
            maximum=2.0,
            step=0.05,
            value=1.0,
        ),
    ],
    stop_btn=None,
    examples=[
        ["請列舉三個台南美食"],
        ["Can you explain briefly to me what is the Python programming language?"],
        ["Explain the plot of Cinderella in a sentence."],
        ["How many hours does it take a man to eat a Helicopter?"],
        ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
    ],
    type="messages",
    description=DESCRIPTION,
    css_paths="style.css",
)

if __name__ == "__main__":
    demo.queue(max_size=20).launch()