Spaces:
Running
on
Zero
Running
on
Zero
import flair fix
Browse files- app.py +60 -43
- demo_images/demo_2_meta.json +1 -1
- requirements.txt +0 -1
- src/flair/degradations.py +1 -1
- src/flair/functions/degradation.py +4 -4
- src/flair/functions/measurements.py +5 -5
- src/flair/pipelines/model_loader.py +1 -1
- src/flair/pipelines/sd3.py +1 -1
- src/flair/utils/blur_util.py +1 -1
- src/flair/var_post_samp.py +1 -1
app.py
CHANGED
@@ -14,8 +14,8 @@ project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "."))
|
|
14 |
if project_root not in sys.path:
|
15 |
sys.path.insert(0, project_root)
|
16 |
|
17 |
-
from flair.pipelines import model_loader
|
18 |
-
from flair import var_post_samp, degradations
|
19 |
|
20 |
|
21 |
|
@@ -283,7 +283,7 @@ def inpaint_image(image_editor_output, prompt_text, fixed_seed_value, use_random
|
|
283 |
|
284 |
current_seed = None
|
285 |
if use_random_seed:
|
286 |
-
current_seed =
|
287 |
else:
|
288 |
try:
|
289 |
current_seed = int(fixed_seed_value)
|
@@ -514,15 +514,16 @@ if os.path.exists(demo_images_dir):
|
|
514 |
metadata = json.load(f)
|
515 |
task = metadata.get("task_type")
|
516 |
prompt = metadata.get("prompt", "")
|
|
|
517 |
if task == "Super Resolution":
|
518 |
-
example_list_sr.append([image_path, prompt, task])
|
519 |
else:
|
520 |
image_editor_input = {
|
521 |
"background": image_path,
|
522 |
"layers": [mask_path],
|
523 |
"composite": None # Add this key to satisfy ImageEditor's as_example processing
|
524 |
}
|
525 |
-
example_list_inp.append([image_editor_input, prompt, task])
|
526 |
|
527 |
# Structure for ImageEditor: { "background": filepath, "layers": [filepath], "composite": None }
|
528 |
|
@@ -553,10 +554,14 @@ if __name__ == "__main__":
|
|
553 |
sys.exit(1)
|
554 |
|
555 |
# --- Define Gradio UI using gr.Blocks after globals are initialized ---
|
556 |
-
title_str = "Solving Inverse Problems with FLAIR
|
557 |
description_str = """
|
558 |
Select a task (Inpainting or Super Resolution) and upload an image.
|
559 |
-
|
|
|
|
|
|
|
|
|
560 |
For Super Resolution, upload a low-resolution image and select the upscaling factor. Images are always upscaled to 768x768 pixels. Therefore, for x12 superresolution, the input image must be 64x64 pixels. You can also upload a high resolution image which will be downscaled to the correct input size.
|
561 |
Use the slider to compare the low resolution input image with the super-resolved output.
|
562 |
|
@@ -729,43 +734,55 @@ Use the slider to compare the low resolution input image with the super-resolved
|
|
729 |
|
730 |
gr.Markdown("---") # Separator
|
731 |
gr.Markdown("### Click an example to load:")
|
732 |
-
def load_example(example_data, prompt, task):
|
733 |
-
image_editor_input = example_data[0]
|
734 |
-
prompt_value = example_data[1]
|
735 |
-
if task == "Inpainting":
|
736 |
-
image_editor.clear() # Clear current image and mask
|
737 |
-
if image_editor_input and image_editor_input.get("background"):
|
738 |
-
image_editor.upload_image(image_editor_input["background"])
|
739 |
-
if image_editor_input and image_editor_input.get("layers"):
|
740 |
-
for layer in image_editor_input["layers"]:
|
741 |
-
image_editor.upload_mask(layer)
|
742 |
-
elif task == "Super Resolution":
|
743 |
-
image_input.clear()
|
744 |
-
image_input.upload_image(image_editor_input)
|
745 |
-
|
746 |
-
# Set the prompt
|
747 |
-
prompt_text.value = prompt_value
|
748 |
-
# Optionally, set a random seed and guidance scale
|
749 |
-
seed_slider.value = random.randint(0, 2**32 - 1)
|
750 |
-
guidance_scale_slider.value = default_guidance_scale
|
751 |
-
# Set the task selector from the example
|
752 |
-
task_selector.set_value(task)
|
753 |
-
update_visibility(task) # Update visibility based on task
|
754 |
|
755 |
-
|
756 |
-
|
757 |
-
|
758 |
-
|
759 |
-
|
760 |
-
|
761 |
-
|
762 |
-
|
763 |
-
|
764 |
-
|
765 |
-
|
766 |
-
|
767 |
-
|
768 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
769 |
|
770 |
# --- End of Gradio UI definition ---
|
771 |
|
|
|
14 |
if project_root not in sys.path:
|
15 |
sys.path.insert(0, project_root)
|
16 |
|
17 |
+
from src.flair.pipelines import model_loader
|
18 |
+
from src.flair import var_post_samp, degradations
|
19 |
|
20 |
|
21 |
|
|
|
283 |
|
284 |
current_seed = None
|
285 |
if use_random_seed:
|
286 |
+
current_seed = random.randint(0, 2**32 - 1)
|
287 |
else:
|
288 |
try:
|
289 |
current_seed = int(fixed_seed_value)
|
|
|
514 |
metadata = json.load(f)
|
515 |
task = metadata.get("task_type")
|
516 |
prompt = metadata.get("prompt", "")
|
517 |
+
n_steps = metadata.get("num_steps", 50)
|
518 |
if task == "Super Resolution":
|
519 |
+
example_list_sr.append([image_path, prompt, task, n_steps])
|
520 |
else:
|
521 |
image_editor_input = {
|
522 |
"background": image_path,
|
523 |
"layers": [mask_path],
|
524 |
"composite": None # Add this key to satisfy ImageEditor's as_example processing
|
525 |
}
|
526 |
+
example_list_inp.append([image_editor_input, prompt, task, n_steps])
|
527 |
|
528 |
# Structure for ImageEditor: { "background": filepath, "layers": [filepath], "composite": None }
|
529 |
|
|
|
554 |
sys.exit(1)
|
555 |
|
556 |
# --- Define Gradio UI using gr.Blocks after globals are initialized ---
|
557 |
+
title_str = "Solving Inverse Problems with FLAIR"
|
558 |
description_str = """
|
559 |
Select a task (Inpainting or Super Resolution) and upload an image.
|
560 |
+
|
561 |
+
For Inpainting, draw a mask on the image to specify the area to be filled.
|
562 |
+
|
563 |
+
We observed that our model can event solve simple editing task, if provided with an appropriate prompt. For large masks the step size might need to be adjusted to e.g. 80.
|
564 |
+
|
565 |
For Super Resolution, upload a low-resolution image and select the upscaling factor. Images are always upscaled to 768x768 pixels. Therefore, for x12 superresolution, the input image must be 64x64 pixels. You can also upload a high resolution image which will be downscaled to the correct input size.
|
566 |
Use the slider to compare the low resolution input image with the super-resolved output.
|
567 |
|
|
|
734 |
|
735 |
gr.Markdown("---") # Separator
|
736 |
gr.Markdown("### Click an example to load:")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
737 |
|
738 |
+
# --- GALLERY FOR SUPER RESOLUTION EXAMPLES ---
|
739 |
+
sr_gallery_items = [[ex[0], f"Prompt: {ex[1]} Steps: {ex[3]}"] for ex in example_list_sr]
|
740 |
+
sr_gallery = gr.Gallery(
|
741 |
+
value=sr_gallery_items,
|
742 |
+
label="Super Resolution Examples",
|
743 |
+
columns=4,
|
744 |
+
height="auto",
|
745 |
+
visible=True
|
746 |
+
)
|
747 |
+
|
748 |
+
# --- GALLERY FOR INPAINTING EXAMPLES ---
|
749 |
+
inp_gallery_items = [[ex[0]["background"], f"Prompt: {ex[1]} Steps: {ex[3]}"] for ex in example_list_inp]
|
750 |
+
inp_gallery = gr.Gallery(
|
751 |
+
value=inp_gallery_items,
|
752 |
+
label="Inpainting Examples",
|
753 |
+
columns=4,
|
754 |
+
height="auto",
|
755 |
+
visible=True
|
756 |
+
)
|
757 |
+
|
758 |
+
def on_sr_gallery_select(evt: gr.SelectData):
|
759 |
+
idx = evt.index
|
760 |
+
ex = example_list_sr[idx]
|
761 |
+
image_input.value = ex[0]
|
762 |
+
prompt_text.value = ex[1]
|
763 |
+
task_selector.value = ex[2]
|
764 |
+
num_steps_slider.value = ex[3]
|
765 |
+
update_visibility(ex[2])
|
766 |
+
return [image_input, prompt_text, task_selector, num_steps_slider]
|
767 |
+
|
768 |
+
def on_inp_gallery_select(evt: gr.SelectData):
|
769 |
+
idx = evt.index
|
770 |
+
ex = example_list_inp[idx]
|
771 |
+
image_editor.value = ex[0]
|
772 |
+
prompt_text.value = ex[1]
|
773 |
+
task_selector.value = ex[2]
|
774 |
+
num_steps_slider.value = ex[3]
|
775 |
+
update_visibility(ex[2])
|
776 |
+
return [image_editor, prompt_text, task_selector, num_steps_slider]
|
777 |
+
|
778 |
+
sr_gallery.select(
|
779 |
+
fn=on_sr_gallery_select,
|
780 |
+
outputs=[image_input, prompt_text, task_selector, num_steps_slider]
|
781 |
+
)
|
782 |
+
inp_gallery.select(
|
783 |
+
fn=on_inp_gallery_select,
|
784 |
+
outputs=[image_editor, prompt_text, task_selector, num_steps_slider]
|
785 |
+
)
|
786 |
|
787 |
# --- End of Gradio UI definition ---
|
788 |
|
demo_images/demo_2_meta.json
CHANGED
@@ -2,6 +2,6 @@
|
|
2 |
"prompt": "a high quality image of a face.",
|
3 |
"seed_on_slider": 3211750901,
|
4 |
"use_random_seed_checkbox": false,
|
5 |
-
"num_steps":
|
6 |
"task_type": "Inpainting"
|
7 |
}
|
|
|
2 |
"prompt": "a high quality image of a face.",
|
3 |
"seed_on_slider": 3211750901,
|
4 |
"use_random_seed_checkbox": false,
|
5 |
+
"num_steps": 75,
|
6 |
"task_type": "Inpainting"
|
7 |
}
|
requirements.txt
CHANGED
@@ -16,4 +16,3 @@ sentencepiece
|
|
16 |
protobuf
|
17 |
accelerate
|
18 |
gradio
|
19 |
-
gradio_imageslider
|
|
|
16 |
protobuf
|
17 |
accelerate
|
18 |
gradio
|
|
src/flair/degradations.py
CHANGED
@@ -2,7 +2,7 @@ import torch
|
|
2 |
import numpy as np
|
3 |
from munch import munchify
|
4 |
from scipy.ndimage import distance_transform_edt
|
5 |
-
from flair.functions.degradation import get_degradation
|
6 |
import torchvision
|
7 |
|
8 |
class BaseDegradation(torch.nn.Module):
|
|
|
2 |
import numpy as np
|
3 |
from munch import munchify
|
4 |
from scipy.ndimage import distance_transform_edt
|
5 |
+
from src.flair.functions.degradation import get_degradation
|
6 |
import torchvision
|
7 |
|
8 |
class BaseDegradation(torch.nn.Module):
|
src/flair/functions/degradation.py
CHANGED
@@ -2,9 +2,9 @@ import numpy as np
|
|
2 |
import torch
|
3 |
from munch import Munch
|
4 |
|
5 |
-
from flair.functions import svd_operators as svd_op
|
6 |
-
from flair.functions import measurements
|
7 |
-
from flair.utils.inpaint_util import MaskGenerator
|
8 |
|
9 |
__DEGRADATION__ = {}
|
10 |
|
@@ -187,7 +187,7 @@ def deg_deblur_guass_general(deg_config, device):
|
|
187 |
return A_funcs
|
188 |
|
189 |
|
190 |
-
from flair.functions.jpeg import jpeg_encode, jpeg_decode
|
191 |
|
192 |
class JPEGOperator():
|
193 |
def __init__(self, qf: int, device):
|
|
|
2 |
import torch
|
3 |
from munch import Munch
|
4 |
|
5 |
+
from src.flair.functions import svd_operators as svd_op
|
6 |
+
from src.flair.functions import measurements
|
7 |
+
from src.flair.utils.inpaint_util import MaskGenerator
|
8 |
|
9 |
__DEGRADATION__ = {}
|
10 |
|
|
|
187 |
return A_funcs
|
188 |
|
189 |
|
190 |
+
from src.flair.functions.jpeg import jpeg_encode, jpeg_decode
|
191 |
|
192 |
class JPEGOperator():
|
193 |
def __init__(self, qf: int, device):
|
src/flair/functions/measurements.py
CHANGED
@@ -6,15 +6,15 @@ from functools import partial
|
|
6 |
from torch.nn import functional as F
|
7 |
from torchvision import torch
|
8 |
|
9 |
-
from flair.utils.blur_util import Blurkernel
|
10 |
-
from flair.utils.img_util import fft2d
|
11 |
import numpy as np
|
12 |
-
from flair.utils.resizer import Resizer
|
13 |
-
from flair.utils.utils_sisr import pre_calculate_FK, pre_calculate_nonuniform
|
14 |
from torch.fft import fft2, ifft2
|
15 |
|
16 |
|
17 |
-
from flair.motionblur.motionblur import Kernel
|
18 |
|
19 |
# =================
|
20 |
# Operation classes
|
|
|
6 |
from torch.nn import functional as F
|
7 |
from torchvision import torch
|
8 |
|
9 |
+
from src.flair.utils.blur_util import Blurkernel
|
10 |
+
from src.flair.utils.img_util import fft2d
|
11 |
import numpy as np
|
12 |
+
from src.flair.utils.resizer import Resizer
|
13 |
+
from src.flair.utils.utils_sisr import pre_calculate_FK, pre_calculate_nonuniform
|
14 |
from torch.fft import fft2, ifft2
|
15 |
|
16 |
|
17 |
+
from src.flair.motionblur.motionblur import Kernel
|
18 |
|
19 |
# =================
|
20 |
# Operation classes
|
src/flair/pipelines/model_loader.py
CHANGED
@@ -10,7 +10,7 @@ from diffusers import (
|
|
10 |
from diffusers import AutoencoderTiny
|
11 |
|
12 |
|
13 |
-
from flair.pipelines import sd3
|
14 |
|
15 |
|
16 |
|
|
|
10 |
from diffusers import AutoencoderTiny
|
11 |
|
12 |
|
13 |
+
from src.flair.pipelines import sd3
|
14 |
|
15 |
|
16 |
|
src/flair/pipelines/sd3.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import torch
|
2 |
from typing import Dict, Any
|
3 |
from diffusers.pipelines.stable_diffusion_3 import pipeline_stable_diffusion_3
|
4 |
-
from flair.pipelines import utils
|
5 |
import tqdm
|
6 |
|
7 |
class SD3Wrapper(pipeline_stable_diffusion_3.StableDiffusion3Pipeline):
|
|
|
1 |
import torch
|
2 |
from typing import Dict, Any
|
3 |
from diffusers.pipelines.stable_diffusion_3 import pipeline_stable_diffusion_3
|
4 |
+
from src.flair.pipelines import utils
|
5 |
import tqdm
|
6 |
|
7 |
class SD3Wrapper(pipeline_stable_diffusion_3.StableDiffusion3Pipeline):
|
src/flair/utils/blur_util.py
CHANGED
@@ -3,7 +3,7 @@ from torch import nn
|
|
3 |
import numpy as np
|
4 |
import scipy
|
5 |
|
6 |
-
from flair.utils.motionblur import Kernel as MotionKernel
|
7 |
|
8 |
class Blurkernel(nn.Module):
|
9 |
def __init__(self, blur_type='gaussian', kernel_size=31, std=3.0, device=None):
|
|
|
3 |
import numpy as np
|
4 |
import scipy
|
5 |
|
6 |
+
from src.flair.utils.motionblur import Kernel as MotionKernel
|
7 |
|
8 |
class Blurkernel(nn.Module):
|
9 |
def __init__(self, blur_type='gaussian', kernel_size=31, std=3.0, device=None):
|
src/flair/var_post_samp.py
CHANGED
@@ -3,7 +3,7 @@ import numpy as np
|
|
3 |
import sys
|
4 |
import os
|
5 |
import tqdm
|
6 |
-
from flair import degradations
|
7 |
import torchvision
|
8 |
|
9 |
def total_variation_loss(x):
|
|
|
3 |
import sys
|
4 |
import os
|
5 |
import tqdm
|
6 |
+
from src.flair import degradations
|
7 |
import torchvision
|
8 |
|
9 |
def total_variation_loss(x):
|