# app.py import gradio as gr import jax import jax.numpy as jnp import numpy as np from PIL import Image # Dummy generator function — Replace this with your real model inference! def generate_image(seed): key = jax.random.PRNGKey(seed) # Generate a fake "image" of size 64x64x3 (RGB) img = jax.random.uniform(key, (64, 64, 3), minval=0, maxval=1.0) img_np = np.array(img * 255, dtype=np.uint8) return Image.fromarray(img_np) # Define Gradio Interface iface = gr.Interface( fn=generate_image, inputs=gr.Slider(0, 10000, value=42, step=1, label="Random Seed"), outputs=gr.Image(type="pil", label="Generated Image"), title="JAX Diffusion Demo", description="🎨 Generate random diffusion samples using JAX! \n\n(Replace dummy function with your trained model.)", theme="default", live=False ) if __name__ == "__main__": iface.launch()