Spaces:
Running
on
Zero
Running
on
Zero
import flair fix
Browse files- app.py +60 -43
- demo_images/demo_2_meta.json +1 -1
- requirements.txt +0 -1
- src/flair/degradations.py +1 -1
- src/flair/functions/degradation.py +4 -4
- src/flair/functions/measurements.py +5 -5
- src/flair/pipelines/model_loader.py +1 -1
- src/flair/pipelines/sd3.py +1 -1
- src/flair/utils/blur_util.py +1 -1
- src/flair/var_post_samp.py +1 -1
app.py
CHANGED
|
@@ -14,8 +14,8 @@ project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "."))
|
|
| 14 |
if project_root not in sys.path:
|
| 15 |
sys.path.insert(0, project_root)
|
| 16 |
|
| 17 |
-
from flair.pipelines import model_loader
|
| 18 |
-
from flair import var_post_samp, degradations
|
| 19 |
|
| 20 |
|
| 21 |
|
|
@@ -283,7 +283,7 @@ def inpaint_image(image_editor_output, prompt_text, fixed_seed_value, use_random
|
|
| 283 |
|
| 284 |
current_seed = None
|
| 285 |
if use_random_seed:
|
| 286 |
-
current_seed =
|
| 287 |
else:
|
| 288 |
try:
|
| 289 |
current_seed = int(fixed_seed_value)
|
|
@@ -514,15 +514,16 @@ if os.path.exists(demo_images_dir):
|
|
| 514 |
metadata = json.load(f)
|
| 515 |
task = metadata.get("task_type")
|
| 516 |
prompt = metadata.get("prompt", "")
|
|
|
|
| 517 |
if task == "Super Resolution":
|
| 518 |
-
example_list_sr.append([image_path, prompt, task])
|
| 519 |
else:
|
| 520 |
image_editor_input = {
|
| 521 |
"background": image_path,
|
| 522 |
"layers": [mask_path],
|
| 523 |
"composite": None # Add this key to satisfy ImageEditor's as_example processing
|
| 524 |
}
|
| 525 |
-
example_list_inp.append([image_editor_input, prompt, task])
|
| 526 |
|
| 527 |
# Structure for ImageEditor: { "background": filepath, "layers": [filepath], "composite": None }
|
| 528 |
|
|
@@ -553,10 +554,14 @@ if __name__ == "__main__":
|
|
| 553 |
sys.exit(1)
|
| 554 |
|
| 555 |
# --- Define Gradio UI using gr.Blocks after globals are initialized ---
|
| 556 |
-
title_str = "Solving Inverse Problems with FLAIR
|
| 557 |
description_str = """
|
| 558 |
Select a task (Inpainting or Super Resolution) and upload an image.
|
| 559 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 560 |
For Super Resolution, upload a low-resolution image and select the upscaling factor. Images are always upscaled to 768x768 pixels. Therefore, for x12 superresolution, the input image must be 64x64 pixels. You can also upload a high resolution image which will be downscaled to the correct input size.
|
| 561 |
Use the slider to compare the low resolution input image with the super-resolved output.
|
| 562 |
|
|
@@ -729,43 +734,55 @@ Use the slider to compare the low resolution input image with the super-resolved
|
|
| 729 |
|
| 730 |
gr.Markdown("---") # Separator
|
| 731 |
gr.Markdown("### Click an example to load:")
|
| 732 |
-
def load_example(example_data, prompt, task):
|
| 733 |
-
image_editor_input = example_data[0]
|
| 734 |
-
prompt_value = example_data[1]
|
| 735 |
-
if task == "Inpainting":
|
| 736 |
-
image_editor.clear() # Clear current image and mask
|
| 737 |
-
if image_editor_input and image_editor_input.get("background"):
|
| 738 |
-
image_editor.upload_image(image_editor_input["background"])
|
| 739 |
-
if image_editor_input and image_editor_input.get("layers"):
|
| 740 |
-
for layer in image_editor_input["layers"]:
|
| 741 |
-
image_editor.upload_mask(layer)
|
| 742 |
-
elif task == "Super Resolution":
|
| 743 |
-
image_input.clear()
|
| 744 |
-
image_input.upload_image(image_editor_input)
|
| 745 |
-
|
| 746 |
-
# Set the prompt
|
| 747 |
-
prompt_text.value = prompt_value
|
| 748 |
-
# Optionally, set a random seed and guidance scale
|
| 749 |
-
seed_slider.value = random.randint(0, 2**32 - 1)
|
| 750 |
-
guidance_scale_slider.value = default_guidance_scale
|
| 751 |
-
# Set the task selector from the example
|
| 752 |
-
task_selector.set_value(task)
|
| 753 |
-
update_visibility(task) # Update visibility based on task
|
| 754 |
|
| 755 |
-
|
| 756 |
-
|
| 757 |
-
|
| 758 |
-
|
| 759 |
-
|
| 760 |
-
|
| 761 |
-
|
| 762 |
-
|
| 763 |
-
|
| 764 |
-
|
| 765 |
-
|
| 766 |
-
|
| 767 |
-
|
| 768 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 769 |
|
| 770 |
# --- End of Gradio UI definition ---
|
| 771 |
|
|
|
|
| 14 |
if project_root not in sys.path:
|
| 15 |
sys.path.insert(0, project_root)
|
| 16 |
|
| 17 |
+
from src.flair.pipelines import model_loader
|
| 18 |
+
from src.flair import var_post_samp, degradations
|
| 19 |
|
| 20 |
|
| 21 |
|
|
|
|
| 283 |
|
| 284 |
current_seed = None
|
| 285 |
if use_random_seed:
|
| 286 |
+
current_seed = random.randint(0, 2**32 - 1)
|
| 287 |
else:
|
| 288 |
try:
|
| 289 |
current_seed = int(fixed_seed_value)
|
|
|
|
| 514 |
metadata = json.load(f)
|
| 515 |
task = metadata.get("task_type")
|
| 516 |
prompt = metadata.get("prompt", "")
|
| 517 |
+
n_steps = metadata.get("num_steps", 50)
|
| 518 |
if task == "Super Resolution":
|
| 519 |
+
example_list_sr.append([image_path, prompt, task, n_steps])
|
| 520 |
else:
|
| 521 |
image_editor_input = {
|
| 522 |
"background": image_path,
|
| 523 |
"layers": [mask_path],
|
| 524 |
"composite": None # Add this key to satisfy ImageEditor's as_example processing
|
| 525 |
}
|
| 526 |
+
example_list_inp.append([image_editor_input, prompt, task, n_steps])
|
| 527 |
|
| 528 |
# Structure for ImageEditor: { "background": filepath, "layers": [filepath], "composite": None }
|
| 529 |
|
|
|
|
| 554 |
sys.exit(1)
|
| 555 |
|
| 556 |
# --- Define Gradio UI using gr.Blocks after globals are initialized ---
|
| 557 |
+
title_str = "Solving Inverse Problems with FLAIR"
|
| 558 |
description_str = """
|
| 559 |
Select a task (Inpainting or Super Resolution) and upload an image.
|
| 560 |
+
|
| 561 |
+
For Inpainting, draw a mask on the image to specify the area to be filled.
|
| 562 |
+
|
| 563 |
+
We observed that our model can event solve simple editing task, if provided with an appropriate prompt. For large masks the step size might need to be adjusted to e.g. 80.
|
| 564 |
+
|
| 565 |
For Super Resolution, upload a low-resolution image and select the upscaling factor. Images are always upscaled to 768x768 pixels. Therefore, for x12 superresolution, the input image must be 64x64 pixels. You can also upload a high resolution image which will be downscaled to the correct input size.
|
| 566 |
Use the slider to compare the low resolution input image with the super-resolved output.
|
| 567 |
|
|
|
|
| 734 |
|
| 735 |
gr.Markdown("---") # Separator
|
| 736 |
gr.Markdown("### Click an example to load:")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 737 |
|
| 738 |
+
# --- GALLERY FOR SUPER RESOLUTION EXAMPLES ---
|
| 739 |
+
sr_gallery_items = [[ex[0], f"Prompt: {ex[1]} Steps: {ex[3]}"] for ex in example_list_sr]
|
| 740 |
+
sr_gallery = gr.Gallery(
|
| 741 |
+
value=sr_gallery_items,
|
| 742 |
+
label="Super Resolution Examples",
|
| 743 |
+
columns=4,
|
| 744 |
+
height="auto",
|
| 745 |
+
visible=True
|
| 746 |
+
)
|
| 747 |
+
|
| 748 |
+
# --- GALLERY FOR INPAINTING EXAMPLES ---
|
| 749 |
+
inp_gallery_items = [[ex[0]["background"], f"Prompt: {ex[1]} Steps: {ex[3]}"] for ex in example_list_inp]
|
| 750 |
+
inp_gallery = gr.Gallery(
|
| 751 |
+
value=inp_gallery_items,
|
| 752 |
+
label="Inpainting Examples",
|
| 753 |
+
columns=4,
|
| 754 |
+
height="auto",
|
| 755 |
+
visible=True
|
| 756 |
+
)
|
| 757 |
+
|
| 758 |
+
def on_sr_gallery_select(evt: gr.SelectData):
|
| 759 |
+
idx = evt.index
|
| 760 |
+
ex = example_list_sr[idx]
|
| 761 |
+
image_input.value = ex[0]
|
| 762 |
+
prompt_text.value = ex[1]
|
| 763 |
+
task_selector.value = ex[2]
|
| 764 |
+
num_steps_slider.value = ex[3]
|
| 765 |
+
update_visibility(ex[2])
|
| 766 |
+
return [image_input, prompt_text, task_selector, num_steps_slider]
|
| 767 |
+
|
| 768 |
+
def on_inp_gallery_select(evt: gr.SelectData):
|
| 769 |
+
idx = evt.index
|
| 770 |
+
ex = example_list_inp[idx]
|
| 771 |
+
image_editor.value = ex[0]
|
| 772 |
+
prompt_text.value = ex[1]
|
| 773 |
+
task_selector.value = ex[2]
|
| 774 |
+
num_steps_slider.value = ex[3]
|
| 775 |
+
update_visibility(ex[2])
|
| 776 |
+
return [image_editor, prompt_text, task_selector, num_steps_slider]
|
| 777 |
+
|
| 778 |
+
sr_gallery.select(
|
| 779 |
+
fn=on_sr_gallery_select,
|
| 780 |
+
outputs=[image_input, prompt_text, task_selector, num_steps_slider]
|
| 781 |
+
)
|
| 782 |
+
inp_gallery.select(
|
| 783 |
+
fn=on_inp_gallery_select,
|
| 784 |
+
outputs=[image_editor, prompt_text, task_selector, num_steps_slider]
|
| 785 |
+
)
|
| 786 |
|
| 787 |
# --- End of Gradio UI definition ---
|
| 788 |
|
demo_images/demo_2_meta.json
CHANGED
|
@@ -2,6 +2,6 @@
|
|
| 2 |
"prompt": "a high quality image of a face.",
|
| 3 |
"seed_on_slider": 3211750901,
|
| 4 |
"use_random_seed_checkbox": false,
|
| 5 |
-
"num_steps":
|
| 6 |
"task_type": "Inpainting"
|
| 7 |
}
|
|
|
|
| 2 |
"prompt": "a high quality image of a face.",
|
| 3 |
"seed_on_slider": 3211750901,
|
| 4 |
"use_random_seed_checkbox": false,
|
| 5 |
+
"num_steps": 75,
|
| 6 |
"task_type": "Inpainting"
|
| 7 |
}
|
requirements.txt
CHANGED
|
@@ -16,4 +16,3 @@ sentencepiece
|
|
| 16 |
protobuf
|
| 17 |
accelerate
|
| 18 |
gradio
|
| 19 |
-
gradio_imageslider
|
|
|
|
| 16 |
protobuf
|
| 17 |
accelerate
|
| 18 |
gradio
|
|
|
src/flair/degradations.py
CHANGED
|
@@ -2,7 +2,7 @@ import torch
|
|
| 2 |
import numpy as np
|
| 3 |
from munch import munchify
|
| 4 |
from scipy.ndimage import distance_transform_edt
|
| 5 |
-
from flair.functions.degradation import get_degradation
|
| 6 |
import torchvision
|
| 7 |
|
| 8 |
class BaseDegradation(torch.nn.Module):
|
|
|
|
| 2 |
import numpy as np
|
| 3 |
from munch import munchify
|
| 4 |
from scipy.ndimage import distance_transform_edt
|
| 5 |
+
from src.flair.functions.degradation import get_degradation
|
| 6 |
import torchvision
|
| 7 |
|
| 8 |
class BaseDegradation(torch.nn.Module):
|
src/flair/functions/degradation.py
CHANGED
|
@@ -2,9 +2,9 @@ import numpy as np
|
|
| 2 |
import torch
|
| 3 |
from munch import Munch
|
| 4 |
|
| 5 |
-
from flair.functions import svd_operators as svd_op
|
| 6 |
-
from flair.functions import measurements
|
| 7 |
-
from flair.utils.inpaint_util import MaskGenerator
|
| 8 |
|
| 9 |
__DEGRADATION__ = {}
|
| 10 |
|
|
@@ -187,7 +187,7 @@ def deg_deblur_guass_general(deg_config, device):
|
|
| 187 |
return A_funcs
|
| 188 |
|
| 189 |
|
| 190 |
-
from flair.functions.jpeg import jpeg_encode, jpeg_decode
|
| 191 |
|
| 192 |
class JPEGOperator():
|
| 193 |
def __init__(self, qf: int, device):
|
|
|
|
| 2 |
import torch
|
| 3 |
from munch import Munch
|
| 4 |
|
| 5 |
+
from src.flair.functions import svd_operators as svd_op
|
| 6 |
+
from src.flair.functions import measurements
|
| 7 |
+
from src.flair.utils.inpaint_util import MaskGenerator
|
| 8 |
|
| 9 |
__DEGRADATION__ = {}
|
| 10 |
|
|
|
|
| 187 |
return A_funcs
|
| 188 |
|
| 189 |
|
| 190 |
+
from src.flair.functions.jpeg import jpeg_encode, jpeg_decode
|
| 191 |
|
| 192 |
class JPEGOperator():
|
| 193 |
def __init__(self, qf: int, device):
|
src/flair/functions/measurements.py
CHANGED
|
@@ -6,15 +6,15 @@ from functools import partial
|
|
| 6 |
from torch.nn import functional as F
|
| 7 |
from torchvision import torch
|
| 8 |
|
| 9 |
-
from flair.utils.blur_util import Blurkernel
|
| 10 |
-
from flair.utils.img_util import fft2d
|
| 11 |
import numpy as np
|
| 12 |
-
from flair.utils.resizer import Resizer
|
| 13 |
-
from flair.utils.utils_sisr import pre_calculate_FK, pre_calculate_nonuniform
|
| 14 |
from torch.fft import fft2, ifft2
|
| 15 |
|
| 16 |
|
| 17 |
-
from flair.motionblur.motionblur import Kernel
|
| 18 |
|
| 19 |
# =================
|
| 20 |
# Operation classes
|
|
|
|
| 6 |
from torch.nn import functional as F
|
| 7 |
from torchvision import torch
|
| 8 |
|
| 9 |
+
from src.flair.utils.blur_util import Blurkernel
|
| 10 |
+
from src.flair.utils.img_util import fft2d
|
| 11 |
import numpy as np
|
| 12 |
+
from src.flair.utils.resizer import Resizer
|
| 13 |
+
from src.flair.utils.utils_sisr import pre_calculate_FK, pre_calculate_nonuniform
|
| 14 |
from torch.fft import fft2, ifft2
|
| 15 |
|
| 16 |
|
| 17 |
+
from src.flair.motionblur.motionblur import Kernel
|
| 18 |
|
| 19 |
# =================
|
| 20 |
# Operation classes
|
src/flair/pipelines/model_loader.py
CHANGED
|
@@ -10,7 +10,7 @@ from diffusers import (
|
|
| 10 |
from diffusers import AutoencoderTiny
|
| 11 |
|
| 12 |
|
| 13 |
-
from flair.pipelines import sd3
|
| 14 |
|
| 15 |
|
| 16 |
|
|
|
|
| 10 |
from diffusers import AutoencoderTiny
|
| 11 |
|
| 12 |
|
| 13 |
+
from src.flair.pipelines import sd3
|
| 14 |
|
| 15 |
|
| 16 |
|
src/flair/pipelines/sd3.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
import torch
|
| 2 |
from typing import Dict, Any
|
| 3 |
from diffusers.pipelines.stable_diffusion_3 import pipeline_stable_diffusion_3
|
| 4 |
-
from flair.pipelines import utils
|
| 5 |
import tqdm
|
| 6 |
|
| 7 |
class SD3Wrapper(pipeline_stable_diffusion_3.StableDiffusion3Pipeline):
|
|
|
|
| 1 |
import torch
|
| 2 |
from typing import Dict, Any
|
| 3 |
from diffusers.pipelines.stable_diffusion_3 import pipeline_stable_diffusion_3
|
| 4 |
+
from src.flair.pipelines import utils
|
| 5 |
import tqdm
|
| 6 |
|
| 7 |
class SD3Wrapper(pipeline_stable_diffusion_3.StableDiffusion3Pipeline):
|
src/flair/utils/blur_util.py
CHANGED
|
@@ -3,7 +3,7 @@ from torch import nn
|
|
| 3 |
import numpy as np
|
| 4 |
import scipy
|
| 5 |
|
| 6 |
-
from flair.utils.motionblur import Kernel as MotionKernel
|
| 7 |
|
| 8 |
class Blurkernel(nn.Module):
|
| 9 |
def __init__(self, blur_type='gaussian', kernel_size=31, std=3.0, device=None):
|
|
|
|
| 3 |
import numpy as np
|
| 4 |
import scipy
|
| 5 |
|
| 6 |
+
from src.flair.utils.motionblur import Kernel as MotionKernel
|
| 7 |
|
| 8 |
class Blurkernel(nn.Module):
|
| 9 |
def __init__(self, blur_type='gaussian', kernel_size=31, std=3.0, device=None):
|
src/flair/var_post_samp.py
CHANGED
|
@@ -3,7 +3,7 @@ import numpy as np
|
|
| 3 |
import sys
|
| 4 |
import os
|
| 5 |
import tqdm
|
| 6 |
-
from flair import degradations
|
| 7 |
import torchvision
|
| 8 |
|
| 9 |
def total_variation_loss(x):
|
|
|
|
| 3 |
import sys
|
| 4 |
import os
|
| 5 |
import tqdm
|
| 6 |
+
from src.flair import degradations
|
| 7 |
import torchvision
|
| 8 |
|
| 9 |
def total_variation_loss(x):
|