Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,200 Bytes
79c5088 0a1f404 79c5088 0a1f404 79c5088 0a1f404 79c5088 3004ef2 79c5088 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
import gradio as gr
import spaces
from PIL import Image
import numpy as np
import torch
import torch.nn.functional as F
from utils.pipeline_utils import load_pipeline
from utils import get_args
from main import run
pipeline = load_pipeline(fp16=False, cache_dir=None)
def process_masks(masks):
# masks: list of file paths
processed_masks = []
mask_composit = torch.zeros((512, 512), dtype=torch.float32, device='cuda')
for mask_path in masks:
mask = Image.open(mask_path).convert("L").resize((512, 512))
mask = torch.tensor(np.array(mask), dtype=torch.float32, device='cuda')
mask[mask > 0] = 255.0
mask = mask / 255.0
processed_masks.append(mask)
mask_composit += mask
mask_composit = torch.clamp(mask_composit, 0, 1)
mask_composit = F.interpolate(mask_composit[None, None, :, :], size=(64, 64), mode="nearest")[0, 0]
if mask_composit.sum() == 0:
mask_composit = None
return mask_composit
@spaces.GPU
def main_pipeline(
input_image: str,
src_prompt: str,
tgt_prompt: str,
alpha: float,
beta: float,
w1: float,
seed: int,
dift_correction: bool = True,
):
args = get_args()
args.alpha = alpha
args.beta = beta
args.w1 = w1
args.seed = seed
args.structural_alignment = True
args.support_new_object = True
args.apply_dift_correction = dift_correction
torch.cuda.empty_cache()
res_image = run(input_image['background'], src_prompt, tgt_prompt, masks=process_masks(input_image['layers']),
pipeline=pipeline, args=args)[2]
return res_image
DESCRIPTION = """# Cora ๐ผ๏ธ๐ฑ๐ฆ
## Fast & Controllable Image Editing
### ๐ ๏ธ Quick start
1. **Upload** or drag-and-drop the image youโd like to edit.
2. **Source prompt** โ describe whatโs in the original image.
3. **Target prompt** โ describe the result you want.
4. Adjust the parameters as needed.
5. *(Optional)* Paint a mask to specify the area to edit.
6. Click **Edit** and wait a few seconds for the output.
### โ๏ธ Parameter cheat-sheet
| Parameter | What it does | `0` (minimum) | `1` (maximum) |
|-----------|--------------|---------------|---------------|
| **alpha** | Appearance transfer control | preserve source appearance | target prompt affects appearance |
| **beta** | Structural change control | preserve original structure | full layout change |
| **w** | Prompt strength | subtle tweaks | strong changes |
| **Seed** | Fixes randomness for reproducibility | โ | โ |
| **Apply correspondence correction** | Uses correspondence-aware latent fix | โ | โ |
### ๐ Tips
- To replicate **TurboEdit**, set **alpha = 1**, **beta = 1**, and turn **off** *Apply correspondence correction*.
- To test reconstruction quality of the inversion, use identical source & target prompts with **alpha = 1**, **beta = 1**, and **w = 1**.
#### ๐ Acknowledgements
The demo template is largely adapted from **[TurboEdit on Hugging Face Spaces](https://huggingface.co/spaces/turboedit/turbo_edit)**.
"""
with gr.Blocks(css="app/style.css") as demo:
gr.HTML(
"""<a href="https://huggingface.co/spaces/armikaeili/cora?duplicate=true">
<img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>Duplicate the Space to run privately without waiting in queue"""
)
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column():
input_image = gr.ImageMask(
label="Input image", type="filepath", height=512, width=512, brush=gr.Brush(color_mode='defaults')
)
result = gr.Image(label="Result", type="pil", height=512, width=512)
with gr.Column():
src_prompt = gr.Text(
label="Source Prompt",
max_lines=1,
placeholder="Source Prompt",
)
tgt_prompt = gr.Text(
label="Target Prompt",
max_lines=1,
placeholder="Target Prompt",
)
with gr.Accordion("Advanced Options", open=False):
seed = gr.Slider(
label="seed", minimum=0, maximum=16 * 1024, value=200, step=1
)
w1 = gr.Slider(
label="w", minimum=1.0, maximum=3.0, value=1.9, step=0.05
)
alpha = gr.Slider(
label="alpha", minimum=0, maximum=1, value=0, step=0.01
)
beta = gr.Slider(
label="beta", minimum=0, maximum=1, value=0.04, step=0.01
)
with gr.Row():
dift_correction = gr.Checkbox(
label="Apply correspondence correction",
value=True)
run_button = gr.Button("Edit")
examples = [
[
"assets/white_cat.png", # input_image
"a cat", # src_prompt
"a cat wearing a suit", # tgt_prompt
0.1, # alpha
0.1, # beta
1.9, # w1
7, # seed
True # dift_correction
],
[
"assets/bear.png", # input_image
"a sitting brown bear", # src_prompt
"a roaring blue bear", # tgt_prompt
0.7, # alpha
0.1, # beta
1.9, # w1
7, # seed
True # dift_correction
],
[
"assets/cat.jpg", # input_image
"a cat", # src_prompt
"an eagle", # tgt_prompt
0.7, # alpha
0.3, # beta
1.9, # w1
7, # seed
True # dift_correction
],
[
"assets/dog.png", # input_image
"a photo of a dog", # src_prompt
"a photo of a dog lying", # tgt_prompt
0.0, # alpha
1, # beta
1.9, # w1
7, # seed
True # dift_correction
],
]
inputs = [
input_image,
src_prompt,
tgt_prompt,
alpha,
beta,
w1,
seed,
dift_correction
]
outputs = [result]
#
gr.Examples(
examples=examples,
inputs=inputs,
outputs=outputs,
fn=main_pipeline,
cache_examples=False,
)
run_button.click(fn=main_pipeline, inputs=inputs, outputs=outputs)
demo.queue(max_size=50).launch(share=False) |