Spaces:
Running
Running
import gradio as gr | |
from PIL import Image, ImageDraw | |
from inference import generate_image | |
TASK_TO_INDEX = {"Task 1": 0, "Task 2": 1, "Task 3": 2, "Task 4": 3} | |
TASK_OPTIMAL_COORDS = {0: (325, 326), 1: (59, 1126), 2: (47, 102), 3: (497, 933)} | |
def create_marker_overlay(image_path: str, x: int, y: int) -> Image.Image: | |
"""Creates an image with a marker at the specified coordinates""" | |
base_image = Image.open(image_path) | |
marked_image = base_image.copy() | |
draw = ImageDraw.Draw(marked_image) | |
marker_size = 10 | |
marker_color = "red" | |
draw.line([x - marker_size, y, x + marker_size, y], fill=marker_color, width=2) | |
draw.line([x, y - marker_size, x, y + marker_size], fill=marker_color, width=2) | |
return marked_image | |
def update_reference_image(choice: int) -> tuple[str, int, str]: | |
image_path = f"imgs/pattern_{choice}.png" | |
heatmap_path = f"imgs/heatmap_{choice}.png" | |
return image_path, choice, heatmap_path | |
def update_marker(image_idx: int, evt: gr.SelectData) -> tuple[Image.Image, tuple[int, int]]: | |
x, y = evt.index[0], evt.index[1] | |
heatmap_path = f"imgs/heatmap_{image_idx}.png" | |
return create_marker_overlay(heatmap_path, x, y), (x, y) | |
def generate_output_image(image_idx: int, coords: tuple[int, int]) -> Image.Image: | |
x, y = coords | |
x_norm, y_norm = x / 1155, y / 1155 | |
return generate_image(image_idx, x_norm, y_norm) | |
def find_optimal_latent(image_idx: int) -> tuple[Image.Image, tuple[int, int], Image.Image]: | |
x, y = TASK_OPTIMAL_COORDS[image_idx] | |
heatmap_path = f"imgs/heatmap_{image_idx}.png" | |
marked_heatmap = create_marker_overlay(heatmap_path, x, y) | |
output_img = generate_output_image(image_idx, (x, y)) | |
return marked_heatmap, (x, y), output_img | |
with gr.Blocks( | |
css=""" | |
.container { | |
max-width: 1000px !important; | |
width: 100% !important; | |
margin-left: auto !important; | |
margin-right: auto !important; | |
padding: 0 1rem !important; | |
} | |
.diagram-container { | |
width: 80% !important; | |
max-width: 800px !important; | |
margin-left: auto !important; | |
margin-right: auto !important; | |
} | |
.diagram-container img { | |
width: 100% !important; | |
height: auto !important; | |
object-fit: contain !important; | |
} | |
.radio-container { | |
width: 100% !important; | |
max-width: 450px !important; | |
margin-left: auto !important; | |
margin-right: auto !important; | |
} | |
.image-container { | |
width: 100% !important; | |
aspect-ratio: 1 !important; | |
max-width: 450px !important; | |
margin-left: auto !important; | |
margin-right: auto !important; | |
} | |
.image-container img { | |
width: 100% !important; | |
height: 100% !important; | |
object-fit: contain !important; | |
} | |
.coordinate-container { | |
width: 100% !important; | |
aspect-ratio: 1 !important; | |
position: relative !important; | |
max-width: 600px !important; | |
margin-left: auto !important; | |
margin-right: auto !important; | |
} | |
.coordinate-container img { | |
width: 100% !important; | |
height: 100% !important; | |
object-fit: contain !important; | |
} | |
.button-container { | |
display: flex !important; | |
justify-content: center !important; | |
width: 100% !important; | |
max-width: 600px !important; | |
margin-left: auto !important; | |
margin-right: auto !important; | |
} | |
.documentation { | |
margin-top: 2rem !important; | |
padding: 1rem !important; | |
background-color: #f8f9fa !important; | |
border-radius: 8px !important; | |
} | |
.optimal-button { | |
margin-top: 1rem !important; | |
margin-bottom: 1rem !important; | |
width: 200px !important; | |
} | |
""" | |
) as demo: | |
with gr.Column(elem_classes="container"): | |
gr.Markdown( | |
""" | |
# Interactive Image Generation | |
## Method Overview | |
This interactive demo showcases our novel image generation method that uses coordinate-based control. | |
The process allows precise control over generated patterns through a coordinate-conditioning mechanism. | |
""" | |
) | |
with gr.Column(elem_classes="diagram-container"): | |
gr.Image( | |
value="imgs/lpn_diagram.png", | |
show_label=False, | |
interactive=False, | |
show_download_button=False, | |
show_fullscreen_button=False, | |
) | |
gr.Markdown( | |
""" | |
### How to Use | |
1. Choose a pattern generation task using the radio buttons | |
2. View the target pattern for your selected task | |
3. Click anywhere in the heatmap to specify coordinates in the latent space | |
4. See the generated image based on your selection | |
Use the "Find Optimal Latent" button to automatically select pre-determined optimal coordinates. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=3): | |
selected_idx = gr.State(value=0) | |
coords = gr.State() | |
with gr.Column(elem_classes="radio-container"): | |
task_select = gr.Radio( | |
choices=["Task 1", "Task 2", "Task 3", "Task 4"], | |
value="Task 1", | |
label="Select Task", | |
interactive=True, | |
) | |
gr.Markdown("### Reference Pattern") | |
with gr.Column(elem_classes="image-container"): | |
reference_image = gr.Image( | |
value="imgs/pattern_0.png", | |
show_label=False, | |
interactive=False, | |
show_download_button=False, | |
show_fullscreen_button=False, | |
) | |
gr.Markdown("### Generated Output") | |
with gr.Column(elem_classes="image-container"): | |
output_image = gr.Image( | |
show_label=False, | |
interactive=False, | |
show_download_button=False, | |
show_fullscreen_button=False, | |
) | |
with gr.Column(scale=4): | |
gr.Markdown("### Coordinate Selector") | |
gr.Markdown( | |
"Click anywhere in the image below to select (x, y) coordinates in the latent space" | |
) | |
with gr.Column(elem_classes="coordinate-container"): | |
coord_selector = gr.Image( | |
value="imgs/heatmap_0.png", | |
show_label=False, | |
interactive=False, | |
sources=[], | |
container=True, | |
show_download_button=False, | |
show_fullscreen_button=False, | |
) | |
with gr.Column(elem_classes="button-container"): | |
optimal_button = gr.Button("Find Optimal Latent", elem_classes="optimal-button") | |
with gr.Column(elem_classes="documentation"): | |
gr.Markdown( | |
""" | |
### Technical Details | |
Our approach uses a novel coordinate-conditioning mechanism that allows precise control over the generated patterns. | |
The heatmap visualization shows the distribution of pattern characteristics across the latent space. | |
For more information, please refer to our [paper](https://arxiv.org/pdf/2411.08706) or GitHub [repository](https://github.com/clement-bonnet/lpn). | |
""" | |
) | |
# Event handlers | |
task_select.change( | |
fn=lambda x: update_reference_image(TASK_TO_INDEX[x]), | |
inputs=[task_select], | |
outputs=[reference_image, selected_idx, coord_selector], | |
) | |
coord_selector.select( | |
fn=update_marker, | |
inputs=[selected_idx], | |
outputs=[coord_selector, coords], | |
trigger_mode="multiple", | |
).then( | |
fn=generate_output_image, | |
inputs=[selected_idx, coords], | |
outputs=output_image, | |
) | |
optimal_button.click( | |
fn=find_optimal_latent, | |
inputs=[selected_idx], | |
outputs=[coord_selector, coords, output_image], | |
) | |
demo.launch() | |