Spaces:
Running
Running
File size: 4,886 Bytes
d6614dc 4cac868 d6614dc 9ef0f25 29b5baf 433f4b7 4cac868 808cfce 4cac868 9ef0f25 4cac868 9ef0f25 4cac868 999b913 433f4b7 f6ee8cd 433f4b7 f6ee8cd 433f4b7 9ef0f25 808cfce 9ef0f25 29b5baf d6614dc 9ef0f25 f6ee8cd 9ef0f25 0dcdf8e 9ef0f25 433f4b7 808cfce f6ee8cd 808cfce 433f4b7 808cfce efdb63c 808cfce 9ef0f25 f6ee8cd efdb63c 9ef0f25 f6ee8cd efdb63c f6ee8cd efdb63c f6ee8cd 9ef0f25 808cfce 433f4b7 808cfce 0dcdf8e 4cac868 0dcdf8e 4cac868 0dcdf8e 999b913 9ef0f25 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import gradio as gr
from PIL import Image, ImageDraw
from inference import generate_image
TASK_TO_INDEX = {"Task 1": 0, "Task 2": 1, "Task 3": 2, "Task 4": 3}
def create_marker_overlay(image_path: str, x: int, y: int) -> Image.Image:
"""
Creates an image with a marker at the specified coordinates
"""
# Load the base image
base_image = Image.open(image_path)
# Create a copy to draw on
marked_image = base_image.copy()
draw = ImageDraw.Draw(marked_image)
# Define marker properties
marker_size = 10
marker_color = "red"
# Draw marker
draw.line([x - marker_size, y, x + marker_size, y], fill=marker_color, width=2)
draw.line([x, y - marker_size, x, y + marker_size], fill=marker_color, width=2)
return marked_image
def update_reference_image(choice: int) -> tuple[str, int, str]:
"""
Update the reference image display based on radio button selection
Returns the image path, selected index, and corresponding heatmap
"""
image_path = f"imgs/pattern_{choice}.png"
heatmap_path = f"imgs/heatmap_{choice}.png"
return image_path, choice, heatmap_path
def process_coord_click(image_idx: int, evt: gr.SelectData) -> tuple[Image.Image, Image.Image]:
"""
Process the click event on the coordinate selector
Returns both the generated image and the marked coordinate selector
"""
x, y = evt.index[0], evt.index[1]
# Create normalized coordinates for generation
x_norm, y_norm = x / 1155, y / 1155
# Generate the output image
generated_image = generate_image(image_idx, x_norm, y_norm)
# Create marked coordinate selector
heatmap_path = f"imgs/heatmap_{image_idx}.png"
marked_selector = create_marker_overlay(heatmap_path, x, y)
return generated_image, marked_selector
with gr.Blocks(
css="""
.radio-container {
width: 450px !important;
margin-left: auto !important;
margin-right: auto !important;
}
.coordinate-container {
width: 600px !important;
height: 600px !important;
}
.coordinate-container img {
width: 100% !important;
height: 100% !important;
object-fit: contain !important;
}
"""
) as demo:
gr.Markdown(
"""
# Interactive Image Generation
Select a task using the radio buttons, then click on the coordinate selector to generate a new image.
"""
)
with gr.Row():
# Left column: Radio selection, reference image, and output
with gr.Column(scale=1):
# State variable to track selected image index
selected_idx = gr.State(value=0)
# Radio buttons with container class
with gr.Column(elem_classes="radio-container"):
task_select = gr.Radio(
choices=["Task 1", "Task 2", "Task 3", "Task 4"],
value="Task 1",
label="Select Task",
interactive=True,
)
# Reference image component that updates based on selection
reference_image = gr.Image(
value="imgs/pattern_0.png",
show_label=False,
interactive=False,
height=300,
width=450,
show_download_button=False,
show_fullscreen_button=False,
)
# Generated image output moved below reference image
output_image = gr.Image(
label="Generated Output",
height=300,
width=450,
show_download_button=False,
show_fullscreen_button=False,
interactive=False,
)
# Right column: Larger coordinate selector
with gr.Column(scale=1):
# Coordinate selector with container class for proper scaling
with gr.Column(elem_classes="coordinate-container"):
coord_selector = gr.Image(
value="imgs/heatmap_0.png",
label="Click to select (x, y) coordinates in the latent space",
show_label=True,
interactive=False,
sources=[],
container=True,
show_download_button=False,
show_fullscreen_button=False,
)
# Handle radio button selection
task_select.change(
fn=lambda x: update_reference_image(TASK_TO_INDEX[x]),
inputs=[task_select],
outputs=[reference_image, selected_idx, coord_selector],
)
# Handle coordinate selection - now updates both output image and coord_selector
coord_selector.select(
process_coord_click,
inputs=[selected_idx],
outputs=[output_image, coord_selector],
trigger_mode="multiple",
)
demo.launch()
|