Spaces:
Running
on
Zero
Running
on
Zero
File size: 13,925 Bytes
06253b0 cbc410e b5465a3 06253b0 31d3f3e 4f54f7b d5d2872 06253b0 d5d2872 06253b0 cbc410e 06253b0 d5d2872 cbc410e 06253b0 cbc410e 06253b0 cbc410e 06253b0 cbc410e 06253b0 cbc410e 06253b0 d5d2872 06253b0 cbc410e 06253b0 311548d cbc410e 06253b0 cbc410e 06253b0 d5d2872 b60bc3a f1bf896 b60bc3a d5d2872 06253b0 311548d 06253b0 311548d 06253b0 a585fed 06253b0 311548d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 |
import gradio as gr
from transformers import AutoProcessor, AutoModelForImageTextToText, TextIteratorStreamer
from transformers.image_utils import load_image
from threading import Thread
import torch
import pickle as pkl
import re
from PIL import Image
import json
import spaces
import os
from serve_constants import html_header, bibtext, learn_more_markdown, tos_markdown
cur_dir = os.path.dirname(os.path.abspath(__file__))
MODEL_ID = "TIGER-Lab/PixelReasoner-RL-v1"
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True,
max_pixels=512*28*28)
model = AutoModelForImageTextToText.from_pretrained(
MODEL_ID,
trust_remote_code=True,
torch_dtype=torch.bfloat16
).to("cuda").eval()
def zoom(image, bbox_2d,padding=(0.1,0.1)):
"""
Crop the image based on the bounding box coordinates.
"""
img_x, img_y = image.size
padding_tr = (600.0/img_x,600.0/img_y)
padding = (min(padding[0],padding_tr[0]),min(padding[1],padding_tr[1]))
if bbox_2d[0] < 1 and bbox_2d[1] < 1 and bbox_2d[2] < 1 and bbox_2d[3] < 1:
normalized_bbox_2d = (float(bbox_2d[0])-padding[0], float(bbox_2d[1])-padding[1], float(bbox_2d[2])+padding[0], float(bbox_2d[3])+padding[1])
else:
normalized_bbox_2d = (float(bbox_2d[0])/img_x-padding[0], float(bbox_2d[1])/img_y-padding[1], float(bbox_2d[2])/img_x+padding[0], float(bbox_2d[3])/img_y+padding[1])
normalized_x1, normalized_y1, normalized_x2, normalized_y2 = normalized_bbox_2d
normalized_x1 =min(max(0, normalized_x1), 1)
normalized_y1 =min(max(0, normalized_y1), 1)
normalized_x2 =min(max(0, normalized_x2), 1)
normalized_y2 =min(max(0, normalized_y2), 1)
cropped_img = image.crop((int(normalized_x1*img_x), int(normalized_y1*img_y), int(normalized_x2*img_x), int(normalized_y2*img_y)))
w, h = cropped_img.size
assert w > 28 and h > 28, f"Cropped image is too small: {w}x{h}"
return cropped_img
def execute_tool(images, rawimages, args, toolname, is_video, function=None):
if toolname=='select_frames':
tgt = args['target_frames']
if len(tgt)>8:
message = f"You have selected {len(tgt)} frames in total. Think again which frames you need to check in details (no more than 8 frames)"
# [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
##### controlled modification
if do_controlled_rectify and np.random.uniform()<0.75:
if np.random.uniform()<0.25:
tgt = tgt[:len(tgt)//2]
elif np.random.uniform()<0.25/0.75:
tgt = tgt[-len(tgt)//2:]
elif np.random.uniform()<0.25/0.5:
tgt = tgt[::2]
else:
tgt = np.random.choice(tgt, size=len(tgt)//2, replace=False)
tgt = sorted(tgt)
selected_frames = function(images[0], tgt)
message = tgt
else:
selected_frames = []
# selected_frames = function(images[0], [x-1 for x in tgt][::2]) # video is always in the first item
elif max(tgt)>len(images[0]):
message = f"There are {len(images[0])} frames numbered in range [1,{len(images[0])}]. Your selection is out of range."
selected_frames = []
else:
message = ""
candidates = images[0]
if not isinstance(candidates, list):
candidates = [candidates]
selected_frames = function(candidates, [x-1 for x in tgt]) # video is always in the first item
return selected_frames, message
else:
tgt = args['target_image']
if is_video:
if len(images)==1: # there is only
# we default the candidate images into video frames
video_frames = images[0]
index = tgt - 1
assert index<len(video_frames), f"Incorrect `target_image`. You can only select frames in the given video within [1,{len(video_frames)}]"
image_to_crop = video_frames[index]
else: # there are zoomed images after the video; images = [[video], img, img, img]
cand_images = images[1:]
index = tgt -1
assert index<len(cand_images), f"Incorrect `target_image`. You can only select a previous frame within [1,{len(cand_images)}]"
image_to_crop = cand_images[index]
else:
index = tgt-1
assert index<len(images), f"Incorrect `target_image`. You can only select previous images within [1,{len(images)}]"
if index<len(rawimages):
tmp = rawimages[index]
else:
tmp = images[index]
image_to_crop = tmp
if function is None: function = zoom
cropped_image = function(image_to_crop, args['bbox_2d'])
return cropped_image
def parse_last_tool(output_text):
# print([output_text])
return json.loads(output_text.split(tool_start)[-1].split(tool_end)[0])
tool_end = '</tool_call>'
tool_start = '<tool_call>'
@spaces.GPU
def model_inference(input_dict, history):
text = input_dict["text"]
files = input_dict["files"]
"""
Create chat history
Example history value:
[
[('pixel.png',), None],
['ignore this image. just say "hi" and nothing else', 'Hi!'],
['just say "hi" and nothing else', 'Hi!']
]
"""
all_images = []
current_message_images = []
sysprompt = "<|im_start|>system\nYou are a helpful assistant.\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{\"type\": \"function\", \"function\": {\"name\": \"crop_image\", \"description\": \"Zoom in on the image based on the bounding box coordinates.\", \"parameters\": {\"type\": \"object\", \"properties\": {\"bbox_2d\": {\"type\": \"array\", \"description\": \"coordinates for bounding box of the area you want to zoom in. minimum value is 0 and maximum value is the width/height of the image.\", \"items\": {\"type\": \"number\"}}, \"target_image\": {\"type\": \"number\", \"description\": \"The index of the image to crop. Index from 1 to the number of images. Choose 1 to operate on original image.\"}}, \"required\": [\"bbox_2d\", \"target_image\"]}}}\n{\"type\": \"function\", \"function\": {\"name\": \"select_frames\", \"description\": \"Select frames from a video.\", \"parameters\": {\"type\": \"object\", \"properties\": {\"target_frames\": {\"type\": \"array\", \"description\": \"List of frame indices to select from the video (no more than 8 frames in total).\", \"items\": {\"type\": \"integer\", \"description\": \"Frame index from 1 to 16.\"}}}, \"required\": [\"target_frames\"]}}}\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call>"
messages = [{
"role": "user",
"content": sysprompt
}]
hint = "\n\nGuidelines: Understand the given visual information and the user query. Determine if it is beneficial to employ the given visual operations (tools). For a video, we can look closer by `select_frames`. For an image, we can look closer by `crop_image`. Reason with the visual information step by step, and put your final answer within \\boxed{}."
for val in history:
if val[0]:
if isinstance(val[0], str):
messages.append({
"role": "user",
"content": [
*[{"type": "image", "image": image} for image in current_message_images],
{"type": "text", "text": val[0]},
],
})
current_message_images = []
else:
# Load messages. These will be appended to the first user text message that comes after
current_message_images = [load_image(image) for image in val[0]]
all_images += current_message_images
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
imagelist = rawimagelist = current_message_images = [load_image(image) for image in files]
all_images += current_message_images
messages.append({
"role": "user",
"content": [
*[{"type": "image", "image": image} for image in current_message_images],
{"type": "text", "text": text+hint},
],
})
print(messages)
complete_assistant_response_for_gradio = []
while True:
"""
Generate and stream text
"""
prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(
text=[prompt],
images=all_images if all_images else None,
return_tensors="pt",
padding=True,
).to("cuda")
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=False)
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024, temperature=0.1, top_p=0.95, top_k=50)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
current_model_output_segment = "" # Text generated in this specific model call
toolflag = False
for new_text_chunk in streamer:
current_model_output_segment += new_text_chunk
# Yield the sum of previously committed full response parts + current streaming segment
# yield complete_assistant_response_for_gradio + current_model_output_segment
if tool_start in current_model_output_segment:
toolflag = True
tmp = current_model_output_segment.split(tool_start)[0]
yield complete_assistant_response_for_gradio + [tmp+"\n\n<b>Planning Visual Operations ...</b>\n\n"]
if not toolflag:
yield complete_assistant_response_for_gradio + [current_model_output_segment]
thread.join()
# Process the full segment (e.g., remove <|im_end|>)
processed_segment = current_model_output_segment.split("<|im_end|>", 1)[0] if "<|im_end|>" in current_model_output_segment else current_model_output_segment
# Append this processed segment to the cumulative display string for Gradio
complete_assistant_response_for_gradio += [processed_segment + "\n\n"]
yield complete_assistant_response_for_gradio # Ensure the fully processed segment is yielded to Gradio
# Check for tool call in the *just generated* segment
qatext_for_tool_check = processed_segment
require_tool = tool_end in qatext_for_tool_check and tool_start in qatext_for_tool_check
if require_tool:
tool_params = parse_last_tool(qatext_for_tool_check)
tool_name = tool_params['name']
tool_args = tool_params['arguments']
# complete_assistant_response_for_gradio += f"\n<b>Executing Visual Operations ...</b> @{tool_name}({tool_args})\n\n"
complete_assistant_response_for_gradio += [f"\n<b>Executing Visual Operations ...</b> @{tool_name}({tool_args})\n\n"]
yield complete_assistant_response_for_gradio # Update Gradio display
video_flag = False
raw_result = execute_tool(imagelist, rawimagelist, tool_args, tool_name, is_video=video_flag)
print(raw_result)
proc_img = raw_result
all_images += [proc_img]
proc_img.save("tmp.png")
display = [dict(text="", files=["tmp.png"])]
complete_assistant_response_for_gradio = complete_assistant_response_for_gradio + display
yield complete_assistant_response_for_gradio # Update Gradio display
new_piece = dict(role='user', content=[
dict(type='text', text="\nHere is the cropped image (Image Size: {}x{}):".format(proc_img.size[0], proc_img.size[1])),
dict(type='image', image=proc_img)
]
)
messages.append(new_piece)
# complete_assistant_response_for_gradio += f"\n<b>Analyzing Operation Result ...</b> @region(size={proc_img.size[0]}x{proc_img.size[1]})\n\n"
complete_assistant_response_for_gradio += [f"\n<b>Analyzing Operation Result ...</b> @region(size={proc_img.size[0]}x{proc_img.size[1]})\n\n"]
yield complete_assistant_response_for_gradio # Update Gradio display
else:
break
with gr.Blocks() as demo:
examples = [
[
{
"text": "What kind of restaurant is it?",
"files": [
"1.jpg"
]
}
]
]
gr.HTML(html_header)
# image_op_display = gr.Image(label="Visual Operation Result", type="pil", height=480, show_download_button=True, interactive=False)
gr.ChatInterface(
fn=model_inference,
description="# **Pixel Reasoner**",
chatbot=gr.Chatbot(label="Conversation", layout="bubble", bubble_full_width=False, show_copy_button=True, height=600),
examples=examples,
# fill_height=True,
textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple"),
stop_btn="Stop Generation",
multimodal=True,
cache_examples=False,
)
gr.Markdown(tos_markdown)
gr.Markdown(learn_more_markdown)
gr.Markdown(bibtext)
demo.launch(debug=True, share=False) |