Spaces:
Running
Running
fix
Browse files
app.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
import gradio as gr
|
2 |
import numpy as np
|
3 |
from PIL import Image, ImageDraw
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
|
9 |
from fastapi.staticfiles import StaticFiles
|
10 |
from fileservice import app
|
@@ -16,22 +16,22 @@ html_text = """
|
|
16 |
</div>
|
17 |
"""
|
18 |
|
19 |
-
|
20 |
-
|
21 |
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
|
30 |
-
|
31 |
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
|
36 |
def get_intervention_vector(selected_cells_bef, selected_cells_aft):
|
37 |
left = np.reshape(np.zeros((1, 14 * 14)), (14, 14))
|
@@ -59,83 +59,83 @@ def get_intervention_vector(selected_cells_bef, selected_cells_aft):
|
|
59 |
|
60 |
return left_map, right_map
|
61 |
|
62 |
-
|
63 |
-
#
|
64 |
-
|
65 |
-
|
66 |
|
67 |
-
|
68 |
|
69 |
-
|
70 |
-
|
71 |
|
72 |
-
|
73 |
|
74 |
-
|
75 |
|
76 |
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
|
94 |
# Dummy prediction function
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
|
101 |
|
102 |
-
|
103 |
|
104 |
-
|
105 |
|
106 |
-
|
107 |
|
108 |
-
|
109 |
|
110 |
-
|
111 |
|
112 |
-
|
113 |
-
|
114 |
|
115 |
-
|
116 |
|
117 |
-
|
118 |
|
119 |
-
|
120 |
|
121 |
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
#
|
132 |
-
|
133 |
|
134 |
-
#
|
135 |
-
|
136 |
-
|
137 |
|
138 |
-
|
139 |
|
140 |
# Add grid to the image
|
141 |
def add_grid_to_image(image_path, grid_size=14):
|
@@ -297,25 +297,25 @@ with gr.Blocks() as demo:
|
|
297 |
html = gr.HTML(html_text)
|
298 |
|
299 |
# Connect the predict button to the prediction function
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
|
320 |
|
321 |
|
|
|
1 |
import gradio as gr
|
2 |
import numpy as np
|
3 |
from PIL import Image, ImageDraw
|
4 |
+
import torch
|
5 |
+
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
|
6 |
+
from utils.model import init_model
|
7 |
+
from utils.tokenization_clip import SimpleTokenizer as ClipTokenizer
|
8 |
|
9 |
from fastapi.staticfiles import StaticFiles
|
10 |
from fileservice import app
|
|
|
16 |
</div>
|
17 |
"""
|
18 |
|
19 |
+
def image_to_tensor(image_path):
|
20 |
+
image = Image.open(image_path).convert('RGB')
|
21 |
|
22 |
+
preprocess = Compose([
|
23 |
+
Resize([224, 224], interpolation=Image.BICUBIC),
|
24 |
+
lambda image: image.convert("RGB"),
|
25 |
+
ToTensor(),
|
26 |
+
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
27 |
+
])
|
28 |
+
image_data = preprocess(image)
|
29 |
|
30 |
+
return {'image': image_data}
|
31 |
|
32 |
+
def get_image_data(image_path):
|
33 |
+
image_input = image_to_tensor(image_path)
|
34 |
+
return image_input
|
35 |
|
36 |
def get_intervention_vector(selected_cells_bef, selected_cells_aft):
|
37 |
left = np.reshape(np.zeros((1, 14 * 14)), (14, 14))
|
|
|
59 |
|
60 |
return left_map, right_map
|
61 |
|
62 |
+
def _get_rawimage(image_path):
|
63 |
+
# Pair x L x T x 3 x H x W
|
64 |
+
image = np.zeros((1, 3, 224,
|
65 |
+
224), dtype=np.float)
|
66 |
|
67 |
+
for i in range(1):
|
68 |
|
69 |
+
raw_image_data = get_image_data(image_path)
|
70 |
+
raw_image_data = raw_image_data['image']
|
71 |
|
72 |
+
image[i] = raw_image_data
|
73 |
|
74 |
+
return image
|
75 |
|
76 |
|
77 |
+
def greedy_decode(model, tokenizer, video, video_mask, gt_left_map, gt_right_map):
|
78 |
+
visual_output, left_map, right_map = model.get_sequence_visual_output(video, video_mask,
|
79 |
+
gt_left_map[:, 0, :].squeeze(), gt_right_map[:, 0, :].squeeze())
|
80 |
|
81 |
+
video_mask = torch.ones(visual_output.shape[0], visual_output.shape[1], device=visual_output.device).long()
|
82 |
+
input_caption_ids = torch.zeros(visual_output.shape[0], device=visual_output.device).data.fill_(tokenizer.vocab["<|startoftext|>"])
|
83 |
+
input_caption_ids = input_caption_ids.long().unsqueeze(1)
|
84 |
+
decoder_mask = torch.ones_like(input_caption_ids)
|
85 |
+
for i in range(32):
|
86 |
+
decoder_scores = model.decoder_caption(visual_output, video_mask, input_caption_ids, decoder_mask, get_logits=True)
|
87 |
+
next_words = decoder_scores[:, -1].max(1)[1].unsqueeze(1)
|
88 |
+
input_caption_ids = torch.cat([input_caption_ids, next_words], 1)
|
89 |
+
next_mask = torch.ones_like(next_words)
|
90 |
+
decoder_mask = torch.cat([decoder_mask, next_mask], 1)
|
91 |
+
|
92 |
+
return input_caption_ids[:, 1:].tolist(), left_map, right_map
|
93 |
|
94 |
# Dummy prediction function
|
95 |
+
def predict_image(image_bef, image_aft, selected_cells_bef, selected_cells_aft):
|
96 |
+
if image_bef is None:
|
97 |
+
return "No image provided", "", ""
|
98 |
+
if image_aft is None:
|
99 |
+
return "No image provided", "", ""
|
100 |
|
101 |
|
102 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
103 |
|
104 |
+
model = init_model('data/pytorch_model.pt', device)
|
105 |
|
106 |
+
tokenizer = ClipTokenizer()
|
107 |
|
108 |
+
left_map, right_map = get_intervention_vector(selected_cells_bef, selected_cells_aft)
|
109 |
|
110 |
+
left_map, right_map = torch.from_numpy(left_map).unsqueeze(0), torch.from_numpy(right_map).unsqueeze(0)
|
111 |
|
112 |
+
bef_image = torch.from_numpy(_get_rawimage(image_bef)).unsqueeze(1)
|
113 |
+
aft_image = torch.from_numpy(_get_rawimage(image_aft)).unsqueeze(1)
|
114 |
|
115 |
+
image_pair = torch.cat([bef_image, aft_image], 1)
|
116 |
|
117 |
+
image_mask = torch.from_numpy(np.ones(2, dtype=np.long)).unsqueeze(0)
|
118 |
|
119 |
+
result_list, left_map, right_map = greedy_decode(model, tokenizer, image_pair, image_mask, left_map, right_map)
|
120 |
|
121 |
|
122 |
+
decode_text_list = tokenizer.convert_ids_to_tokens(result_list[0])
|
123 |
+
if "<|endoftext|>" in decode_text_list:
|
124 |
+
SEP_index = decode_text_list.index("<|endoftext|>")
|
125 |
+
decode_text_list = decode_text_list[:SEP_index]
|
126 |
+
if "!" in decode_text_list:
|
127 |
+
PAD_index = decode_text_list.index("!")
|
128 |
+
decode_text_list = decode_text_list[:PAD_index]
|
129 |
+
decode_text = decode_text_list.strip()
|
130 |
+
|
131 |
+
# Generate dummy predictions
|
132 |
+
pred = f"{decode_text}"
|
133 |
|
134 |
+
# Include information about selected cells
|
135 |
+
selected_info_bef = f"{selected_cells_bef}" if selected_cells_bef else "No image patch was selected"
|
136 |
+
selected_info_aft = f"{selected_cells_aft}" if selected_cells_aft else "No image patch was selected"
|
137 |
|
138 |
+
return pred, selected_info_bef, selected_info_aft
|
139 |
|
140 |
# Add grid to the image
|
141 |
def add_grid_to_image(image_path, grid_size=14):
|
|
|
297 |
html = gr.HTML(html_text)
|
298 |
|
299 |
# Connect the predict button to the prediction function
|
300 |
+
predict_btn.click(
|
301 |
+
fn=predict_image,
|
302 |
+
inputs=[image_bef, image_aft, selected_cells_bef, selected_cells_aft],
|
303 |
+
outputs=[prediction, selected_info_bef, selected_info_aft]
|
304 |
+
)
|
305 |
+
|
306 |
+
image_bef.change(
|
307 |
+
fn=None,
|
308 |
+
inputs=[image_bef],
|
309 |
+
outputs=[],
|
310 |
+
js="(image) => { initializeEditor(); importBackground(image); return []; }",
|
311 |
+
)
|
312 |
+
|
313 |
+
image_aft.change(
|
314 |
+
fn=None,
|
315 |
+
inputs=[image_aft],
|
316 |
+
outputs=[],
|
317 |
+
js="(image) => { initializeEditor(); importBackground(image); return []; }",
|
318 |
+
)
|
319 |
|
320 |
|
321 |
|