File size: 1,045 Bytes
d58b45f
 
 
 
252d51a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d58b45f
252d51a
 
 
 
 
d58b45f
252d51a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import gradio as gr
import jax
import jax.numpy as jnp

# Dummy diffusion generation function (replace with your real one)
def generate_diffusion(prompt, steps):
    key = jax.random.PRNGKey(0)
    # For demo: Create random noise image
    image = jax.random.uniform(key, (64, 64, 3))
    image = jnp.clip(image, 0, 1)
    return image

# Gradio Interface using Blocks
with gr.Blocks() as demo:
    gr.Markdown("# 🌟 JAX Diffusion Demo")
    gr.Markdown("Generate images using a simple diffusion model powered by **JAX**!")

    with gr.Row():
        with gr.Column():
            prompt_input = gr.Textbox(label="Prompt", placeholder="Describe your image...")
            steps_input = gr.Slider(10, 100, value=50, step=5, label="Diffusion Steps")
            generate_button = gr.Button("Generate")

        with gr.Column():
            output_image = gr.Image(label="Generated Image")

    generate_button.click(
        fn=generate_diffusion,
        inputs=[prompt_input, steps_input],
        outputs=output_image
    )

demo.launch()