Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import subprocess | |
import os | |
import shutil | |
from pathlib import Path | |
from PIL import Image | |
import spaces | |
# ----------------------------------------------------------------------------- | |
# CONFIGURE THESE PATHS TO MATCH YOUR PROJECT STRUCTURE | |
# ----------------------------------------------------------------------------- | |
INPUT_DIR = "samples" | |
OUTPUT_DIR = "inference_results/coz_vlmprompt" | |
# ----------------------------------------------------------------------------- | |
# HELPER FUNCTION TO RUN INFERENCE AND RETURN THE OUTPUT IMAGE | |
# ----------------------------------------------------------------------------- | |
def run_with_upload(uploaded_image_path, upscale_option): | |
""" | |
1) Clear INPUT_DIR | |
2) Save the uploaded file as input.png in INPUT_DIR | |
3) Read `upscale_option` (e.g. "1x", "2x", "4x") → turn it into "1", "2", or "4" | |
4) Call inference_coz.py with `--upscale <that_value>` | |
5) (Here we assume you still stitch together 1.png–4.png, or however you want.) | |
""" | |
# 1) Make sure INPUT_DIR exists; if it does, delete everything inside. | |
os.makedirs(INPUT_DIR, exist_ok=True) | |
for fn in os.listdir(INPUT_DIR): | |
full_path = os.path.join(INPUT_DIR, fn) | |
try: | |
if os.path.isfile(full_path) or os.path.islink(full_path): | |
os.remove(full_path) | |
elif os.path.isdir(full_path): | |
shutil.rmtree(full_path) | |
except Exception as e: | |
print(f"Warning: could not delete {full_path}: {e}") | |
# 2) Copy the uploaded image into INPUT_DIR. | |
# Gradio will give us a path like "/tmp/gradio_xyz.png" | |
if uploaded_image_path is None: | |
return None | |
try: | |
# Open with PIL (this handles JPEG, BMP, TIFF, etc.) | |
pil_img = Image.open(uploaded_image_path).convert("RGB") | |
except Exception as e: | |
print(f"Error: could not open uploaded image: {e}") | |
return None | |
# Save it as "input.png" in our INPUT_DIR | |
save_path = Path(INPUT_DIR) / "input.png" | |
try: | |
pil_img.save(save_path, format="PNG") | |
except Exception as e: | |
print(f"Error: could not save as PNG: {e}") | |
return None | |
# 3) Build and run your inference_coz.py command. | |
# This will block until it completes. | |
upscale_value = upscale_option.replace("x", "") # e.g. "2x" → "2" | |
cmd = [ | |
"python", "inference_coz.py", | |
"-i", INPUT_DIR, | |
"-o", OUTPUT_DIR, | |
"--rec_type", "recursive_multiscale", | |
"--prompt_type", "vlm", | |
"--upscale", upscale_value, | |
"--lora_path", "ckpt/SR_LoRA/model_20001.pkl", | |
"--vae_path", "ckpt/SR_VAE/vae_encoder_20001.pt", | |
"--pretrained_model_name_or_path", "stabilityai/stable-diffusion-3-medium-diffusers", | |
"--ram_ft_path", "ckpt/DAPE/DAPE.pth", | |
"--ram_path", "ckpt/RAM/ram_swin_large_14m.pth" | |
] | |
try: | |
subprocess.run(cmd, check=True) | |
except subprocess.CalledProcessError as err: | |
# If inference_coz.py crashes, we can print/log the error. | |
print("Inference failed:", err) | |
return None | |
# ------------------------------------------------------------------------- | |
# 4) After inference, look for the four numbered PNGs and stitch them | |
# ------------------------------------------------------------------------- | |
per_sample_dir = os.path.join(OUTPUT_DIR, "per-sample", "input") | |
expected_files = [os.path.join(per_sample_dir, f"{i}.png") for i in range(1, 5)] | |
pil_images = [] | |
for fp in expected_files: | |
if not os.path.isfile(fp): | |
print(f"Warning: expected file not found: {fp}") | |
return None | |
try: | |
img = Image.open(fp).convert("RGB") | |
pil_images.append(img) | |
except Exception as e: | |
print(f"Error opening {fp}: {e}") | |
return None | |
if len(pil_images) != 4: | |
print(f"Error: found {len(pil_images)} images, but need 4.") | |
return None | |
widths, heights = zip(*(im.size for im in pil_images)) | |
w, h = widths[0], heights[0] | |
grid_w = w * 2 | |
grid_h = h * 2 | |
# composite = Image.new("RGB", (grid_w, grid_h)) | |
# composite.paste(pil_images[0], (0, 0)) | |
# composite.paste(pil_images[1], (w, 0)) | |
# composite.paste(pil_images[2], (0, h)) | |
# composite.paste(pil_images[3], (w, h)) | |
return [pil_images[0], pil_images[1], pil_images[2], pil_images[3]] | |
# ------------------------------------------------------------- | |
# BUILD THE GRADIO INTERFACE | |
# ----------------------------------------------------------------------------- | |
css=""" | |
#col-container { | |
margin: 0 auto; | |
max-width: 1024px; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
gr.HTML( | |
""" | |
<div style="text-align: center;"> | |
<h1>Chain-of-Zoom</h1> | |
<p style="font-size:16px;">Extreme Super-Resolution via Scale Autoregression and Preference Alignment </p> | |
</div> | |
<br> | |
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> | |
<a href="https://github.com/bryanswkim/Chain-of-Zoom"> | |
<img src='https://img.shields.io/badge/GitHub-Repo-blue'> | |
</a> | |
</div> | |
""" | |
) | |
with gr.Column(elem_id="col-container"): | |
with gr.Row(): | |
with gr.Column(): | |
# 1) Image upload component. We set type="filepath" so the callback | |
# (run_with_upload) will receive a local path to the uploaded file. | |
upload_image = gr.Image( | |
label="Upload your input image", | |
type="filepath" | |
) | |
# 2) Radio for choosing 1× / 2× / 4× upscaling | |
upscale_radio = gr.Radio( | |
choices=["1x", "2x", "4x"], | |
value="2x", | |
show_label=False | |
) | |
# 2) A button that the user will click to launch inference. | |
run_button = gr.Button("Chain-of-Zoom it") | |
# (3) Gallery to display multiple output images | |
output_gallery = gr.Gallery( | |
label="Inference Results", | |
show_label=True, | |
elem_id="gallery", | |
columns=[2], rows=[2] | |
) | |
# Wire the button: when clicked, call run_with_upload(upload_image), put | |
# its return value into output_image. | |
run_button.click( | |
fn=run_with_upload, | |
inputs=[upload_image, upscale_radio], | |
outputs=output_gallery | |
) | |
# ----------------------------------------------------------------------------- | |
# START THE GRADIO SERVER | |
# ----------------------------------------------------------------------------- | |
demo.launch(share=True) |