File size: 4,886 Bytes
d6614dc
4cac868
d6614dc
9ef0f25
29b5baf
433f4b7
 
 
4cac868
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
808cfce
 
 
 
 
 
 
 
 
 
4cac868
9ef0f25
 
4cac868
9ef0f25
 
4cac868
 
 
 
 
 
 
 
 
 
 
999b913
 
433f4b7
 
f6ee8cd
433f4b7
 
 
 
f6ee8cd
 
 
 
 
 
 
 
 
433f4b7
 
9ef0f25
 
 
808cfce
9ef0f25
29b5baf
d6614dc
9ef0f25
f6ee8cd
9ef0f25
0dcdf8e
 
9ef0f25
433f4b7
 
 
 
 
 
 
 
808cfce
f6ee8cd
808cfce
 
 
433f4b7
808cfce
 
efdb63c
 
808cfce
9ef0f25
f6ee8cd
efdb63c
 
 
 
 
 
 
 
9ef0f25
f6ee8cd
 
 
 
 
 
 
 
efdb63c
 
f6ee8cd
efdb63c
 
f6ee8cd
9ef0f25
808cfce
 
433f4b7
808cfce
 
 
0dcdf8e
4cac868
0dcdf8e
4cac868
 
 
 
0dcdf8e
999b913
9ef0f25
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import gradio as gr
from PIL import Image, ImageDraw

from inference import generate_image

TASK_TO_INDEX = {"Task 1": 0, "Task 2": 1, "Task 3": 2, "Task 4": 3}


def create_marker_overlay(image_path: str, x: int, y: int) -> Image.Image:
    """
    Creates an image with a marker at the specified coordinates
    """
    # Load the base image
    base_image = Image.open(image_path)

    # Create a copy to draw on
    marked_image = base_image.copy()
    draw = ImageDraw.Draw(marked_image)

    # Define marker properties
    marker_size = 10
    marker_color = "red"

    # Draw marker
    draw.line([x - marker_size, y, x + marker_size, y], fill=marker_color, width=2)
    draw.line([x, y - marker_size, x, y + marker_size], fill=marker_color, width=2)

    return marked_image


def update_reference_image(choice: int) -> tuple[str, int, str]:
    """
    Update the reference image display based on radio button selection
    Returns the image path, selected index, and corresponding heatmap
    """
    image_path = f"imgs/pattern_{choice}.png"
    heatmap_path = f"imgs/heatmap_{choice}.png"
    return image_path, choice, heatmap_path


def process_coord_click(image_idx: int, evt: gr.SelectData) -> tuple[Image.Image, Image.Image]:
    """
    Process the click event on the coordinate selector
    Returns both the generated image and the marked coordinate selector
    """
    x, y = evt.index[0], evt.index[1]
    # Create normalized coordinates for generation
    x_norm, y_norm = x / 1155, y / 1155

    # Generate the output image
    generated_image = generate_image(image_idx, x_norm, y_norm)

    # Create marked coordinate selector
    heatmap_path = f"imgs/heatmap_{image_idx}.png"
    marked_selector = create_marker_overlay(heatmap_path, x, y)

    return generated_image, marked_selector


with gr.Blocks(
    css="""
    .radio-container {
        width: 450px !important;
        margin-left: auto !important;
        margin-right: auto !important;
    }
    .coordinate-container {
        width: 600px !important;
        height: 600px !important;
    }
    .coordinate-container img {
        width: 100% !important;
        height: 100% !important;
        object-fit: contain !important;
    }
"""
) as demo:
    gr.Markdown(
        """
    # Interactive Image Generation
    Select a task using the radio buttons, then click on the coordinate selector to generate a new image.
    """
    )

    with gr.Row():
        # Left column: Radio selection, reference image, and output
        with gr.Column(scale=1):
            # State variable to track selected image index
            selected_idx = gr.State(value=0)

            # Radio buttons with container class
            with gr.Column(elem_classes="radio-container"):
                task_select = gr.Radio(
                    choices=["Task 1", "Task 2", "Task 3", "Task 4"],
                    value="Task 1",
                    label="Select Task",
                    interactive=True,
                )

            # Reference image component that updates based on selection
            reference_image = gr.Image(
                value="imgs/pattern_0.png",
                show_label=False,
                interactive=False,
                height=300,
                width=450,
                show_download_button=False,
                show_fullscreen_button=False,
            )

            # Generated image output moved below reference image
            output_image = gr.Image(
                label="Generated Output",
                height=300,
                width=450,
                show_download_button=False,
                show_fullscreen_button=False,
                interactive=False,
            )

        # Right column: Larger coordinate selector
        with gr.Column(scale=1):
            # Coordinate selector with container class for proper scaling
            with gr.Column(elem_classes="coordinate-container"):
                coord_selector = gr.Image(
                    value="imgs/heatmap_0.png",
                    label="Click to select (x, y) coordinates in the latent space",
                    show_label=True,
                    interactive=False,
                    sources=[],
                    container=True,
                    show_download_button=False,
                    show_fullscreen_button=False,
                )

    # Handle radio button selection
    task_select.change(
        fn=lambda x: update_reference_image(TASK_TO_INDEX[x]),
        inputs=[task_select],
        outputs=[reference_image, selected_idx, coord_selector],
    )

    # Handle coordinate selection - now updates both output image and coord_selector
    coord_selector.select(
        process_coord_click,
        inputs=[selected_idx],
        outputs=[output_image, coord_selector],
        trigger_mode="multiple",
    )

demo.launch()