jax-diffusion / app.py
carrycooldude's picture
Update app.py
252d51a
raw
history blame
1.05 kB
import gradio as gr
import jax
import jax.numpy as jnp
# Dummy diffusion generation function (replace with your real one)
def generate_diffusion(prompt, steps):
key = jax.random.PRNGKey(0)
# For demo: Create random noise image
image = jax.random.uniform(key, (64, 64, 3))
image = jnp.clip(image, 0, 1)
return image
# Gradio Interface using Blocks
with gr.Blocks() as demo:
gr.Markdown("# 🌟 JAX Diffusion Demo")
gr.Markdown("Generate images using a simple diffusion model powered by **JAX**!")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(label="Prompt", placeholder="Describe your image...")
steps_input = gr.Slider(10, 100, value=50, step=5, label="Diffusion Steps")
generate_button = gr.Button("Generate")
with gr.Column():
output_image = gr.Image(label="Generated Image")
generate_button.click(
fn=generate_diffusion,
inputs=[prompt_input, steps_input],
outputs=output_image
)
demo.launch()