lpn / app.py
clement-bonnet's picture
feat: resize page
e1b10c7
raw
history blame
8.53 kB
import gradio as gr
from PIL import Image, ImageDraw
from inference import generate_image
TASK_TO_INDEX = {"Task 1": 0, "Task 2": 1, "Task 3": 2, "Task 4": 3}
TASK_OPTIMAL_COORDS = {0: (325, 326), 1: (59, 1126), 2: (47, 102), 3: (497, 933)}
def create_marker_overlay(image_path: str, x: int, y: int) -> Image.Image:
"""Creates an image with a marker at the specified coordinates"""
base_image = Image.open(image_path)
marked_image = base_image.copy()
draw = ImageDraw.Draw(marked_image)
marker_size = 10
marker_color = "red"
draw.line([x - marker_size, y, x + marker_size, y], fill=marker_color, width=2)
draw.line([x, y - marker_size, x, y + marker_size], fill=marker_color, width=2)
return marked_image
def update_reference_image(choice: int) -> tuple[str, int, str]:
image_path = f"imgs/pattern_{choice}.png"
heatmap_path = f"imgs/heatmap_{choice}.png"
return image_path, choice, heatmap_path
def update_marker(image_idx: int, evt: gr.SelectData) -> tuple[Image.Image, tuple[int, int]]:
x, y = evt.index[0], evt.index[1]
heatmap_path = f"imgs/heatmap_{image_idx}.png"
return create_marker_overlay(heatmap_path, x, y), (x, y)
def generate_output_image(image_idx: int, coords: tuple[int, int]) -> Image.Image:
x, y = coords
x_norm, y_norm = x / 1155, y / 1155
return generate_image(image_idx, x_norm, y_norm)
def find_optimal_latent(image_idx: int) -> tuple[Image.Image, tuple[int, int], Image.Image]:
x, y = TASK_OPTIMAL_COORDS[image_idx]
heatmap_path = f"imgs/heatmap_{image_idx}.png"
marked_heatmap = create_marker_overlay(heatmap_path, x, y)
output_img = generate_output_image(image_idx, (x, y))
return marked_heatmap, (x, y), output_img
with gr.Blocks(
css="""
.container {
max-width: 1000px !important;
width: 100% !important;
margin-left: auto !important;
margin-right: auto !important;
padding: 0 1rem !important;
}
.diagram-container {
width: 80% !important;
max-width: 800px !important;
margin-left: auto !important;
margin-right: auto !important;
}
.diagram-container img {
width: 100% !important;
height: auto !important;
object-fit: contain !important;
}
.radio-container {
width: 100% !important;
max-width: 450px !important;
margin-left: auto !important;
margin-right: auto !important;
}
.image-container {
width: 100% !important;
aspect-ratio: 1 !important;
max-width: 450px !important;
margin-left: auto !important;
margin-right: auto !important;
}
.image-container img {
width: 100% !important;
height: 100% !important;
object-fit: contain !important;
}
.coordinate-container {
width: 100% !important;
aspect-ratio: 1 !important;
position: relative !important;
max-width: 600px !important;
margin-left: auto !important;
margin-right: auto !important;
}
.coordinate-container img {
width: 100% !important;
height: 100% !important;
object-fit: contain !important;
}
.button-container {
display: flex !important;
justify-content: center !important;
width: 100% !important;
max-width: 600px !important;
margin-left: auto !important;
margin-right: auto !important;
}
.documentation {
margin-top: 2rem !important;
padding: 1rem !important;
background-color: #f8f9fa !important;
border-radius: 8px !important;
}
.optimal-button {
margin-top: 1rem !important;
margin-bottom: 1rem !important;
width: 200px !important;
}
"""
) as demo:
with gr.Column(elem_classes="container"):
gr.Markdown(
"""
# Interactive Image Generation
## Method Overview
This interactive demo showcases our novel image generation method that uses coordinate-based control.
The process allows precise control over generated patterns through a coordinate-conditioning mechanism.
"""
)
with gr.Column(elem_classes="diagram-container"):
gr.Image(
value="imgs/lpn_diagram.png",
show_label=False,
interactive=False,
show_download_button=False,
show_fullscreen_button=False,
)
gr.Markdown(
"""
### How to Use
1. Choose a pattern generation task using the radio buttons
2. View the target pattern for your selected task
3. Click anywhere in the heatmap to specify coordinates in the latent space
4. See the generated image based on your selection
Use the "Find Optimal Latent" button to automatically select pre-determined optimal coordinates.
"""
)
with gr.Row():
with gr.Column(scale=3):
selected_idx = gr.State(value=0)
coords = gr.State()
with gr.Column(elem_classes="radio-container"):
task_select = gr.Radio(
choices=["Task 1", "Task 2", "Task 3", "Task 4"],
value="Task 1",
label="Select Task",
interactive=True,
)
gr.Markdown("### Reference Pattern")
with gr.Column(elem_classes="image-container"):
reference_image = gr.Image(
value="imgs/pattern_0.png",
show_label=False,
interactive=False,
show_download_button=False,
show_fullscreen_button=False,
)
gr.Markdown("### Generated Output")
with gr.Column(elem_classes="image-container"):
output_image = gr.Image(
show_label=False,
interactive=False,
show_download_button=False,
show_fullscreen_button=False,
)
with gr.Column(scale=4):
gr.Markdown("### Coordinate Selector")
gr.Markdown(
"Click anywhere in the image below to select (x, y) coordinates in the latent space"
)
with gr.Column(elem_classes="coordinate-container"):
coord_selector = gr.Image(
value="imgs/heatmap_0.png",
show_label=False,
interactive=False,
sources=[],
container=True,
show_download_button=False,
show_fullscreen_button=False,
)
with gr.Column(elem_classes="button-container"):
optimal_button = gr.Button("Find Optimal Latent", elem_classes="optimal-button")
with gr.Column(elem_classes="documentation"):
gr.Markdown(
"""
### Technical Details
Our approach uses a novel coordinate-conditioning mechanism that allows precise control over the generated patterns.
The heatmap visualization shows the distribution of pattern characteristics across the latent space.
For more information, please refer to our [paper](https://arxiv.org/pdf/2411.08706) or GitHub [repository](https://github.com/clement-bonnet/lpn).
"""
)
# Event handlers
task_select.change(
fn=lambda x: update_reference_image(TASK_TO_INDEX[x]),
inputs=[task_select],
outputs=[reference_image, selected_idx, coord_selector],
)
coord_selector.select(
fn=update_marker,
inputs=[selected_idx],
outputs=[coord_selector, coords],
trigger_mode="multiple",
).then(
fn=generate_output_image,
inputs=[selected_idx, coords],
outputs=output_image,
)
optimal_button.click(
fn=find_optimal_latent,
inputs=[selected_idx],
outputs=[coord_selector, coords, output_image],
)
demo.launch()