Spaces:
Running
Running
| import gradio as gr | |
| import ast | |
| import numpy as np | |
| from PIL import Image, ImageDraw | |
| import torch | |
| from torchvision.transforms import Compose, Resize, ToTensor, Normalize | |
| from utils.model import init_model | |
| from utils.tokenization_clip import SimpleTokenizer as ClipTokenizer | |
| from fastapi.staticfiles import StaticFiles | |
| from fileservice import app | |
| def image_to_tensor(image_path): | |
| image = Image.open(image_path).convert('RGB') | |
| preprocess = Compose([ | |
| Resize([224, 224], interpolation=Image.BICUBIC), | |
| lambda image: image.convert("RGB"), | |
| ToTensor(), | |
| Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), | |
| ]) | |
| image_data = preprocess(image) | |
| return {'image': image_data} | |
| def get_image_data(image_path): | |
| image_input = image_to_tensor(image_path) | |
| return image_input | |
| def parse_bool_string(s): | |
| try: | |
| bool_list = ast.literal_eval(s) | |
| if not isinstance(bool_list, list): | |
| raise ValueError("The input string must represent a list.") | |
| return bool_list | |
| except (SyntaxError, ValueError) as e: | |
| raise ValueError(f"Invalid input string: {e}") | |
| def get_intervention_vector(selected_cells_bef, selected_cells_aft): | |
| first_ = True | |
| second_ = True | |
| left_map = np.zeros((1, 14 * 14 + 1)) | |
| right_map = np.zeros((1, 14 * 14 + 1)) | |
| left_map[0, 1:] = np.reshape(selected_cells_bef, (1, 14 * 14)) | |
| right_map[0, 1:] = np.reshape(selected_cells_aft, (1, 14 * 14)) | |
| if np.count_nonzero(selected_cells_bef) == 0: | |
| left_map[0, 0] = 1.0 | |
| first_ = False | |
| if np.count_nonzero(selected_cells_aft) == 0: | |
| right_map[0, 0] = 1.0 | |
| second_ = False | |
| return left_map, right_map, first_, second_ | |
| def _get_rawimage(image_path): | |
| # Pair x L x T x 3 x H x W | |
| image = np.zeros((1, 3, 224, | |
| 224), dtype=np.float) | |
| for i in range(1): | |
| raw_image_data = get_image_data(image_path) | |
| raw_image_data = raw_image_data['image'] | |
| image[i] = raw_image_data | |
| return image | |
| def greedy_decode(model, tokenizer, video, video_mask, gt_left_map, gt_right_map): | |
| visual_output, left_map, right_map = model.get_sequence_visual_output(video, video_mask, | |
| gt_left_map[:, 0, :].squeeze(), gt_right_map[:, 0, :].squeeze()) | |
| video_mask = torch.ones(visual_output.shape[0], visual_output.shape[1], device=visual_output.device).long() | |
| input_caption_ids = torch.zeros(visual_output.shape[0], device=visual_output.device).data.fill_(tokenizer.vocab["<|startoftext|>"]) | |
| input_caption_ids = input_caption_ids.long().unsqueeze(1) | |
| decoder_mask = torch.ones_like(input_caption_ids) | |
| for i in range(32): | |
| decoder_scores = model.decoder_caption(visual_output, video_mask, input_caption_ids, decoder_mask, get_logits=True) | |
| next_words = decoder_scores[:, -1].max(1)[1].unsqueeze(1) | |
| input_caption_ids = torch.cat([input_caption_ids, next_words], 1) | |
| next_mask = torch.ones_like(next_words) | |
| decoder_mask = torch.cat([decoder_mask, next_mask], 1) | |
| return input_caption_ids[:, 1:].tolist(), left_map, right_map | |
| # Dummy prediction function | |
| def predict_image(image_bef, image_aft, json_data_bef, json_data_aft): | |
| if image_bef is None: | |
| return "No image provided", "", "" | |
| if image_aft is None: | |
| return "No image provided", "", "" | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = init_model('data/pytorch_model.pt', device) | |
| tokenizer = ClipTokenizer() | |
| selected_cells_bef = np.asarray(parse_bool_string(json_data_bef), np.int32) | |
| selected_cells_aft = np.asarray(parse_bool_string(json_data_aft), np.int32) | |
| left_map, right_map, first_, second_ = get_intervention_vector(selected_cells_bef, selected_cells_aft) | |
| left_map, right_map = torch.from_numpy(left_map).unsqueeze(0), torch.from_numpy(right_map).unsqueeze(0) | |
| bef_image = torch.from_numpy(_get_rawimage(image_bef)).unsqueeze(1) | |
| aft_image = torch.from_numpy(_get_rawimage(image_aft)).unsqueeze(1) | |
| image_pair = torch.cat([bef_image, aft_image], 1) | |
| image_mask = torch.from_numpy(np.ones(2, dtype=np.long)).unsqueeze(0) | |
| result_list, left_map, right_map = greedy_decode(model, tokenizer, image_pair, image_mask, left_map, right_map) | |
| decode_text_list = tokenizer.convert_ids_to_tokens(result_list[0]) | |
| if "<|endoftext|>" in decode_text_list: | |
| SEP_index = decode_text_list.index("<|endoftext|>") | |
| decode_text_list = decode_text_list[:SEP_index] | |
| if "!" in decode_text_list: | |
| PAD_index = decode_text_list.index("!") | |
| decode_text_list = decode_text_list[:PAD_index] | |
| decode_text = decode_text_list.strip() | |
| # Generate dummy predictions | |
| pred = f"{decode_text}" | |
| # Include information about selected cells | |
| i, j = np.nonzero(selected_cells_bef) | |
| selected_info_bef = f"{list(zip(i, j))}" if first_ else "No image patch was selected" | |
| i, j = np.nonzero(selected_cells_aft) | |
| selected_info_aft = f"{list(zip(i, j))}" if second_ else "No image patch was selected" | |
| return pred, selected_info_bef, selected_info_aft | |
| # Add grid to the image | |
| def add_grid_to_image(image_path, grid_size=14): | |
| if image_path is None: | |
| return None | |
| image = Image.open(image_path) | |
| w, h = image.size | |
| image = image.convert('RGBA') | |
| draw = ImageDraw.Draw(image) | |
| x_positions = np.linspace(0, w, grid_size + 1) | |
| y_positions = np.linspace(0, h, grid_size + 1) | |
| # Draw the vertical lines | |
| for x in x_positions[1:-1]: | |
| line = ((x, 0), (x, h)) | |
| draw.line(line, fill='white') | |
| # Draw the horizontal lines | |
| for y in y_positions[1:-1]: | |
| line = ((0, y), (w, y)) | |
| draw.line(line, fill='white') | |
| return image, h, w | |
| # Handle cell selection | |
| def handle_click(image, evt: gr.SelectData, selected_cells, image_path): | |
| if image is None: | |
| return None, [] | |
| grid_size = 14 | |
| image, h, w = add_grid_to_image(image_path, grid_size) | |
| x_positions = np.linspace(0, w, grid_size + 1) | |
| y_positions = np.linspace(0, h, grid_size + 1) | |
| # Calculate which cell was clicked | |
| for index, x in enumerate(x_positions[:-1]): | |
| if evt.index[0] >= x and evt.index[0] <= x_positions[index+1]: | |
| row = index | |
| for index, y in enumerate(y_positions[:-1]): | |
| if evt.index[1] >= y and evt.index[1] <= y_positions[index+1]: | |
| col = index | |
| cell_idx = (row, col) | |
| # Toggle selection | |
| if cell_idx in selected_cells: | |
| selected_cells.remove(cell_idx) | |
| else: | |
| selected_cells.append(cell_idx) | |
| # Add semi-transparent overlay for selected cells | |
| highlight_layer = Image.new('RGBA', (w, h), (0, 0, 0, 0)) # Fully transparent layer | |
| highlight_draw = ImageDraw.Draw(highlight_layer) | |
| # Define a lighter green color with 40% transparency | |
| light_green = (144, 238, 144, 102) # RGB = (144, 238, 144), Alpha = 102 (40% of 255) | |
| for (row, col) in selected_cells: | |
| cell_top_left = (x_positions[row], y_positions[col]) | |
| cell_bottom_right = (x_positions[row + 1], y_positions[col + 1]) | |
| highlight_draw.rectangle([cell_top_left, cell_bottom_right], fill=light_green, outline='white') | |
| result_img = Image.alpha_composite(image.convert('RGBA'), highlight_layer) | |
| return result_img, selected_cells | |
| # Process example images | |
| def process_example(image_path_bef, image_path_aft): | |
| # Add grid to the example image | |
| image_bef_grid, _, _ = add_grid_to_image(image_path_bef, 14) | |
| image_aft_grid, _, _ = add_grid_to_image(image_path_aft, 14) | |
| return image_bef_grid, image_aft_grid # Reset selected cells and store original image | |
| def get_image_size(image_path): | |
| w, h = Image.open(image_path).convert('RGB').size | |
| return w, h | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# TAB: Transformer Attention Bottleneck") | |
| # Instructions | |
| gr.Markdown(""" | |
| ## Instructions: | |
| 1. Upload an image or select one from the examples | |
| 2. Click on grid cells to select/deselect them | |
| 3. Click the 'Predict' button to get model predictions | |
| """) | |
| height = gr.State(value=320) | |
| width = gr.State(value=480) | |
| sel_attn_bef = gr.Textbox("", visible=False) | |
| sel_attn_aft = gr.Textbox("", visible=False) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # Input components with grid overlay | |
| image_bef = gr.Image(type="filepath", visible=True) | |
| image_aft = gr.Image(type="filepath", visible=True) | |
| predict_btn = gr.Button("Predict") | |
| with gr.Column(scale=1): | |
| html_text = f""" | |
| <div id="container"> | |
| <canvas id="before" style="width: 100%; height: auto;"></canvas><img id="canvas-before" style="display:none;"/> | |
| </div> | |
| <br> | |
| <div id="container"> | |
| <canvas id="after" style="width: 100%; height: auto;"></canvas><img id="canvas-after" style="display:none;"/> | |
| </div> | |
| """ | |
| html = gr.HTML(html_text) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # Example images | |
| examples = gr.Examples( | |
| examples=[["data/images/CLEVR_default_000572.png", "data/images/CLEVR_semantic_000572.png"], | |
| ["data/images/CLEVR_default_003339.png", "data/images/CLEVR_semantic_003339.png"]], | |
| inputs=[image_bef, image_aft], | |
| outputs=[width, height], | |
| label="Example Images", | |
| fn=get_image_size, | |
| examples_per_page=5 | |
| ) | |
| with gr.Column(scale=1): | |
| # Output components | |
| prediction = gr.Textbox(label="Predicted caption") | |
| selected_info_bef = gr.Textbox(label="Selected patches on before") | |
| selected_info_aft = gr.Textbox(label="Selected patches on after") | |
| # Connect the predict button to the prediction function | |
| predict_btn.click( | |
| fn=predict_image, | |
| inputs=[image_bef, image_aft, sel_attn_bef, sel_attn_aft], | |
| outputs=[prediction, selected_info_bef, selected_info_aft], | |
| _js="(image_bef, image_aft, sel_attn_bef, sel_attn_aft) => { return [image_bef, image_aft, read_js_Data_bef(), read_js_Data_aft()]; }" | |
| ) | |
| image_bef.change( | |
| fn=None, | |
| inputs=[image_bef], | |
| outputs=[], | |
| _js="(image_bef) => { importBackgroundBefore(image_bef); initializeEditorBefore(); return []; }", | |
| ) | |
| image_aft.change( | |
| fn=None, | |
| inputs=[image_aft], | |
| outputs=[], | |
| _js="(image_aft) => { importBackgroundAfter(image_aft); initializeEditorAfter(); return []; }", | |
| ) | |
| app.mount("/js", StaticFiles(directory="js"), name="js") | |
| gr.mount_gradio_app(app, demo, path="/") | |