pooyanrg commited on
Commit
ffbbba2
·
1 Parent(s): 26cf243
Files changed (1) hide show
  1. app.py +90 -90
app.py CHANGED
@@ -1,10 +1,10 @@
1
  import gradio as gr
2
  import numpy as np
3
  from PIL import Image, ImageDraw
4
- # import torch
5
- # from torchvision.transforms import Compose, Resize, ToTensor, Normalize
6
- # from utils.model import init_model
7
- # from utils.tokenization_clip import SimpleTokenizer as ClipTokenizer
8
 
9
  from fastapi.staticfiles import StaticFiles
10
  from fileservice import app
@@ -16,22 +16,22 @@ html_text = """
16
  </div>
17
  """
18
 
19
- # def image_to_tensor(image_path):
20
- # image = Image.open(image_path).convert('RGB')
21
 
22
- # preprocess = Compose([
23
- # Resize([224, 224], interpolation=Image.BICUBIC),
24
- # lambda image: image.convert("RGB"),
25
- # ToTensor(),
26
- # Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
27
- # ])
28
- # image_data = preprocess(image)
29
 
30
- # return {'image': image_data}
31
 
32
- # def get_image_data(image_path):
33
- # image_input = image_to_tensor(image_path)
34
- # return image_input
35
 
36
  def get_intervention_vector(selected_cells_bef, selected_cells_aft):
37
  left = np.reshape(np.zeros((1, 14 * 14)), (14, 14))
@@ -59,83 +59,83 @@ def get_intervention_vector(selected_cells_bef, selected_cells_aft):
59
 
60
  return left_map, right_map
61
 
62
- # def _get_rawimage(image_path):
63
- # # Pair x L x T x 3 x H x W
64
- # image = np.zeros((1, 3, 224,
65
- # 224), dtype=np.float)
66
 
67
- # for i in range(1):
68
 
69
- # raw_image_data = get_image_data(image_path)
70
- # raw_image_data = raw_image_data['image']
71
 
72
- # image[i] = raw_image_data
73
 
74
- # return image
75
 
76
 
77
- # def greedy_decode(model, tokenizer, video, video_mask, gt_left_map, gt_right_map):
78
- # visual_output, left_map, right_map = model.get_sequence_visual_output(video, video_mask,
79
- # gt_left_map[:, 0, :].squeeze(), gt_right_map[:, 0, :].squeeze())
80
 
81
- # video_mask = torch.ones(visual_output.shape[0], visual_output.shape[1], device=visual_output.device).long()
82
- # input_caption_ids = torch.zeros(visual_output.shape[0], device=visual_output.device).data.fill_(tokenizer.vocab["<|startoftext|>"])
83
- # input_caption_ids = input_caption_ids.long().unsqueeze(1)
84
- # decoder_mask = torch.ones_like(input_caption_ids)
85
- # for i in range(32):
86
- # decoder_scores = model.decoder_caption(visual_output, video_mask, input_caption_ids, decoder_mask, get_logits=True)
87
- # next_words = decoder_scores[:, -1].max(1)[1].unsqueeze(1)
88
- # input_caption_ids = torch.cat([input_caption_ids, next_words], 1)
89
- # next_mask = torch.ones_like(next_words)
90
- # decoder_mask = torch.cat([decoder_mask, next_mask], 1)
91
-
92
- # return input_caption_ids[:, 1:].tolist(), left_map, right_map
93
 
94
  # Dummy prediction function
95
- # def predict_image(image_bef, image_aft, selected_cells_bef, selected_cells_aft):
96
- # if image_bef is None:
97
- # return "No image provided", "", ""
98
- # if image_aft is None:
99
- # return "No image provided", "", ""
100
 
101
 
102
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
103
 
104
- # model = init_model('data/pytorch_model.pt', device)
105
 
106
- # tokenizer = ClipTokenizer()
107
 
108
- # left_map, right_map = get_intervention_vector(selected_cells_bef, selected_cells_aft)
109
 
110
- # left_map, right_map = torch.from_numpy(left_map).unsqueeze(0), torch.from_numpy(right_map).unsqueeze(0)
111
 
112
- # bef_image = torch.from_numpy(_get_rawimage(image_bef)).unsqueeze(1)
113
- # aft_image = torch.from_numpy(_get_rawimage(image_aft)).unsqueeze(1)
114
 
115
- # image_pair = torch.cat([bef_image, aft_image], 1)
116
 
117
- # image_mask = torch.from_numpy(np.ones(2, dtype=np.long)).unsqueeze(0)
118
 
119
- # result_list, left_map, right_map = greedy_decode(model, tokenizer, image_pair, image_mask, left_map, right_map)
120
 
121
 
122
- # decode_text_list = tokenizer.convert_ids_to_tokens(result_list[0])
123
- # if "<|endoftext|>" in decode_text_list:
124
- # SEP_index = decode_text_list.index("<|endoftext|>")
125
- # decode_text_list = decode_text_list[:SEP_index]
126
- # if "!" in decode_text_list:
127
- # PAD_index = decode_text_list.index("!")
128
- # decode_text_list = decode_text_list[:PAD_index]
129
- # decode_text = decode_text_list.strip()
130
-
131
- # # Generate dummy predictions
132
- # pred = f"{decode_text}"
133
 
134
- # # Include information about selected cells
135
- # selected_info_bef = f"{selected_cells_bef}" if selected_cells_bef else "No image patch was selected"
136
- # selected_info_aft = f"{selected_cells_aft}" if selected_cells_aft else "No image patch was selected"
137
 
138
- # return pred, selected_info_bef, selected_info_aft
139
 
140
  # Add grid to the image
141
  def add_grid_to_image(image_path, grid_size=14):
@@ -297,25 +297,25 @@ with gr.Blocks() as demo:
297
  html = gr.HTML(html_text)
298
 
299
  # Connect the predict button to the prediction function
300
- # predict_btn.click(
301
- # fn=predict_image,
302
- # inputs=[image_bef, image_aft, selected_cells_bef, selected_cells_aft],
303
- # outputs=[prediction, selected_info_bef, selected_info_aft]
304
- # )
305
-
306
- # image_bef.change(
307
- # fn=None,
308
- # inputs=[image_bef],
309
- # outputs=[],
310
- # js="(image) => { initializeEditor(); importBackground(image); return []; }",
311
- # )
312
-
313
- # image_aft.change(
314
- # fn=None,
315
- # inputs=[image_aft],
316
- # outputs=[],
317
- # js="(image) => { initializeEditor(); importBackground(image); return []; }",
318
- # )
319
 
320
 
321
 
 
1
  import gradio as gr
2
  import numpy as np
3
  from PIL import Image, ImageDraw
4
+ import torch
5
+ from torchvision.transforms import Compose, Resize, ToTensor, Normalize
6
+ from utils.model import init_model
7
+ from utils.tokenization_clip import SimpleTokenizer as ClipTokenizer
8
 
9
  from fastapi.staticfiles import StaticFiles
10
  from fileservice import app
 
16
  </div>
17
  """
18
 
19
+ def image_to_tensor(image_path):
20
+ image = Image.open(image_path).convert('RGB')
21
 
22
+ preprocess = Compose([
23
+ Resize([224, 224], interpolation=Image.BICUBIC),
24
+ lambda image: image.convert("RGB"),
25
+ ToTensor(),
26
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
27
+ ])
28
+ image_data = preprocess(image)
29
 
30
+ return {'image': image_data}
31
 
32
+ def get_image_data(image_path):
33
+ image_input = image_to_tensor(image_path)
34
+ return image_input
35
 
36
  def get_intervention_vector(selected_cells_bef, selected_cells_aft):
37
  left = np.reshape(np.zeros((1, 14 * 14)), (14, 14))
 
59
 
60
  return left_map, right_map
61
 
62
+ def _get_rawimage(image_path):
63
+ # Pair x L x T x 3 x H x W
64
+ image = np.zeros((1, 3, 224,
65
+ 224), dtype=np.float)
66
 
67
+ for i in range(1):
68
 
69
+ raw_image_data = get_image_data(image_path)
70
+ raw_image_data = raw_image_data['image']
71
 
72
+ image[i] = raw_image_data
73
 
74
+ return image
75
 
76
 
77
+ def greedy_decode(model, tokenizer, video, video_mask, gt_left_map, gt_right_map):
78
+ visual_output, left_map, right_map = model.get_sequence_visual_output(video, video_mask,
79
+ gt_left_map[:, 0, :].squeeze(), gt_right_map[:, 0, :].squeeze())
80
 
81
+ video_mask = torch.ones(visual_output.shape[0], visual_output.shape[1], device=visual_output.device).long()
82
+ input_caption_ids = torch.zeros(visual_output.shape[0], device=visual_output.device).data.fill_(tokenizer.vocab["<|startoftext|>"])
83
+ input_caption_ids = input_caption_ids.long().unsqueeze(1)
84
+ decoder_mask = torch.ones_like(input_caption_ids)
85
+ for i in range(32):
86
+ decoder_scores = model.decoder_caption(visual_output, video_mask, input_caption_ids, decoder_mask, get_logits=True)
87
+ next_words = decoder_scores[:, -1].max(1)[1].unsqueeze(1)
88
+ input_caption_ids = torch.cat([input_caption_ids, next_words], 1)
89
+ next_mask = torch.ones_like(next_words)
90
+ decoder_mask = torch.cat([decoder_mask, next_mask], 1)
91
+
92
+ return input_caption_ids[:, 1:].tolist(), left_map, right_map
93
 
94
  # Dummy prediction function
95
+ def predict_image(image_bef, image_aft, selected_cells_bef, selected_cells_aft):
96
+ if image_bef is None:
97
+ return "No image provided", "", ""
98
+ if image_aft is None:
99
+ return "No image provided", "", ""
100
 
101
 
102
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
103
 
104
+ model = init_model('data/pytorch_model.pt', device)
105
 
106
+ tokenizer = ClipTokenizer()
107
 
108
+ left_map, right_map = get_intervention_vector(selected_cells_bef, selected_cells_aft)
109
 
110
+ left_map, right_map = torch.from_numpy(left_map).unsqueeze(0), torch.from_numpy(right_map).unsqueeze(0)
111
 
112
+ bef_image = torch.from_numpy(_get_rawimage(image_bef)).unsqueeze(1)
113
+ aft_image = torch.from_numpy(_get_rawimage(image_aft)).unsqueeze(1)
114
 
115
+ image_pair = torch.cat([bef_image, aft_image], 1)
116
 
117
+ image_mask = torch.from_numpy(np.ones(2, dtype=np.long)).unsqueeze(0)
118
 
119
+ result_list, left_map, right_map = greedy_decode(model, tokenizer, image_pair, image_mask, left_map, right_map)
120
 
121
 
122
+ decode_text_list = tokenizer.convert_ids_to_tokens(result_list[0])
123
+ if "<|endoftext|>" in decode_text_list:
124
+ SEP_index = decode_text_list.index("<|endoftext|>")
125
+ decode_text_list = decode_text_list[:SEP_index]
126
+ if "!" in decode_text_list:
127
+ PAD_index = decode_text_list.index("!")
128
+ decode_text_list = decode_text_list[:PAD_index]
129
+ decode_text = decode_text_list.strip()
130
+
131
+ # Generate dummy predictions
132
+ pred = f"{decode_text}"
133
 
134
+ # Include information about selected cells
135
+ selected_info_bef = f"{selected_cells_bef}" if selected_cells_bef else "No image patch was selected"
136
+ selected_info_aft = f"{selected_cells_aft}" if selected_cells_aft else "No image patch was selected"
137
 
138
+ return pred, selected_info_bef, selected_info_aft
139
 
140
  # Add grid to the image
141
  def add_grid_to_image(image_path, grid_size=14):
 
297
  html = gr.HTML(html_text)
298
 
299
  # Connect the predict button to the prediction function
300
+ predict_btn.click(
301
+ fn=predict_image,
302
+ inputs=[image_bef, image_aft, selected_cells_bef, selected_cells_aft],
303
+ outputs=[prediction, selected_info_bef, selected_info_aft]
304
+ )
305
+
306
+ image_bef.change(
307
+ fn=None,
308
+ inputs=[image_bef],
309
+ outputs=[],
310
+ js="(image) => { initializeEditor(); importBackground(image); return []; }",
311
+ )
312
+
313
+ image_aft.change(
314
+ fn=None,
315
+ inputs=[image_aft],
316
+ outputs=[],
317
+ js="(image) => { initializeEditor(); importBackground(image); return []; }",
318
+ )
319
 
320
 
321