toshas commited on
Commit
0eead1f
·
1 Parent(s): bd654b5

fix user components not receiving updates, improve gallery passthrough of 16bit pngs, make thumbnails square

Browse files
gradio_dualvision/app_template.py CHANGED
@@ -32,6 +32,7 @@ from PIL import Image
32
  from gradio.components.base import Component
33
 
34
  from .gradio_patches.examples import Examples
 
35
  from .gradio_patches.imagesliderplus import ImageSliderPlus
36
  from .gradio_patches.radio import Radio
37
 
@@ -99,6 +100,8 @@ class DualVisionApp(gr.Blocks):
99
  self.key_original_image = key_original_image
100
  self.slider_position = slider_position
101
  self.input_keys = None
 
 
102
  self.left_selector_visible = left_selector_visible
103
  self.advanced_settings_can_be_half_width = advanced_settings_can_be_half_width
104
  if spaces_zero_gpu_enabled:
@@ -353,7 +356,12 @@ class DualVisionApp(gr.Blocks):
353
  )
354
  if any(k not in results_settings for k in self.input_keys):
355
  raise gr.Error(f"Mismatching setgings keys")
356
- results_settings = {k: results_settings[k] for k in self.input_keys}
 
 
 
 
 
357
 
358
  results_dict = {
359
  self.key_original_image: image_in,
@@ -440,7 +448,7 @@ class DualVisionApp(gr.Blocks):
440
  """
441
  self.make_header()
442
 
443
- results_state = gr.Gallery(visible=False, format="png")
444
 
445
  image_slider = self.make_slider()
446
 
@@ -598,6 +606,11 @@ class DualVisionApp(gr.Blocks):
598
  with gr.Row():
599
  btn_clear, btn_submit = self.make_buttons()
600
  self.input_keys = list(user_components.keys())
 
 
 
 
 
601
  return user_components, btn_clear, btn_submit
602
 
603
  def make_buttons(self):
 
32
  from gradio.components.base import Component
33
 
34
  from .gradio_patches.examples import Examples
35
+ from .gradio_patches.gallery import Gallery
36
  from .gradio_patches.imagesliderplus import ImageSliderPlus
37
  from .gradio_patches.radio import Radio
38
 
 
100
  self.key_original_image = key_original_image
101
  self.slider_position = slider_position
102
  self.input_keys = None
103
+ self.input_cls = None
104
+ self.input_kwargs = None
105
  self.left_selector_visible = left_selector_visible
106
  self.advanced_settings_can_be_half_width = advanced_settings_can_be_half_width
107
  if spaces_zero_gpu_enabled:
 
356
  )
357
  if any(k not in results_settings for k in self.input_keys):
358
  raise gr.Error(f"Mismatching setgings keys")
359
+ results_settings = {
360
+ k: cls(**ctor_args, value=results_settings[k])
361
+ for k, cls, ctor_args in zip(
362
+ self.input_keys, self.input_cls, self.input_kwargs
363
+ )
364
+ }
365
 
366
  results_dict = {
367
  self.key_original_image: image_in,
 
448
  """
449
  self.make_header()
450
 
451
+ results_state = Gallery(visible=False)
452
 
453
  image_slider = self.make_slider()
454
 
 
606
  with gr.Row():
607
  btn_clear, btn_submit = self.make_buttons()
608
  self.input_keys = list(user_components.keys())
609
+ self.input_cls = list(v.__class__ for v in user_components.values())
610
+ self.input_kwargs = [
611
+ {k: v for k, v in c.constructor_args.items() if k not in ("value")}
612
+ for c in user_components.values()
613
+ ]
614
  return user_components, btn_clear, btn_submit
615
 
616
  def make_buttons(self):
gradio_dualvision/gradio_patches/gallery.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from concurrent.futures import ThreadPoolExecutor
4
+ from gradio.components.gallery import (
5
+ GalleryImageType,
6
+ CaptionedGalleryImageType,
7
+ GalleryImage,
8
+ GalleryData,
9
+ )
10
+ from pathlib import Path
11
+ from urllib.parse import urlparse
12
+
13
+ import gradio
14
+ import numpy as np
15
+ import PIL.Image
16
+ from gradio_client.utils import is_http_url_like
17
+
18
+ from gradio import processing_utils, utils, wasm_utils
19
+ from gradio.data_classes import FileData
20
+
21
+
22
+ class Gallery(gradio.Gallery):
23
+ def postprocess(
24
+ self,
25
+ value: list[GalleryImageType | CaptionedGalleryImageType] | None,
26
+ ) -> GalleryData:
27
+ """
28
+ This is a patched version of the original function, wherein the format for PIL is computed based on the data type:
29
+ format = "png" if img.mode == "I;16" else "webp"
30
+ """
31
+ if value is None:
32
+ return GalleryData(root=[])
33
+ output = []
34
+
35
+ def _save(img):
36
+ url = None
37
+ caption = None
38
+ orig_name = None
39
+ if isinstance(img, (tuple, list)):
40
+ img, caption = img
41
+ if isinstance(img, np.ndarray):
42
+ file = processing_utils.save_img_array_to_cache(
43
+ img, cache_dir=self.GRADIO_CACHE, format=self.format
44
+ )
45
+ file_path = str(utils.abspath(file))
46
+ elif isinstance(img, PIL.Image.Image):
47
+ format = "png" if img.mode == "I;16" else "webp"
48
+ file = processing_utils.save_pil_to_cache(
49
+ img, cache_dir=self.GRADIO_CACHE, format=format
50
+ )
51
+ file_path = str(utils.abspath(file))
52
+ elif isinstance(img, str):
53
+ file_path = img
54
+ if is_http_url_like(img):
55
+ url = img
56
+ orig_name = Path(urlparse(img).path).name
57
+ else:
58
+ url = None
59
+ orig_name = Path(img).name
60
+ elif isinstance(img, Path):
61
+ file_path = str(img)
62
+ orig_name = img.name
63
+ else:
64
+ raise ValueError(f"Cannot process type as image: {type(img)}")
65
+ return GalleryImage(
66
+ image=FileData(path=file_path, url=url, orig_name=orig_name),
67
+ caption=caption,
68
+ )
69
+
70
+ if wasm_utils.IS_WASM:
71
+ for img in value:
72
+ output.append(_save(img))
73
+ else:
74
+ with ThreadPoolExecutor() as executor:
75
+ for o in executor.map(_save, value):
76
+ output.append(o)
77
+ return GalleryData(root=output)
gradio_dualvision/gradio_patches/imagesliderplus.py CHANGED
@@ -49,7 +49,7 @@ class ImageSliderPlus(ImageSlider):
49
  data_model = ImageSliderPlusData
50
 
51
  def as_example(self, value):
52
- return self.process_example_dims(value, 256)
53
 
54
  def _format_image(self, im: Image):
55
  if self.type != "filepath":
@@ -117,15 +117,23 @@ class ImageSliderPlus(ImageSlider):
117
  return out_0, out_1
118
 
119
  @staticmethod
120
- def resize_and_save(image_path: str, max_dim: int) -> str:
121
  img = Image.open(image_path).convert("RGB")
 
 
 
 
 
 
 
 
122
  img.thumbnail((max_dim, max_dim))
123
  temp_file = tempfile.NamedTemporaryFile(suffix=".webp", delete=False)
124
  img.save(temp_file.name, "WEBP")
125
  return temp_file.name
126
 
127
  def process_example_dims(
128
- self, input_data: tuple[str | Path | None] | None, max_dim: Optional[int] = None
129
  ) -> image_tuple:
130
  if input_data is None:
131
  return None
@@ -134,8 +142,8 @@ class ImageSliderPlus(ImageSlider):
134
  return input_data[0]
135
  if max_dim is not None:
136
  input_data = (
137
- self.resize_and_save(input_data[0], max_dim),
138
- self.resize_and_save(input_data[1], max_dim),
139
  )
140
  return (
141
  self.move_resource_to_block_cache(input_data[0]),
 
49
  data_model = ImageSliderPlusData
50
 
51
  def as_example(self, value):
52
+ return self.process_example_dims(value, 256, True)
53
 
54
  def _format_image(self, im: Image):
55
  if self.type != "filepath":
 
117
  return out_0, out_1
118
 
119
  @staticmethod
120
+ def resize_and_save(image_path: str, max_dim: int, square: bool = False) -> str:
121
  img = Image.open(image_path).convert("RGB")
122
+ if square:
123
+ width, height = img.size
124
+ min_side = min(width, height)
125
+ left = (width - min_side) // 2
126
+ top = (height - min_side) // 2
127
+ right = left + min_side
128
+ bottom = top + min_side
129
+ img = img.crop((left, top, right, bottom))
130
  img.thumbnail((max_dim, max_dim))
131
  temp_file = tempfile.NamedTemporaryFile(suffix=".webp", delete=False)
132
  img.save(temp_file.name, "WEBP")
133
  return temp_file.name
134
 
135
  def process_example_dims(
136
+ self, input_data: tuple[str | Path | None] | None, max_dim: Optional[int] = None, square: bool = False
137
  ) -> image_tuple:
138
  if input_data is None:
139
  return None
 
142
  return input_data[0]
143
  if max_dim is not None:
144
  input_data = (
145
+ self.resize_and_save(input_data[0], max_dim, square),
146
+ self.resize_and_save(input_data[1], max_dim, square),
147
  )
148
  return (
149
  self.move_resource_to_block_cache(input_data[0]),