Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,180 Bytes
780320d f6e8319 780320d f6e8319 780320d 87c1890 780320d d73c075 f6e8319 d73c075 f6e8319 d73c075 f6e8319 d73c075 f6e8319 d73c075 87c1890 f6e8319 780320d f6e8319 780320d 87c1890 f6e8319 780320d f6e8319 780320d f6e8319 87c1890 780320d f6e8319 780320d f6e8319 780320d 24ee135 780320d 87c1890 780320d f6e8319 24ee135 87c1890 780320d 87c1890 d73c075 87c1890 d73c075 f6e8319 87c1890 780320d 87c1890 780320d 24ee135 780320d 87c1890 780320d 87c1890 f6e8319 87c1890 f6e8319 d73c075 87c1890 f6e8319 d73c075 f6e8319 d73c075 f6e8319 d73c075 f6e8319 87c1890 f6e8319 87c1890 780320d 87c1890 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
import os
import shutil
import subprocess
from pathlib import Path
from PIL import Image
import gradio as gr
import spaces
INPUT_DIR = "samples"
OUTPUT_DIR = "inference_results/coz_vlmprompt"
def resize_and_center_crop(img: Image.Image, size: int) -> Image.Image:
w, h = img.size
scale = size / min(w, h)
new_w, new_h = int(w * scale), int(h * scale)
img = img.resize((new_w, new_h), Image.LANCZOS)
left = (new_w - size) // 2
top = (new_h - size) // 2
return img.crop((left, top, left + size, top + size))
def make_preview_with_boxes(image_path: str, scale_option: str) -> Image.Image:
try:
orig = Image.open(image_path).convert("RGB")
except Exception as e:
fallback = Image.new("RGB", (512, 512), (200, 200, 200))
from PIL import ImageDraw
draw = ImageDraw.Draw(fallback)
draw.text((20, 20), f"Error:\n{e}", fill="red")
return fallback
base = resize_and_center_crop(orig, 512)
scale_int = int(scale_option.replace("x", ""))
if scale_int == 1: sizes = [512] * 4
else: sizes = [512 // (scale_int * (2 ** i)) for i in range(4)]
from PIL import ImageDraw
draw = ImageDraw.Draw(base)
colors = ["red", "lime", "cyan", "yellow"]
width = 3
for idx, s in enumerate(sizes):
x0 = (512 - s) // 2
y0 = (512 - s) // 2
x1 = x0 + s
y1 = y0 + s
draw.rectangle([(x0, y0), (x1, y1)], outline=colors[idx], width=width)
return base
@spaces.GPU(duration=120)
def run_with_upload(uploaded_image_path, upscale_option, session_id=None):
"""
Each invocation creates/uses:
- samples/<session_id>/input.png ← user’s uploaded image
- inference_results/coz_vlmprompt/<session_id>/per-sample/input/*.png ← inference outputs
"""
if uploaded_image_path is None:
return []
# 1) Prepare a per-session input directory
print(session_id)
session_folder = os.path.join(INPUT_DIR, str(session_id))
os.makedirs(session_folder, exist_ok=True)
# 2) Clear only this session’s folder
for fn in os.listdir(session_folder):
full_path = os.path.join(session_folder, fn)
if os.path.isfile(full_path) or os.path.islink(full_path):
os.remove(full_path)
elif os.path.isdir(full_path):
shutil.rmtree(full_path)
# 3) Save uploaded image to session_folder/input.png
try:
pil_img = Image.open(uploaded_image_path).convert("RGB")
save_path = Path(session_folder) / "input.png"
pil_img.save(save_path, format="PNG")
except Exception as e:
print(f"Error: could not save uploaded image: {e}")
return []
# 4) Define a per-session output directory
session_output_dir = os.path.join(OUTPUT_DIR, str(session_id))
os.makedirs(session_output_dir, exist_ok=True)
# 5) Build and run the inference command
upscale_value = upscale_option.replace("x", "")
cmd = [
"python", "inference_coz.py",
"-i", session_folder,
"-o", session_output_dir,
"--rec_type", "recursive_multiscale",
"--prompt_type", "vlm",
"--upscale", upscale_value,
"--lora_path", "ckpt/SR_LoRA/model_20001.pkl",
"--vae_path", "ckpt/SR_VAE/vae_encoder_20001.pt",
"--pretrained_model_name_or_path", "stabilityai/stable-diffusion-3-medium-diffusers",
"--ram_ft_path", "ckpt/DAPE/DAPE.pth",
"--ram_path", "ckpt/RAM/ram_swin_large_14m.pth"
]
try:
subprocess.run(cmd, check=True)
except subprocess.CalledProcessError as err:
print("Inference failed:", err)
return []
# 6) Gather output file paths (1.png through 4.png)
per_sample_dir = os.path.join(session_output_dir, "per-sample", "input")
expected_files = [os.path.join(per_sample_dir, f"{i}.png") for i in range(1, 5)]
for fp in expected_files:
if not os.path.isfile(fp):
print(f"Warning: expected file not found: {fp}")
return []
return expected_files
def get_caption(src_gallery, evt: gr.SelectData):
if not src_gallery or not os.path.isfile(src_gallery[evt.index][0]):
return "No caption available."
selected_image_path = src_gallery[evt.index][0]
base = os.path.basename(selected_image_path) # e.g. "2.png"
stem = os.path.splitext(base)[0] # e.g. "2"
txt_folder = os.path.join(OUTPUT_DIR, str(evt.index), "per-sample", "input", "txt")
txt_path = os.path.join(txt_folder, f"{int(stem) - 1}.txt")
if not os.path.isfile(txt_path):
return f"Caption file not found: {int(stem) - 1}.txt"
try:
with open(txt_path, "r", encoding="utf-8") as f:
caption = f.read().strip()
return caption if caption else "(Caption file is empty.)"
except Exception as e:
return f"Error reading caption: {e}"
css = """
#col-container {
margin: 0 auto;
max-width: 1024px;
}
"""
with gr.Blocks(css=css) as demo:
gr.HTML(
"""
<div style="text-align: center;">
<h1>Chain-of-Zoom</h1>
<p style="font-size:16px;">Extreme Super-Resolution via Scale Autoregression and Preference Alignment</p>
</div>
<br>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<a href="https://github.com/bryanswkim/Chain-of-Zoom">
<img src='https://img.shields.io/badge/GitHub-Repo-blue'>
</a>
</div>
"""
)
with gr.Column(elem_id="col-container"):
with gr.Row():
with gr.Column():
upload_image = gr.Image(label="Upload your input image", type="filepath")
upscale_radio = gr.Radio(choices=["1x", "2x", "4x"], value="2x", show_label=False)
run_button = gr.Button("Chain-of-Zoom it")
preview_with_box = gr.Image(label="Preview (512×512 with centered boxes)", type="pil", interactive=False)
with gr.Column():
output_gallery = gr.Gallery(label="Inference Results", show_label=True, columns=[2], rows=[2])
caption_text = gr.Textbox(label="Caption", lines=4, placeholder="Click on any image above to see its caption here.")
upload_image.change(
fn=lambda img_path, scale_opt: make_preview_with_boxes(img_path, scale_opt) if img_path is not None else None,
inputs=[upload_image, upscale_radio],
outputs=[preview_with_box]
)
upscale_radio.change(
fn=lambda img_path, scale_opt: make_preview_with_boxes(img_path, scale_opt) if img_path is not None else None,
inputs=[upload_image, upscale_radio],
outputs=[preview_with_box]
)
# Note: gr.State() will pass session_id automatically
run_button.click(
fn=run_with_upload,
inputs=[upload_image, upscale_radio, gr.State()],
outputs=[output_gallery]
)
output_gallery.select(
fn=get_caption,
inputs=[output_gallery],
outputs=[caption_text]
)
demo.launch(share=True)
|