pooyanrg commited on
Commit
ad6a1d7
·
1 Parent(s): bd9e8a6
__pycache__/app.cpython-312.pyc ADDED
Binary file (7.98 kB). View file
 
__pycache__/fileservice.cpython-312.pyc ADDED
Binary file (1.98 kB). View file
 
app.py CHANGED
@@ -1,10 +1,10 @@
1
  import gradio as gr
2
  import numpy as np
3
  from PIL import Image, ImageDraw
4
- import torch
5
- from torchvision.transforms import Compose, Resize, ToTensor, Normalize
6
- from utils.model import init_model
7
- from utils.tokenization_clip import SimpleTokenizer as ClipTokenizer
8
 
9
  from fastapi.staticfiles import StaticFiles
10
  from fileservice import app
@@ -16,22 +16,22 @@ html_text = """
16
  </div>
17
  """
18
 
19
- def image_to_tensor(image_path):
20
- image = Image.open(image_path).convert('RGB')
21
 
22
- preprocess = Compose([
23
- Resize([224, 224], interpolation=Image.BICUBIC),
24
- lambda image: image.convert("RGB"),
25
- ToTensor(),
26
- Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
27
- ])
28
- image_data = preprocess(image)
29
 
30
- return {'image': image_data}
31
 
32
- def get_image_data(image_path):
33
- image_input = image_to_tensor(image_path)
34
- return image_input
35
 
36
  def get_intervention_vector(selected_cells_bef, selected_cells_aft):
37
  left = np.reshape(np.zeros((1, 14 * 14)), (14, 14))
@@ -59,83 +59,83 @@ def get_intervention_vector(selected_cells_bef, selected_cells_aft):
59
 
60
  return left_map, right_map
61
 
62
- def _get_rawimage(image_path):
63
- # Pair x L x T x 3 x H x W
64
- image = np.zeros((1, 3, 224,
65
- 224), dtype=np.float)
66
 
67
- for i in range(1):
68
 
69
- raw_image_data = get_image_data(image_path)
70
- raw_image_data = raw_image_data['image']
71
 
72
- image[i] = raw_image_data
73
 
74
- return image
75
 
76
 
77
- def greedy_decode(model, tokenizer, video, video_mask, gt_left_map, gt_right_map):
78
- visual_output, left_map, right_map = model.get_sequence_visual_output(video, video_mask,
79
- gt_left_map[:, 0, :].squeeze(), gt_right_map[:, 0, :].squeeze())
80
 
81
- video_mask = torch.ones(visual_output.shape[0], visual_output.shape[1], device=visual_output.device).long()
82
- input_caption_ids = torch.zeros(visual_output.shape[0], device=visual_output.device).data.fill_(tokenizer.vocab["<|startoftext|>"])
83
- input_caption_ids = input_caption_ids.long().unsqueeze(1)
84
- decoder_mask = torch.ones_like(input_caption_ids)
85
- for i in range(32):
86
- decoder_scores = model.decoder_caption(visual_output, video_mask, input_caption_ids, decoder_mask, get_logits=True)
87
- next_words = decoder_scores[:, -1].max(1)[1].unsqueeze(1)
88
- input_caption_ids = torch.cat([input_caption_ids, next_words], 1)
89
- next_mask = torch.ones_like(next_words)
90
- decoder_mask = torch.cat([decoder_mask, next_mask], 1)
91
-
92
- return input_caption_ids[:, 1:].tolist(), left_map, right_map
93
 
94
  # Dummy prediction function
95
- def predict_image(image_bef, image_aft, selected_cells_bef, selected_cells_aft):
96
- if image_bef is None:
97
- return "No image provided", "", ""
98
- if image_aft is None:
99
- return "No image provided", "", ""
100
 
101
 
102
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
103
 
104
- model = init_model('data/pytorch_model.pt', device)
105
 
106
- tokenizer = ClipTokenizer()
107
 
108
- left_map, right_map = get_intervention_vector(selected_cells_bef, selected_cells_aft)
109
 
110
- left_map, right_map = torch.from_numpy(left_map).unsqueeze(0), torch.from_numpy(right_map).unsqueeze(0)
111
 
112
- bef_image = torch.from_numpy(_get_rawimage(image_bef)).unsqueeze(1)
113
- aft_image = torch.from_numpy(_get_rawimage(image_aft)).unsqueeze(1)
114
 
115
- image_pair = torch.cat([bef_image, aft_image], 1)
116
 
117
- image_mask = torch.from_numpy(np.ones(2, dtype=np.long)).unsqueeze(0)
118
 
119
- result_list, left_map, right_map = greedy_decode(model, tokenizer, image_pair, image_mask, left_map, right_map)
120
 
121
 
122
- decode_text_list = tokenizer.convert_ids_to_tokens(result_list[0])
123
- if "<|endoftext|>" in decode_text_list:
124
- SEP_index = decode_text_list.index("<|endoftext|>")
125
- decode_text_list = decode_text_list[:SEP_index]
126
- if "!" in decode_text_list:
127
- PAD_index = decode_text_list.index("!")
128
- decode_text_list = decode_text_list[:PAD_index]
129
- decode_text = decode_text_list.strip()
130
-
131
- # Generate dummy predictions
132
- pred = f"{decode_text}"
133
 
134
- # Include information about selected cells
135
- selected_info_bef = f"{selected_cells_bef}" if selected_cells_bef else "No image patch was selected"
136
- selected_info_aft = f"{selected_cells_aft}" if selected_cells_aft else "No image patch was selected"
137
 
138
- return pred, selected_info_bef, selected_info_aft
139
 
140
  # Add grid to the image
141
  def add_grid_to_image(image_path, grid_size=14):
@@ -297,11 +297,11 @@ with gr.Blocks() as demo:
297
  html = gr.HTML(html_text)
298
 
299
  # Connect the predict button to the prediction function
300
- predict_btn.click(
301
- fn=predict_image,
302
- inputs=[image_bef, image_aft, selected_cells_bef, selected_cells_aft],
303
- outputs=[prediction, selected_info_bef, selected_info_aft]
304
- )
305
 
306
  # image_bef.change(
307
  # fn=None,
 
1
  import gradio as gr
2
  import numpy as np
3
  from PIL import Image, ImageDraw
4
+ # import torch
5
+ # from torchvision.transforms import Compose, Resize, ToTensor, Normalize
6
+ # from utils.model import init_model
7
+ # from utils.tokenization_clip import SimpleTokenizer as ClipTokenizer
8
 
9
  from fastapi.staticfiles import StaticFiles
10
  from fileservice import app
 
16
  </div>
17
  """
18
 
19
+ # def image_to_tensor(image_path):
20
+ # image = Image.open(image_path).convert('RGB')
21
 
22
+ # preprocess = Compose([
23
+ # Resize([224, 224], interpolation=Image.BICUBIC),
24
+ # lambda image: image.convert("RGB"),
25
+ # ToTensor(),
26
+ # Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
27
+ # ])
28
+ # image_data = preprocess(image)
29
 
30
+ # return {'image': image_data}
31
 
32
+ # def get_image_data(image_path):
33
+ # image_input = image_to_tensor(image_path)
34
+ # return image_input
35
 
36
  def get_intervention_vector(selected_cells_bef, selected_cells_aft):
37
  left = np.reshape(np.zeros((1, 14 * 14)), (14, 14))
 
59
 
60
  return left_map, right_map
61
 
62
+ # def _get_rawimage(image_path):
63
+ # # Pair x L x T x 3 x H x W
64
+ # image = np.zeros((1, 3, 224,
65
+ # 224), dtype=np.float)
66
 
67
+ # for i in range(1):
68
 
69
+ # raw_image_data = get_image_data(image_path)
70
+ # raw_image_data = raw_image_data['image']
71
 
72
+ # image[i] = raw_image_data
73
 
74
+ # return image
75
 
76
 
77
+ # def greedy_decode(model, tokenizer, video, video_mask, gt_left_map, gt_right_map):
78
+ # visual_output, left_map, right_map = model.get_sequence_visual_output(video, video_mask,
79
+ # gt_left_map[:, 0, :].squeeze(), gt_right_map[:, 0, :].squeeze())
80
 
81
+ # video_mask = torch.ones(visual_output.shape[0], visual_output.shape[1], device=visual_output.device).long()
82
+ # input_caption_ids = torch.zeros(visual_output.shape[0], device=visual_output.device).data.fill_(tokenizer.vocab["<|startoftext|>"])
83
+ # input_caption_ids = input_caption_ids.long().unsqueeze(1)
84
+ # decoder_mask = torch.ones_like(input_caption_ids)
85
+ # for i in range(32):
86
+ # decoder_scores = model.decoder_caption(visual_output, video_mask, input_caption_ids, decoder_mask, get_logits=True)
87
+ # next_words = decoder_scores[:, -1].max(1)[1].unsqueeze(1)
88
+ # input_caption_ids = torch.cat([input_caption_ids, next_words], 1)
89
+ # next_mask = torch.ones_like(next_words)
90
+ # decoder_mask = torch.cat([decoder_mask, next_mask], 1)
91
+
92
+ # return input_caption_ids[:, 1:].tolist(), left_map, right_map
93
 
94
  # Dummy prediction function
95
+ # def predict_image(image_bef, image_aft, selected_cells_bef, selected_cells_aft):
96
+ # if image_bef is None:
97
+ # return "No image provided", "", ""
98
+ # if image_aft is None:
99
+ # return "No image provided", "", ""
100
 
101
 
102
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
103
 
104
+ # model = init_model('data/pytorch_model.pt', device)
105
 
106
+ # tokenizer = ClipTokenizer()
107
 
108
+ # left_map, right_map = get_intervention_vector(selected_cells_bef, selected_cells_aft)
109
 
110
+ # left_map, right_map = torch.from_numpy(left_map).unsqueeze(0), torch.from_numpy(right_map).unsqueeze(0)
111
 
112
+ # bef_image = torch.from_numpy(_get_rawimage(image_bef)).unsqueeze(1)
113
+ # aft_image = torch.from_numpy(_get_rawimage(image_aft)).unsqueeze(1)
114
 
115
+ # image_pair = torch.cat([bef_image, aft_image], 1)
116
 
117
+ # image_mask = torch.from_numpy(np.ones(2, dtype=np.long)).unsqueeze(0)
118
 
119
+ # result_list, left_map, right_map = greedy_decode(model, tokenizer, image_pair, image_mask, left_map, right_map)
120
 
121
 
122
+ # decode_text_list = tokenizer.convert_ids_to_tokens(result_list[0])
123
+ # if "<|endoftext|>" in decode_text_list:
124
+ # SEP_index = decode_text_list.index("<|endoftext|>")
125
+ # decode_text_list = decode_text_list[:SEP_index]
126
+ # if "!" in decode_text_list:
127
+ # PAD_index = decode_text_list.index("!")
128
+ # decode_text_list = decode_text_list[:PAD_index]
129
+ # decode_text = decode_text_list.strip()
130
+
131
+ # # Generate dummy predictions
132
+ # pred = f"{decode_text}"
133
 
134
+ # # Include information about selected cells
135
+ # selected_info_bef = f"{selected_cells_bef}" if selected_cells_bef else "No image patch was selected"
136
+ # selected_info_aft = f"{selected_cells_aft}" if selected_cells_aft else "No image patch was selected"
137
 
138
+ # return pred, selected_info_bef, selected_info_aft
139
 
140
  # Add grid to the image
141
  def add_grid_to_image(image_path, grid_size=14):
 
297
  html = gr.HTML(html_text)
298
 
299
  # Connect the predict button to the prediction function
300
+ # predict_btn.click(
301
+ # fn=predict_image,
302
+ # inputs=[image_bef, image_aft, selected_cells_bef, selected_cells_aft],
303
+ # outputs=[prediction, selected_info_bef, selected_info_aft]
304
+ # )
305
 
306
  # image_bef.change(
307
  # fn=None,
fileservice.py CHANGED
@@ -1,5 +1,7 @@
1
  from fastapi import FastAPI, Request, Response
2
 
 
 
3
  filenames = ["js/interactive_grid.js"]
4
  contents = "\n".join(
5
  [f"<script type='text/javascript' src='{x}'></script>" for x in filenames]
@@ -19,6 +21,9 @@ ga_script = """
19
 
20
  app = FastAPI()
21
 
 
 
 
22
 
23
  @app.middleware("http")
24
  async def insert_js(request: Request, call_next):
 
1
  from fastapi import FastAPI, Request, Response
2
 
3
+ from pydantic import BaseModel, ConfigDict
4
+
5
  filenames = ["js/interactive_grid.js"]
6
  contents = "\n".join(
7
  [f"<script type='text/javascript' src='{x}'></script>" for x in filenames]
 
21
 
22
  app = FastAPI()
23
 
24
+ class CustomConfig(BaseModel):
25
+ model_config = ConfigDict(arbitrary_types_allowed=True)
26
+
27
 
28
  @app.middleware("http")
29
  async def insert_js(request: Request, call_next):
requirements.txt CHANGED
@@ -16,8 +16,8 @@ torch
16
  torchvision
17
  torchaudio
18
  tqdm==4.67.1
19
- fastapi>=0.100.0
20
- pydantic>=2.0.0
21
  uvicorn[standard]
22
  faiss-cpu==1.7.2
23
 
 
16
  torchvision
17
  torchaudio
18
  tqdm==4.67.1
19
+ fastapi
20
+ pydantic
21
  uvicorn[standard]
22
  faiss-cpu==1.7.2
23