import gradio as gr from PIL import Image, ImageOps import numpy as np from tibblingai import wta import matplotlib.pyplot as plt # needed for plt.imread def load_sample_images(): sample_paths = ["sample1.jpg", "sample2.jpg"] # Must be in the same folder as your script original_imgs = [Image.open(path) for path in sample_paths] processed_imgs = [process_image(path) for path in sample_paths] return gr.update(visible=True, value=original_imgs), gr.update(visible=True, value=processed_imgs) # Image processing function (you can replace this) def process_image(img: Image.Image) -> Image.Image: # Read image using matplotlib img_np = plt.imread(img) if img_np.ndim == 3 and img_np.shape[2] > 3: img_np = img_np[:, :, :3] # Convert to grayscale for original image display gray_img = ImageOps.grayscale(Image.fromarray((img_np * 255).astype(np.uint8))) # Run your WTA processing (dummy if not available) # Replace this line with actual WTA processing img_tensor = wta.wta(img_np).numpy() # Assuming returns shape (1, H, W) # Convert processed image to inferno colormap inferno_colored = plt.cm.inferno(img_tensor[0]) inferno_img = Image.fromarray((inferno_colored[:, :, :3] * 255).astype(np.uint8)) return gray_img, inferno_img # Function to process uploaded images def process_images(images): if not isinstance(images, list): images = [images] original_imgs = [] processed_imgs = [] for img in images: original_gray_scale_img, inferno_img = process_image(img) original_imgs.append(original_gray_scale_img) processed_imgs.append(inferno_img) return gr.update(visible=True, value=original_imgs), gr.update(visible=True, value=processed_imgs) # Gradio Interface with gr.Blocks() as demo: gr.Markdown("## Upload Image(s) for Processing") # File component (no preview shown here) file_input = gr.File(file_types=["image"],file_count="multiple", label="Upload Images") with gr.Row(): original_gallery = gr.Gallery(label="Original Images", visible=False) processed_gallery = gr.Gallery(label="Processed Images", visible=False) file_input.change(fn=process_images, inputs=file_input, outputs=[original_gallery, processed_gallery]) with gr.Row(): with gr.Column(scale=0, min_width=300): load_examples_btn = gr.Button("Run WTA on sample images") load_examples_btn.click(fn=load_sample_images, outputs=[original_gallery, processed_gallery]) # 🆕 Examples section # examples = gr.Examples( # examples=[["sample1.jpg"], ["sample2.jpg"]], # inputs=file_input, # label="Example Images" # ) # example_files = [["sample1.jpg", "sample2.jpg"], ["sample2.jpg"]] # gr.Examples(examples=example_files, inputs=[file_input], label="Try one of our example samples") demo.launch(debug=True)