import os import random import requests from pathlib import Path import gradio as gr import numpy as np from PIL import Image from scripts.constants import Const from scripts.text_editing_click2mask import click2mask_app from scripts.clicker import ClickDraw # -----------------------------------------------------------------------------# # Download model checkpoint once # # -----------------------------------------------------------------------------# CKPT_URL = ( "https://download.openxlab.org.cn/models/SunzeY/AlphaCLIP/weight/" "clip_l14_336_grit1m_fultune_8xe.pth" ) CKPT_PATH = Path("checkpoints") / "clip_l14_336_grit1m_fultune_8xe.pth" def download_checkpoint() -> None: """Download the model weights if they are not already on disk.""" if CKPT_PATH.exists(): print("Checkpoint already exists, skipping download.") return print("Downloading model checkpoint …") CKPT_PATH.parent.mkdir(parents=True, exist_ok=True) with requests.get(CKPT_URL, stream=True) as r: r.raise_for_status() with CKPT_PATH.open("wb") as f: for chunk in r.iter_content(chunk_size=8192): if chunk: f.write(chunk) print("Download completed.") # -----------------------------------------------------------------------------# # Helper functions # # -----------------------------------------------------------------------------# example_prompts = ["A sea monster", "A big ship", "An iceberg"] example_image_path = "examples/gradio/img3.jpg" example_point = [320, 285] # x, y def resize_to_512(image: Image.Image) -> Image.Image: """Convert to RGB and resize to 512×512 if needed.""" if image.mode != "RGB": image = image.convert("RGB") if image.size != (Const.W, Const.H): image = image.resize((Const.W, Const.H), Image.LANCZOS) return image def extract_point512(xy, img_size): """Convert click (x, y) in UI coords → (row, col) in 512×512 space.""" x, y = xy scale_x = Const.W / img_size[0] scale_y = Const.H / img_size[1] return int(y * scale_y), int(x * scale_x) # (row, col) # -----------------------------------------------------------------------------# # Gradio UI # # -----------------------------------------------------------------------------# CSS = """ .btn-generate { background-color: #b2f2bb !important; color: black !important; } .btn-clear { background-color: #ffc9c9 !important; color: black !important; } .btn-example { background-color: #dee2e6 !important; color: black !important; } .centered { text-align: center; } """ with gr.Blocks(css=CSS) as demo: # ---------- per-session state ---------- orig_image_state = gr.State() # PIL.Image or None point_state = gr.State() # tuple(row, col) or None # ---------- layout ---------- with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Upload and Click on Image", elem_classes="centered") image = gr.Image(type="pil", height=512, width=512, label="Input") with gr.Column(scale=1): gr.Markdown("### Generated Image", elem_classes="centered") output = gr.Image( height=512, width=512, label="Output", show_download_button=True ) prompt = gr.Textbox(label="What object to add?", placeholder="e.g. a sea monster") with gr.Row(): gen_btn = gr.Button("Generate", elem_classes="btn-generate") ex_btn = gr.Button("Load Example", elem_classes="btn-example") clr_btn = gr.Button("Reset", elem_classes="btn-clear") # ---------- callbacks ---------- def on_upload(image): """Handle file drop / selection.""" image = resize_to_512(image) return image, image, None # show img, store orig_image_state, clear point def on_click(image, evt: gr.SelectData, orig_image): """Handle a click on the input image.""" if orig_image is None: raise gr.Error("Upload an image first.") point512 = extract_point512(evt.index, image.size) _, img_marked = ClickDraw()(orig_image, point512=point512) return img_marked, orig_image, point512 def generate_image(prompt_txt, orig_image, point512): """Generate new image given prompt and point.""" if not prompt_txt or prompt_txt.strip() == "": raise gr.Error("Please enter a prompt.") if orig_image is None or point512 is None: raise gr.Error("Upload and click on an image first.") return click2mask_app(prompt_txt, orig_image, point512) def clear_all(): """Reset UI and session state.""" return None, "", None, None # clear image, prompt, orig_state, point_state def load_example(): """Load predefined example.""" img = resize_to_512(Image.open(example_image_path)) point512 = np.array(example_point[::-1]) # (y, x) _, img_marked = ClickDraw()(img, point512=point512) example_prompt = random.choice(example_prompts) return img_marked, example_prompt, img, point512 # ---------- wiring ---------- image.upload( fn=on_upload, inputs=image, outputs=[image, orig_image_state, point_state], ) image.select( fn=on_click, inputs=[image, orig_image_state], outputs=[image, orig_image_state, point_state], ) gen_btn.click( fn=generate_image, inputs=[prompt, orig_image_state, point_state], outputs=output, ) ex_btn.click( fn=load_example, inputs=[], outputs=[image, prompt, orig_image_state, point_state], ) clr_btn.click( fn=clear_all, inputs=[], outputs=[image, prompt, orig_image_state, point_state], ) # -----------------------------------------------------------------------------# # Run # # -----------------------------------------------------------------------------# if __name__ == "__main__": download_checkpoint() demo.queue() demo.launch(share=True)