ondrejbiza commited on
Commit
1e7763d
·
1 Parent(s): 8530ae8

Fix global state bug.

Browse files
Files changed (1) hide show
  1. app.py +53 -47
app.py CHANGED
@@ -1,12 +1,14 @@
1
  import os
 
2
 
3
  from clu import checkpoint
 
4
  import gradio as gr
 
5
  import jax
6
  import jax.numpy as jnp
7
  import numpy as np
8
  from PIL import Image
9
- from huggingface_hub import snapshot_download
10
 
11
  from invariant_slot_attention.configs.clevr_with_masks.equiv_transl_scale import get_config
12
  from invariant_slot_attention.lib import utils
@@ -14,7 +16,6 @@ from invariant_slot_attention.lib import utils
14
 
15
  def load_model(config, checkpoint_dir):
16
  rng = jax.random.PRNGKey(42)
17
- rng, data_rng = jax.random.split(rng)
18
 
19
  # Initialize model
20
  model = utils.build_model_from_config(config.model)
@@ -55,10 +56,9 @@ def load_image(name):
55
  img = Image.open(f"images/{name}.png")
56
  img = img.crop((64, 29, 64 + 192, 29 + 192))
57
  img = img.resize((128, 128))
58
- img_ = np.array(img)
59
  img = np.array(img)[:, :, :3] / 255.
60
  img = jnp.array(img, dtype=jnp.float32)
61
- return img, img_
62
 
63
 
64
  download_path = snapshot_download(repo_id="ondrejbiza/isa")
@@ -68,8 +68,7 @@ model, state, rng = load_model(get_config(), checkpoint_dir)
68
 
69
  rng, init_rng = jax.random.split(rng, num=2)
70
 
71
- from flax import linen as nn
72
- from typing import Callable
73
  class DecoderWrapper(nn.Module):
74
  decoder: Callable[[], nn.Module]
75
  @nn.compact
@@ -77,17 +76,12 @@ class DecoderWrapper(nn.Module):
77
  return self.decoder()(slots, train)
78
  decoder_model = DecoderWrapper(decoder=model.decoder)
79
 
80
- slots = np.zeros((11, 64), dtype=np.float32)
81
- pos = np.zeros((11, 2), dtype=np.float32)
82
- scale = np.zeros((11, 2), dtype=np.float32)
83
- probs = np.zeros((11, 128, 128), dtype=np.float32)
84
-
85
  with gr.Blocks() as demo:
86
 
87
- # work in progress
88
- # with gr.Row():
89
- # gr_gallery = gr.Gallery(value=[f"images/img{i}.png" for i in range(1, 9)])
90
- # gr_gallery = gr_gallery.style(columns=[3], rows=[3], object_fit="contain", height="auto")
91
 
92
  with gr.Row():
93
 
@@ -116,89 +110,101 @@ with gr.Blocks() as demo:
116
  def update_image_and_segmentation(name, idx):
117
  idx = idx - 1
118
 
119
- img_input, _ = load_image(name)
120
  out = model.apply(
121
  {"params": state.params, **state.variables},
122
  video=img_input[None, None],
123
  rngs={"state_init": init_rng},
124
  train=False)
125
 
126
- probs[:] = nn.softmax(out["outputs"]["segmentation_logits"][0, 0, :, :, :, 0], axis=0)
127
  img = np.array(out["outputs"]["video"][0, 0])
128
  img = np.clip(img, 0, 1)
129
 
130
- slots_ = out["states"]
131
- slots[:] = slots_[0, 0, :, :-4]
132
- pos[:] = slots_[0, 0, :, -4: -2]
133
- scale[:] = slots_[0, 0, :, -2:]
134
 
135
  return (img * 255).astype(np.uint8), (probs[idx] * 255).astype(np.uint8), float(pos[idx, 0]), \
136
- float(pos[idx, 1]), float(scale[idx, 0]), float(scale[idx, 1])
137
 
138
  gr_choose_image.change(
139
  fn=update_image_and_segmentation,
140
  inputs=[gr_choose_image, gr_slot_slider],
141
- outputs=[gr_image_1, gr_image_2, gr_x_slider, gr_y_slider, gr_sx_slider, gr_sy_slider]
 
142
  )
143
 
144
- def update_sliders(idx):
145
  idx = idx - 1 # 1-indexing to 0-indexing
146
- return (probs[idx] * 255).astype(np.uint8), float(pos[idx, 0]), \
147
- float(pos[idx, 1]), float(scale[idx, 0]), float(scale[idx, 1])
148
 
149
  gr_slot_slider.change(
150
  fn=update_sliders,
151
- inputs=gr_slot_slider,
152
  outputs=[gr_image_2, gr_x_slider, gr_y_slider, gr_sx_slider, gr_sy_slider]
153
  )
154
 
155
- def update_pos_x(idx, val):
156
- pos[idx - 1, 0] = val
157
- def update_pos_y(idx, val):
158
- pos[idx - 1, 1] = val
159
- def update_scale_x(idx, val):
160
- scale[idx - 1, 0] = val
161
- def update_scale_y(idx, val):
162
- scale[idx - 1, 1] = val
 
 
 
 
 
 
 
163
 
164
  gr_x_slider.change(
165
  fn=update_pos_x,
166
- inputs=[gr_slot_slider, gr_x_slider]
 
167
  )
168
  gr_y_slider.change(
169
  fn=update_pos_y,
170
- inputs=[gr_slot_slider, gr_y_slider]
 
171
  )
172
  gr_sx_slider.change(
173
  fn=update_scale_x,
174
- inputs=[gr_slot_slider, gr_sx_slider]
 
175
  )
176
  gr_sy_slider.change(
177
  fn=update_scale_y,
178
- inputs=[gr_slot_slider, gr_sy_slider]
 
179
  )
180
 
181
- def render(idx):
182
  idx = idx - 1
183
 
184
- slots_ = np.concatenate([slots, pos, scale], axis=-1)
185
- slots_ = jnp.array(slots_)
186
 
187
  out = decoder_model.apply(
188
  {"params": state.params, **state.variables},
189
- slots=slots_[None, None],
190
  train=False
191
  )
192
 
193
- probs[:] = nn.softmax(out["segmentation_logits"][0, 0, :, :, :, 0], axis=0)
194
  image = np.array(out["video"][0, 0])
195
  image = np.clip(image, 0, 1)
196
- return (image * 255).astype(np.uint8), (probs[idx] * 255).astype(np.uint8)
197
 
198
  gr_button.click(
199
  fn=render,
200
- inputs=gr_slot_slider,
201
- outputs=[gr_image_1, gr_image_2]
202
  )
203
 
204
  demo.launch()
 
1
  import os
2
+ from typing import Callable
3
 
4
  from clu import checkpoint
5
+ from flax import linen as nn
6
  import gradio as gr
7
+ from huggingface_hub import snapshot_download
8
  import jax
9
  import jax.numpy as jnp
10
  import numpy as np
11
  from PIL import Image
 
12
 
13
  from invariant_slot_attention.configs.clevr_with_masks.equiv_transl_scale import get_config
14
  from invariant_slot_attention.lib import utils
 
16
 
17
  def load_model(config, checkpoint_dir):
18
  rng = jax.random.PRNGKey(42)
 
19
 
20
  # Initialize model
21
  model = utils.build_model_from_config(config.model)
 
56
  img = Image.open(f"images/{name}.png")
57
  img = img.crop((64, 29, 64 + 192, 29 + 192))
58
  img = img.resize((128, 128))
 
59
  img = np.array(img)[:, :, :3] / 255.
60
  img = jnp.array(img, dtype=jnp.float32)
61
+ return img
62
 
63
 
64
  download_path = snapshot_download(repo_id="ondrejbiza/isa")
 
68
 
69
  rng, init_rng = jax.random.split(rng, num=2)
70
 
71
+
 
72
  class DecoderWrapper(nn.Module):
73
  decoder: Callable[[], nn.Module]
74
  @nn.compact
 
76
  return self.decoder()(slots, train)
77
  decoder_model = DecoderWrapper(decoder=model.decoder)
78
 
 
 
 
 
 
79
  with gr.Blocks() as demo:
80
 
81
+ local_slots = gr.State(np.zeros((11, 64), dtype=np.float32))
82
+ local_pos = gr.State(np.zeros((11, 2), dtype=np.float32))
83
+ local_scale = gr.State(np.zeros((11, 2), dtype=np.float32))
84
+ local_probs = gr.State(np.zeros((11, 128, 128), dtype=np.float32))
85
 
86
  with gr.Row():
87
 
 
110
  def update_image_and_segmentation(name, idx):
111
  idx = idx - 1
112
 
113
+ img_input = load_image(name)
114
  out = model.apply(
115
  {"params": state.params, **state.variables},
116
  video=img_input[None, None],
117
  rngs={"state_init": init_rng},
118
  train=False)
119
 
120
+ probs = np.array(nn.softmax(out["outputs"]["segmentation_logits"][0, 0, :, :, :, 0], axis=0))
121
  img = np.array(out["outputs"]["video"][0, 0])
122
  img = np.clip(img, 0, 1)
123
 
124
+ slots_ = np.array(out["states"])
125
+ slots = slots_[0, 0, :, :-4]
126
+ pos = slots_[0, 0, :, -4: -2]
127
+ scale = slots_[0, 0, :, -2:]
128
 
129
  return (img * 255).astype(np.uint8), (probs[idx] * 255).astype(np.uint8), float(pos[idx, 0]), \
130
+ float(pos[idx, 1]), float(scale[idx, 0]), float(scale[idx, 1]), probs, slots, pos, scale
131
 
132
  gr_choose_image.change(
133
  fn=update_image_and_segmentation,
134
  inputs=[gr_choose_image, gr_slot_slider],
135
+ outputs=[gr_image_1, gr_image_2, gr_x_slider, gr_y_slider, gr_sx_slider, gr_sy_slider,
136
+ local_probs, local_slots, local_pos, local_scale]
137
  )
138
 
139
+ def update_sliders(idx, local_probs, local_pos, local_scale):
140
  idx = idx - 1 # 1-indexing to 0-indexing
141
+ return (local_probs[idx] * 255).astype(np.uint8), float(local_pos[idx, 0]), \
142
+ float(local_pos[idx, 1]), float(local_scale[idx, 0]), float(local_scale[idx, 1])
143
 
144
  gr_slot_slider.change(
145
  fn=update_sliders,
146
+ inputs=[gr_slot_slider, local_probs, local_pos, local_scale],
147
  outputs=[gr_image_2, gr_x_slider, gr_y_slider, gr_sx_slider, gr_sy_slider]
148
  )
149
 
150
+ def update_pos_x(idx, val, local_pos):
151
+ local_pos[idx - 1, 0] = val
152
+ return local_pos
153
+
154
+ def update_pos_y(idx, val, local_pos):
155
+ local_pos[idx - 1, 1] = val
156
+ return local_pos
157
+
158
+ def update_scale_x(idx, val, local_scale):
159
+ local_scale[idx - 1, 0] = val
160
+ return local_scale
161
+
162
+ def update_scale_y(idx, val, local_scale):
163
+ local_scale[idx - 1, 1] = val
164
+ return local_scale
165
 
166
  gr_x_slider.change(
167
  fn=update_pos_x,
168
+ inputs=[gr_slot_slider, gr_x_slider, local_pos],
169
+ outputs=local_pos
170
  )
171
  gr_y_slider.change(
172
  fn=update_pos_y,
173
+ inputs=[gr_slot_slider, gr_y_slider, local_pos],
174
+ outputs=local_pos
175
  )
176
  gr_sx_slider.change(
177
  fn=update_scale_x,
178
+ inputs=[gr_slot_slider, gr_sx_slider, local_scale],
179
+ outputs=local_scale
180
  )
181
  gr_sy_slider.change(
182
  fn=update_scale_y,
183
+ inputs=[gr_slot_slider, gr_sy_slider, local_scale],
184
+ outputs=local_scale
185
  )
186
 
187
+ def render(idx, local_slots, local_pos, local_scale):
188
  idx = idx - 1
189
 
190
+ slots = np.concatenate([local_slots, local_pos, local_scale], axis=-1)
191
+ slots = jnp.array(slots)
192
 
193
  out = decoder_model.apply(
194
  {"params": state.params, **state.variables},
195
+ slots=slots[None, None],
196
  train=False
197
  )
198
 
199
+ probs = np.array(nn.softmax(out["segmentation_logits"][0, 0, :, :, :, 0], axis=0))
200
  image = np.array(out["video"][0, 0])
201
  image = np.clip(image, 0, 1)
202
+ return (image * 255).astype(np.uint8), (probs[idx] * 255).astype(np.uint8), probs
203
 
204
  gr_button.click(
205
  fn=render,
206
+ inputs=[gr_slot_slider, local_slots, local_pos, local_scale],
207
+ outputs=[gr_image_1, gr_image_2, local_probs]
208
  )
209
 
210
  demo.launch()