Winner_Take_All / app.py
Taha Razzaq
minor update
0f66727
import gradio as gr
from PIL import Image, ImageOps
import numpy as np
from tibblingai import wta
import matplotlib.pyplot as plt # needed for plt.imread
def load_sample_images():
sample_paths = ["sample1.jpg", "sample2.jpg"] # Must be in the same folder as your script
original_imgs = [Image.open(path) for path in sample_paths]
processed_imgs = [process_image(path) for path in sample_paths]
return gr.update(visible=True, value=original_imgs), gr.update(visible=True, value=processed_imgs)
# Image processing function (you can replace this)
def process_image(img: Image.Image) -> Image.Image:
# Read image using matplotlib
img_np = plt.imread(img)
if img_np.ndim == 3 and img_np.shape[2] > 3:
img_np = img_np[:, :, :3]
# Convert to grayscale for original image display
gray_img = ImageOps.grayscale(Image.fromarray((img_np * 255).astype(np.uint8)))
# Run your WTA processing (dummy if not available)
# Replace this line with actual WTA processing
img_tensor = wta.wta(img_np).numpy() # Assuming returns shape (1, H, W)
# Convert processed image to inferno colormap
inferno_colored = plt.cm.inferno(img_tensor[0])
inferno_img = Image.fromarray((inferno_colored[:, :, :3] * 255).astype(np.uint8))
return gray_img, inferno_img
# Function to process uploaded images
def process_images(images):
if not isinstance(images, list):
images = [images]
original_imgs = []
processed_imgs = []
for img in images:
original_gray_scale_img, inferno_img = process_image(img)
original_imgs.append(original_gray_scale_img)
processed_imgs.append(inferno_img)
return gr.update(visible=True, value=original_imgs), gr.update(visible=True, value=processed_imgs)
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("## Upload Image(s) for Processing")
# File component (no preview shown here)
file_input = gr.File(file_types=["image"],file_count="multiple", label="Upload Images")
with gr.Row():
original_gallery = gr.Gallery(label="Original Images", visible=False)
processed_gallery = gr.Gallery(label="Processed Images", visible=False)
file_input.change(fn=process_images, inputs=file_input, outputs=[original_gallery, processed_gallery])
with gr.Row():
with gr.Column(scale=0, min_width=300):
load_examples_btn = gr.Button("Run WTA on sample images")
load_examples_btn.click(fn=load_sample_images, outputs=[original_gallery, processed_gallery])
# πŸ†• Examples section
# examples = gr.Examples(
# examples=[["sample1.jpg"], ["sample2.jpg"]],
# inputs=file_input,
# label="Example Images"
# )
# example_files = [["sample1.jpg", "sample2.jpg"], ["sample2.jpg"]]
# gr.Examples(examples=example_files, inputs=[file_input], label="Try one of our example samples")
demo.launch(debug=True)