import gradio as gr from transformers import AutoProcessor, AutoModelForImageTextToText, TextIteratorStreamer from transformers.image_utils import load_image from threading import Thread import torch import pickle as pkl import re from PIL import Image import json import spaces import os from serve_constants import html_header, bibtext, learn_more_markdown, tos_markdown cur_dir = os.path.dirname(os.path.abspath(__file__)) MODEL_ID = "TIGER-Lab/PixelReasoner-RL-v1" processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True, max_pixels=512*28*28) model = AutoModelForImageTextToText.from_pretrained( MODEL_ID, trust_remote_code=True, torch_dtype=torch.bfloat16 ).to("cuda").eval() def zoom(image, bbox_2d,padding=(0.1,0.1)): """ Crop the image based on the bounding box coordinates. """ img_x, img_y = image.size padding_tr = (600.0/img_x,600.0/img_y) padding = (min(padding[0],padding_tr[0]),min(padding[1],padding_tr[1])) if bbox_2d[0] < 1 and bbox_2d[1] < 1 and bbox_2d[2] < 1 and bbox_2d[3] < 1: normalized_bbox_2d = (float(bbox_2d[0])-padding[0], float(bbox_2d[1])-padding[1], float(bbox_2d[2])+padding[0], float(bbox_2d[3])+padding[1]) else: normalized_bbox_2d = (float(bbox_2d[0])/img_x-padding[0], float(bbox_2d[1])/img_y-padding[1], float(bbox_2d[2])/img_x+padding[0], float(bbox_2d[3])/img_y+padding[1]) normalized_x1, normalized_y1, normalized_x2, normalized_y2 = normalized_bbox_2d normalized_x1 =min(max(0, normalized_x1), 1) normalized_y1 =min(max(0, normalized_y1), 1) normalized_x2 =min(max(0, normalized_x2), 1) normalized_y2 =min(max(0, normalized_y2), 1) cropped_img = image.crop((int(normalized_x1*img_x), int(normalized_y1*img_y), int(normalized_x2*img_x), int(normalized_y2*img_y))) w, h = cropped_img.size assert w > 28 and h > 28, f"Cropped image is too small: {w}x{h}" return cropped_img def execute_tool(images, rawimages, args, toolname, is_video, function=None): if toolname=='select_frames': tgt = args['target_frames'] if len(tgt)>8: message = f"You have selected {len(tgt)} frames in total. Think again which frames you need to check in details (no more than 8 frames)" # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16] ##### controlled modification if do_controlled_rectify and np.random.uniform()<0.75: if np.random.uniform()<0.25: tgt = tgt[:len(tgt)//2] elif np.random.uniform()<0.25/0.75: tgt = tgt[-len(tgt)//2:] elif np.random.uniform()<0.25/0.5: tgt = tgt[::2] else: tgt = np.random.choice(tgt, size=len(tgt)//2, replace=False) tgt = sorted(tgt) selected_frames = function(images[0], tgt) message = tgt else: selected_frames = [] # selected_frames = function(images[0], [x-1 for x in tgt][::2]) # video is always in the first item elif max(tgt)>len(images[0]): message = f"There are {len(images[0])} frames numbered in range [1,{len(images[0])}]. Your selection is out of range." selected_frames = [] else: message = "" candidates = images[0] if not isinstance(candidates, list): candidates = [candidates] selected_frames = function(candidates, [x-1 for x in tgt]) # video is always in the first item return selected_frames, message else: tgt = args['target_image'] if is_video: if len(images)==1: # there is only # we default the candidate images into video frames video_frames = images[0] index = tgt - 1 assert index\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n" messages = [{ "role": "user", "content": sysprompt }] hint = "\n\nGuidelines: Understand the given visual information and the user query. Determine if it is beneficial to employ the given visual operations (tools). For a video, we can look closer by `select_frames`. For an image, we can look closer by `crop_image_normalized`. Reason with the visual information step by step, and put your final answer within \\boxed{}." for val in history: if val[0]: if isinstance(val[0], str): messages.append({ "role": "user", "content": [ *[{"type": "image", "image": image} for image in current_message_images], {"type": "text", "text": val[0]}, ], }) current_message_images = [] else: # Load messages. These will be appended to the first user text message that comes after current_message_images = [load_image(image) for image in val[0]] all_images += current_message_images if val[1]: messages.append({"role": "assistant", "content": val[1]}) imagelist = rawimagelist = current_message_images = [load_image(image) for image in files] all_images += current_message_images messages.append({ "role": "user", "content": [ *[{"type": "image", "image": image} for image in current_message_images], {"type": "text", "text": text+hint}, ], }) print(messages) complete_assistant_response_for_gradio = [] while True: """ Generate and stream text """ prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = processor( text=[prompt], images=all_images if all_images else None, return_tensors="pt", padding=True, ).to("cuda") streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=False) generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024, temperature=0.1, top_p=0.95, top_k=50) # generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024, do_sample=False, num_beams=1) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() current_model_output_segment = "" # Text generated in this specific model call toolflag = False for new_text_chunk in streamer: current_model_output_segment += new_text_chunk # Yield the sum of previously committed full response parts + current streaming segment # yield complete_assistant_response_for_gradio + current_model_output_segment if tool_start in current_model_output_segment: toolflag = True tmp = current_model_output_segment.split(tool_start)[0] yield complete_assistant_response_for_gradio + [tmp+"\n\nPlanning Visual Operations ...\n\n"] if not toolflag: yield complete_assistant_response_for_gradio + [current_model_output_segment] thread.join() # Process the full segment (e.g., remove <|im_end|>) processed_segment = current_model_output_segment.split("<|im_end|>", 1)[0] if "<|im_end|>" in current_model_output_segment else current_model_output_segment # Append this processed segment to the cumulative display string for Gradio complete_assistant_response_for_gradio += [processed_segment + "\n\n"] yield complete_assistant_response_for_gradio # Ensure the fully processed segment is yielded to Gradio # Check for tool call in the *just generated* segment qatext_for_tool_check = processed_segment require_tool = tool_end in qatext_for_tool_check and tool_start in qatext_for_tool_check if require_tool: tool_params = parse_last_tool(qatext_for_tool_check) tool_name = tool_params['name'] tool_args = tool_params['arguments'] # complete_assistant_response_for_gradio += f"\nExecuting Visual Operations ... @{tool_name}({tool_args})\n\n" complete_assistant_response_for_gradio += [f"\nExecuting Visual Operations ... @{tool_name}({tool_args})\n\n"] yield complete_assistant_response_for_gradio # Update Gradio display video_flag = False print(f"candidate images", all_images) raw_result = execute_tool(all_images, all_images, tool_args, tool_name, is_video=video_flag) print(raw_result) proc_img = raw_result all_images += [proc_img] proc_img.save("tmp.png") display = [dict(text="", files=["tmp.png"])] complete_assistant_response_for_gradio = complete_assistant_response_for_gradio + display yield complete_assistant_response_for_gradio # Update Gradio display new_piece = dict(role='user', content=[ dict(type='text', text="\nHere is the cropped image (Image Size: {}x{}):".format(proc_img.size[0], proc_img.size[1])), dict(type='image', image=proc_img) ] ) messages.append(new_piece) # complete_assistant_response_for_gradio += f"\nAnalyzing Operation Result ... @region(size={proc_img.size[0]}x{proc_img.size[1]})\n\n" complete_assistant_response_for_gradio += [f"\nAnalyzing Operation Result ... @region(size={proc_img.size[0]}x{proc_img.size[1]})\n\n"] yield complete_assistant_response_for_gradio # Update Gradio display else: break with gr.Blocks() as demo: examples = [ [ { "text": "What kind of restaurant is it?", "files": [ "1.jpg" ] } ] ] gr.HTML(html_header) # image_op_display = gr.Image(label="Visual Operation Result", type="pil", height=480, show_download_button=True, interactive=False) gr.ChatInterface( fn=model_inference, description="# **Pixel Reasoner**", chatbot=gr.Chatbot(label="Conversation", layout="bubble", bubble_full_width=False, show_copy_button=True, height=600), examples=examples, # fill_height=True, textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple"), stop_btn="Stop Generation", multimodal=True, cache_examples=False, ) gr.Markdown(tos_markdown) gr.Markdown(learn_more_markdown) gr.Markdown(bibtext) demo.launch(debug=True, share=False)