File size: 7,180 Bytes
780320d
 
f6e8319
780320d
f6e8319
 
780320d
 
87c1890
 
780320d
d73c075
 
 
 
 
 
 
 
 
 
 
 
 
 
f6e8319
d73c075
 
 
f6e8319
 
 
 
 
d73c075
 
f6e8319
d73c075
 
 
 
 
f6e8319
d73c075
 
87c1890
f6e8319
780320d
f6e8319
 
 
780320d
 
87c1890
f6e8319
 
 
 
 
 
 
 
 
 
 
 
 
 
780320d
 
f6e8319
780320d
 
f6e8319
87c1890
780320d
f6e8319
 
 
 
 
 
780320d
 
f6e8319
 
780320d
 
24ee135
780320d
 
 
 
 
 
 
 
 
 
87c1890
780320d
f6e8319
 
 
24ee135
 
 
87c1890
 
780320d
87c1890
d73c075
87c1890
d73c075
 
 
f6e8319
87c1890
 
 
 
 
 
 
 
 
780320d
87c1890
780320d
 
24ee135
780320d
 
 
 
 
 
87c1890
 
 
 
 
 
 
 
 
 
 
780320d
 
 
87c1890
 
f6e8319
 
87c1890
f6e8319
d73c075
87c1890
f6e8319
 
d73c075
 
f6e8319
d73c075
 
 
 
f6e8319
d73c075
 
 
 
f6e8319
87c1890
 
f6e8319
87c1890
 
 
 
 
 
 
 
780320d
87c1890
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import os
import shutil
import subprocess
from pathlib import Path
from PIL import Image
import gradio as gr
import spaces

INPUT_DIR   = "samples"
OUTPUT_DIR  = "inference_results/coz_vlmprompt"

def resize_and_center_crop(img: Image.Image, size: int) -> Image.Image:
    w, h = img.size
    scale = size / min(w, h)
    new_w, new_h = int(w * scale), int(h * scale)
    img = img.resize((new_w, new_h), Image.LANCZOS)
    left = (new_w - size) // 2
    top  = (new_h - size) // 2
    return img.crop((left, top, left + size, top + size))

def make_preview_with_boxes(image_path: str, scale_option: str) -> Image.Image:
    try:
        orig = Image.open(image_path).convert("RGB")
    except Exception as e:
        fallback = Image.new("RGB", (512, 512), (200, 200, 200))
        from PIL import ImageDraw
        draw = ImageDraw.Draw(fallback)
        draw.text((20, 20), f"Error:\n{e}", fill="red")
        return fallback
    base = resize_and_center_crop(orig, 512)
    scale_int = int(scale_option.replace("x", ""))
    if scale_int == 1: sizes = [512] * 4
    else: sizes = [512 // (scale_int * (2 ** i)) for i in range(4)]
    from PIL import ImageDraw
    draw = ImageDraw.Draw(base)
    colors = ["red", "lime", "cyan", "yellow"]
    width = 3
    for idx, s in enumerate(sizes):
        x0 = (512 - s) // 2
        y0 = (512 - s) // 2
        x1 = x0 + s
        y1 = y0 + s
        draw.rectangle([(x0, y0), (x1, y1)], outline=colors[idx], width=width)
    return base

@spaces.GPU(duration=120)
def run_with_upload(uploaded_image_path, upscale_option, session_id=None):
    """
    Each invocation creates/uses:
      - samples/<session_id>/input.png   ← user’s uploaded image
      - inference_results/coz_vlmprompt/<session_id>/per-sample/input/*.png ← inference outputs
    """
    if uploaded_image_path is None:
        return []
    # 1) Prepare a per-session input directory
    print(session_id)
    session_folder = os.path.join(INPUT_DIR, str(session_id))
    os.makedirs(session_folder, exist_ok=True)

    # 2) Clear only this session’s folder
    for fn in os.listdir(session_folder):
        full_path = os.path.join(session_folder, fn)
        if os.path.isfile(full_path) or os.path.islink(full_path):
            os.remove(full_path)
        elif os.path.isdir(full_path):
            shutil.rmtree(full_path)

    # 3) Save uploaded image to session_folder/input.png
    try:
        pil_img = Image.open(uploaded_image_path).convert("RGB")
        save_path = Path(session_folder) / "input.png"
        pil_img.save(save_path, format="PNG")
    except Exception as e:
        print(f"Error: could not save uploaded image: {e}")
        return []

    # 4) Define a per-session output directory
    session_output_dir = os.path.join(OUTPUT_DIR, str(session_id))
    os.makedirs(session_output_dir, exist_ok=True)

    # 5) Build and run the inference command
    upscale_value = upscale_option.replace("x", "")
    cmd = [
        "python", "inference_coz.py",
        "-i", session_folder,
        "-o", session_output_dir,
        "--rec_type", "recursive_multiscale",
        "--prompt_type", "vlm",
        "--upscale", upscale_value,
        "--lora_path", "ckpt/SR_LoRA/model_20001.pkl",
        "--vae_path", "ckpt/SR_VAE/vae_encoder_20001.pt",
        "--pretrained_model_name_or_path", "stabilityai/stable-diffusion-3-medium-diffusers",
        "--ram_ft_path", "ckpt/DAPE/DAPE.pth",
        "--ram_path", "ckpt/RAM/ram_swin_large_14m.pth"
    ]
    try:
        subprocess.run(cmd, check=True)
    except subprocess.CalledProcessError as err:
        print("Inference failed:", err)
        return []

    # 6) Gather output file paths (1.png through 4.png)
    per_sample_dir = os.path.join(session_output_dir, "per-sample", "input")
    expected_files = [os.path.join(per_sample_dir, f"{i}.png") for i in range(1, 5)]
    for fp in expected_files:
        if not os.path.isfile(fp):
            print(f"Warning: expected file not found: {fp}")
            return []
    return expected_files

def get_caption(src_gallery, evt: gr.SelectData):
    if not src_gallery or not os.path.isfile(src_gallery[evt.index][0]):
        return "No caption available."
    selected_image_path = src_gallery[evt.index][0]
    base = os.path.basename(selected_image_path)  # e.g. "2.png"
    stem = os.path.splitext(base)[0]              # e.g. "2"
    txt_folder = os.path.join(OUTPUT_DIR, str(evt.index), "per-sample", "input", "txt")
    txt_path = os.path.join(txt_folder, f"{int(stem) - 1}.txt")
    if not os.path.isfile(txt_path):
        return f"Caption file not found: {int(stem) - 1}.txt"
    try:
        with open(txt_path, "r", encoding="utf-8") as f:
            caption = f.read().strip()
        return caption if caption else "(Caption file is empty.)"
    except Exception as e:
        return f"Error reading caption: {e}"

css = """
#col-container {
    margin: 0 auto;
    max-width: 1024px;
}
"""

with gr.Blocks(css=css) as demo:
    gr.HTML(
        """
        <div style="text-align: center;">
            <h1>Chain-of-Zoom</h1>
            <p style="font-size:16px;">Extreme Super-Resolution via Scale Autoregression and Preference Alignment</p>
        </div>
        <br>
        <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
            <a href="https://github.com/bryanswkim/Chain-of-Zoom">
                <img src='https://img.shields.io/badge/GitHub-Repo-blue'>
            </a>
        </div>
        """
    )

    with gr.Column(elem_id="col-container"):
        with gr.Row():
            with gr.Column():
                upload_image = gr.Image(label="Upload your input image", type="filepath")
                upscale_radio = gr.Radio(choices=["1x", "2x", "4x"], value="2x", show_label=False)
                run_button = gr.Button("Chain-of-Zoom it")
                preview_with_box = gr.Image(label="Preview (512×512 with centered boxes)", type="pil", interactive=False)

            with gr.Column():
                output_gallery = gr.Gallery(label="Inference Results", show_label=True, columns=[2], rows=[2])
                caption_text = gr.Textbox(label="Caption", lines=4, placeholder="Click on any image above to see its caption here.")

        upload_image.change(
            fn=lambda img_path, scale_opt: make_preview_with_boxes(img_path, scale_opt) if img_path is not None else None,
            inputs=[upload_image, upscale_radio],
            outputs=[preview_with_box]
        )
        upscale_radio.change(
            fn=lambda img_path, scale_opt: make_preview_with_boxes(img_path, scale_opt) if img_path is not None else None,
            inputs=[upload_image, upscale_radio],
            outputs=[preview_with_box]
        )

        # Note: gr.State() will pass session_id automatically
        run_button.click(
            fn=run_with_upload,
            inputs=[upload_image, upscale_radio, gr.State()],
            outputs=[output_gallery]
        )

        output_gallery.select(
            fn=get_caption,
            inputs=[output_gallery],
            outputs=[caption_text]
        )

demo.launch(share=True)