juliuse commited on
Commit
a7169e0
·
1 Parent(s): d92876a

import flair fix

Browse files
app.py CHANGED
@@ -14,8 +14,8 @@ project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "."))
14
  if project_root not in sys.path:
15
  sys.path.insert(0, project_root)
16
 
17
- from flair.pipelines import model_loader
18
- from flair import var_post_samp, degradations
19
 
20
 
21
 
@@ -283,7 +283,7 @@ def inpaint_image(image_editor_output, prompt_text, fixed_seed_value, use_random
283
 
284
  current_seed = None
285
  if use_random_seed:
286
- current_seed = None # load_config_for_inference will generate a random seed
287
  else:
288
  try:
289
  current_seed = int(fixed_seed_value)
@@ -514,15 +514,16 @@ if os.path.exists(demo_images_dir):
514
  metadata = json.load(f)
515
  task = metadata.get("task_type")
516
  prompt = metadata.get("prompt", "")
 
517
  if task == "Super Resolution":
518
- example_list_sr.append([image_path, prompt, task])
519
  else:
520
  image_editor_input = {
521
  "background": image_path,
522
  "layers": [mask_path],
523
  "composite": None # Add this key to satisfy ImageEditor's as_example processing
524
  }
525
- example_list_inp.append([image_editor_input, prompt, task])
526
 
527
  # Structure for ImageEditor: { "background": filepath, "layers": [filepath], "composite": None }
528
 
@@ -553,10 +554,14 @@ if __name__ == "__main__":
553
  sys.exit(1)
554
 
555
  # --- Define Gradio UI using gr.Blocks after globals are initialized ---
556
- title_str = "Solving Inverse Problems with FLAIR: Inpainting Demo"
557
  description_str = """
558
  Select a task (Inpainting or Super Resolution) and upload an image.
559
- For Inpainting, draw a mask on the image to specify the area to be filled. We observed that our model can event solve simple editing task, if provided with an appropriate prompt. For large masks the step size might need to be adjusted to e.g. 80.
 
 
 
 
560
  For Super Resolution, upload a low-resolution image and select the upscaling factor. Images are always upscaled to 768x768 pixels. Therefore, for x12 superresolution, the input image must be 64x64 pixels. You can also upload a high resolution image which will be downscaled to the correct input size.
561
  Use the slider to compare the low resolution input image with the super-resolved output.
562
 
@@ -729,43 +734,55 @@ Use the slider to compare the low resolution input image with the super-resolved
729
 
730
  gr.Markdown("---") # Separator
731
  gr.Markdown("### Click an example to load:")
732
- def load_example(example_data, prompt, task):
733
- image_editor_input = example_data[0]
734
- prompt_value = example_data[1]
735
- if task == "Inpainting":
736
- image_editor.clear() # Clear current image and mask
737
- if image_editor_input and image_editor_input.get("background"):
738
- image_editor.upload_image(image_editor_input["background"])
739
- if image_editor_input and image_editor_input.get("layers"):
740
- for layer in image_editor_input["layers"]:
741
- image_editor.upload_mask(layer)
742
- elif task == "Super Resolution":
743
- image_input.clear()
744
- image_input.upload_image(image_editor_input)
745
-
746
- # Set the prompt
747
- prompt_text.value = prompt_value
748
- # Optionally, set a random seed and guidance scale
749
- seed_slider.value = random.randint(0, 2**32 - 1)
750
- guidance_scale_slider.value = default_guidance_scale
751
- # Set the task selector from the example
752
- task_selector.set_value(task)
753
- update_visibility(task) # Update visibility based on task
754
 
755
- with gr.Row():
756
- gr.Examples(
757
- examples=example_list_sr,
758
- inputs=[image_input, prompt_text, task_selector],
759
- label="Super Resolution Examples",
760
- fn=load_example,
761
- )
762
- with gr.Row():
763
- gr.Examples(
764
- examples=example_list_inp,
765
- inputs=[image_editor, prompt_text, task_selector],
766
- label="Inpainting Examples",
767
- fn=load_example,
768
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
769
 
770
  # --- End of Gradio UI definition ---
771
 
 
14
  if project_root not in sys.path:
15
  sys.path.insert(0, project_root)
16
 
17
+ from src.flair.pipelines import model_loader
18
+ from src.flair import var_post_samp, degradations
19
 
20
 
21
 
 
283
 
284
  current_seed = None
285
  if use_random_seed:
286
+ current_seed = random.randint(0, 2**32 - 1)
287
  else:
288
  try:
289
  current_seed = int(fixed_seed_value)
 
514
  metadata = json.load(f)
515
  task = metadata.get("task_type")
516
  prompt = metadata.get("prompt", "")
517
+ n_steps = metadata.get("num_steps", 50)
518
  if task == "Super Resolution":
519
+ example_list_sr.append([image_path, prompt, task, n_steps])
520
  else:
521
  image_editor_input = {
522
  "background": image_path,
523
  "layers": [mask_path],
524
  "composite": None # Add this key to satisfy ImageEditor's as_example processing
525
  }
526
+ example_list_inp.append([image_editor_input, prompt, task, n_steps])
527
 
528
  # Structure for ImageEditor: { "background": filepath, "layers": [filepath], "composite": None }
529
 
 
554
  sys.exit(1)
555
 
556
  # --- Define Gradio UI using gr.Blocks after globals are initialized ---
557
+ title_str = "Solving Inverse Problems with FLAIR"
558
  description_str = """
559
  Select a task (Inpainting or Super Resolution) and upload an image.
560
+
561
+ For Inpainting, draw a mask on the image to specify the area to be filled.
562
+
563
+ We observed that our model can event solve simple editing task, if provided with an appropriate prompt. For large masks the step size might need to be adjusted to e.g. 80.
564
+
565
  For Super Resolution, upload a low-resolution image and select the upscaling factor. Images are always upscaled to 768x768 pixels. Therefore, for x12 superresolution, the input image must be 64x64 pixels. You can also upload a high resolution image which will be downscaled to the correct input size.
566
  Use the slider to compare the low resolution input image with the super-resolved output.
567
 
 
734
 
735
  gr.Markdown("---") # Separator
736
  gr.Markdown("### Click an example to load:")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
737
 
738
+ # --- GALLERY FOR SUPER RESOLUTION EXAMPLES ---
739
+ sr_gallery_items = [[ex[0], f"Prompt: {ex[1]} Steps: {ex[3]}"] for ex in example_list_sr]
740
+ sr_gallery = gr.Gallery(
741
+ value=sr_gallery_items,
742
+ label="Super Resolution Examples",
743
+ columns=4,
744
+ height="auto",
745
+ visible=True
746
+ )
747
+
748
+ # --- GALLERY FOR INPAINTING EXAMPLES ---
749
+ inp_gallery_items = [[ex[0]["background"], f"Prompt: {ex[1]} Steps: {ex[3]}"] for ex in example_list_inp]
750
+ inp_gallery = gr.Gallery(
751
+ value=inp_gallery_items,
752
+ label="Inpainting Examples",
753
+ columns=4,
754
+ height="auto",
755
+ visible=True
756
+ )
757
+
758
+ def on_sr_gallery_select(evt: gr.SelectData):
759
+ idx = evt.index
760
+ ex = example_list_sr[idx]
761
+ image_input.value = ex[0]
762
+ prompt_text.value = ex[1]
763
+ task_selector.value = ex[2]
764
+ num_steps_slider.value = ex[3]
765
+ update_visibility(ex[2])
766
+ return [image_input, prompt_text, task_selector, num_steps_slider]
767
+
768
+ def on_inp_gallery_select(evt: gr.SelectData):
769
+ idx = evt.index
770
+ ex = example_list_inp[idx]
771
+ image_editor.value = ex[0]
772
+ prompt_text.value = ex[1]
773
+ task_selector.value = ex[2]
774
+ num_steps_slider.value = ex[3]
775
+ update_visibility(ex[2])
776
+ return [image_editor, prompt_text, task_selector, num_steps_slider]
777
+
778
+ sr_gallery.select(
779
+ fn=on_sr_gallery_select,
780
+ outputs=[image_input, prompt_text, task_selector, num_steps_slider]
781
+ )
782
+ inp_gallery.select(
783
+ fn=on_inp_gallery_select,
784
+ outputs=[image_editor, prompt_text, task_selector, num_steps_slider]
785
+ )
786
 
787
  # --- End of Gradio UI definition ---
788
 
demo_images/demo_2_meta.json CHANGED
@@ -2,6 +2,6 @@
2
  "prompt": "a high quality image of a face.",
3
  "seed_on_slider": 3211750901,
4
  "use_random_seed_checkbox": false,
5
- "num_steps": 50,
6
  "task_type": "Inpainting"
7
  }
 
2
  "prompt": "a high quality image of a face.",
3
  "seed_on_slider": 3211750901,
4
  "use_random_seed_checkbox": false,
5
+ "num_steps": 75,
6
  "task_type": "Inpainting"
7
  }
requirements.txt CHANGED
@@ -16,4 +16,3 @@ sentencepiece
16
  protobuf
17
  accelerate
18
  gradio
19
- gradio_imageslider
 
16
  protobuf
17
  accelerate
18
  gradio
 
src/flair/degradations.py CHANGED
@@ -2,7 +2,7 @@ import torch
2
  import numpy as np
3
  from munch import munchify
4
  from scipy.ndimage import distance_transform_edt
5
- from flair.functions.degradation import get_degradation
6
  import torchvision
7
 
8
  class BaseDegradation(torch.nn.Module):
 
2
  import numpy as np
3
  from munch import munchify
4
  from scipy.ndimage import distance_transform_edt
5
+ from src.flair.functions.degradation import get_degradation
6
  import torchvision
7
 
8
  class BaseDegradation(torch.nn.Module):
src/flair/functions/degradation.py CHANGED
@@ -2,9 +2,9 @@ import numpy as np
2
  import torch
3
  from munch import Munch
4
 
5
- from flair.functions import svd_operators as svd_op
6
- from flair.functions import measurements
7
- from flair.utils.inpaint_util import MaskGenerator
8
 
9
  __DEGRADATION__ = {}
10
 
@@ -187,7 +187,7 @@ def deg_deblur_guass_general(deg_config, device):
187
  return A_funcs
188
 
189
 
190
- from flair.functions.jpeg import jpeg_encode, jpeg_decode
191
 
192
  class JPEGOperator():
193
  def __init__(self, qf: int, device):
 
2
  import torch
3
  from munch import Munch
4
 
5
+ from src.flair.functions import svd_operators as svd_op
6
+ from src.flair.functions import measurements
7
+ from src.flair.utils.inpaint_util import MaskGenerator
8
 
9
  __DEGRADATION__ = {}
10
 
 
187
  return A_funcs
188
 
189
 
190
+ from src.flair.functions.jpeg import jpeg_encode, jpeg_decode
191
 
192
  class JPEGOperator():
193
  def __init__(self, qf: int, device):
src/flair/functions/measurements.py CHANGED
@@ -6,15 +6,15 @@ from functools import partial
6
  from torch.nn import functional as F
7
  from torchvision import torch
8
 
9
- from flair.utils.blur_util import Blurkernel
10
- from flair.utils.img_util import fft2d
11
  import numpy as np
12
- from flair.utils.resizer import Resizer
13
- from flair.utils.utils_sisr import pre_calculate_FK, pre_calculate_nonuniform
14
  from torch.fft import fft2, ifft2
15
 
16
 
17
- from flair.motionblur.motionblur import Kernel
18
 
19
  # =================
20
  # Operation classes
 
6
  from torch.nn import functional as F
7
  from torchvision import torch
8
 
9
+ from src.flair.utils.blur_util import Blurkernel
10
+ from src.flair.utils.img_util import fft2d
11
  import numpy as np
12
+ from src.flair.utils.resizer import Resizer
13
+ from src.flair.utils.utils_sisr import pre_calculate_FK, pre_calculate_nonuniform
14
  from torch.fft import fft2, ifft2
15
 
16
 
17
+ from src.flair.motionblur.motionblur import Kernel
18
 
19
  # =================
20
  # Operation classes
src/flair/pipelines/model_loader.py CHANGED
@@ -10,7 +10,7 @@ from diffusers import (
10
  from diffusers import AutoencoderTiny
11
 
12
 
13
- from flair.pipelines import sd3
14
 
15
 
16
 
 
10
  from diffusers import AutoencoderTiny
11
 
12
 
13
+ from src.flair.pipelines import sd3
14
 
15
 
16
 
src/flair/pipelines/sd3.py CHANGED
@@ -1,7 +1,7 @@
1
  import torch
2
  from typing import Dict, Any
3
  from diffusers.pipelines.stable_diffusion_3 import pipeline_stable_diffusion_3
4
- from flair.pipelines import utils
5
  import tqdm
6
 
7
  class SD3Wrapper(pipeline_stable_diffusion_3.StableDiffusion3Pipeline):
 
1
  import torch
2
  from typing import Dict, Any
3
  from diffusers.pipelines.stable_diffusion_3 import pipeline_stable_diffusion_3
4
+ from src.flair.pipelines import utils
5
  import tqdm
6
 
7
  class SD3Wrapper(pipeline_stable_diffusion_3.StableDiffusion3Pipeline):
src/flair/utils/blur_util.py CHANGED
@@ -3,7 +3,7 @@ from torch import nn
3
  import numpy as np
4
  import scipy
5
 
6
- from flair.utils.motionblur import Kernel as MotionKernel
7
 
8
  class Blurkernel(nn.Module):
9
  def __init__(self, blur_type='gaussian', kernel_size=31, std=3.0, device=None):
 
3
  import numpy as np
4
  import scipy
5
 
6
+ from src.flair.utils.motionblur import Kernel as MotionKernel
7
 
8
  class Blurkernel(nn.Module):
9
  def __init__(self, blur_type='gaussian', kernel_size=31, std=3.0, device=None):
src/flair/var_post_samp.py CHANGED
@@ -3,7 +3,7 @@ import numpy as np
3
  import sys
4
  import os
5
  import tqdm
6
- from flair import degradations
7
  import torchvision
8
 
9
  def total_variation_loss(x):
 
3
  import sys
4
  import os
5
  import tqdm
6
+ from src.flair import degradations
7
  import torchvision
8
 
9
  def total_variation_loss(x):