awacke1's picture
Create app.py
9c99365 verified
raw
history blame
6.69 kB
import streamlit as st
import requests
import base64
from io import BytesIO
from PIL import Image
import time
# Streamlit app title and description
st.title("SDXL Image Generation Wrapper")
st.markdown("""
This Streamlit app wraps the Google SDXL Gradio app to generate images from text prompts.
Enter a prompt, select a style, and adjust settings to create high-quality images.
""")
# Backend URL of the Gradio app (replace with actual endpoint if different)
BACKEND_URL = "https://google-sdxl.hf.space/api/predict/"
# Style options (mirroring the Gradio app's style_list)
STYLE_OPTIONS = [
"(No style)", "Cinematic", "Photographic", "Anime", "Manga",
"Digital Art", "Pixel art", "Fantasy art", "Neonpunk", "3D Model"
]
# Function to apply style (replicating the Gradio app's apply_style logic)
def apply_style(style_name, prompt, negative):
style_dict = {
"(No style)": ("{prompt}", ""),
"Cinematic": (
"cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
"anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured"
),
"Photographic": (
"cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed",
"drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly"
),
"Anime": (
"anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
"photo, deformed, black and white, realism, disfigured, low contrast"
),
"Manga": (
"manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style",
"ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style"
),
"Digital Art": (
"concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed",
"photo, photorealistic, realism, ugly"
),
"Pixel art": (
"pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics",
"sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic"
),
"Fantasy art": (
"ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
"photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white"
),
"Neonpunk": (
"neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
"painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured"
),
"3D Model": (
"professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
"ugly, deformed, noisy, low poly, blurry, painting"
)
}
positive, neg = style_dict.get(style_name, style_dict["(No style)"])
return positive.replace("{prompt}", prompt), neg + (negative or "")
# Function to call the Gradio backend
def generate_images(prompt, negative_prompt, guidance_scale, style_name):
try:
# Apply style to prompt and negative prompt
styled_prompt, styled_negative = apply_style(style_name, prompt, negative_prompt)
# Prepare payload (mimicking Gradio's infer function)
payload = {
"data": [
styled_prompt,
styled_negative,
guidance_scale,
style_name
]
}
# Send request to Gradio backend
start_time = time.time()
response = requests.post(BACKEND_URL, json=payload, timeout=60)
elapsed_time = time.time() - start_time
# Check response
if response.status_code != 200:
st.error(f"Error: Received status code {response.status_code}")
return None
# Parse response
json_data = response.json()
if "data" not in json_data or not json_data["data"]:
st.error("Error: No images returned from the backend.")
return None
# Extract images (assuming base64 strings)
images = []
for img_data in json_data["data"][0]: # Adjust based on actual response structure
if isinstance(img_data, str) and img_data.startswith("data:image"):
img_base64 = img_data.split(",")[1]
img_bytes = base64.b64decode(img_base64)
img = Image.open(BytesIO(img_bytes))
images.append(img)
else:
st.warning("Unexpected image data format.")
st.success(f"Images generated in {elapsed_time:.2f} seconds!")
return images
except requests.exceptions.RequestException as e:
st.error(f"Network error: {str(e)}")
return None
except ValueError as e:
st.error(f"Error decoding response: {str(e)}")
return None
except Exception as e:
st.error(f"Unexpected error: {str(e)}")
return None
# Streamlit UI components
with st.form(key="input_form"):
prompt = st.text_input("Enter your prompt", placeholder="A serious capybara at work, wearing a suit")
negative_prompt = st.text_input("Negative prompt (optional)", placeholder="low_quality")
guidance_scale = st.slider("Guidance Scale", min_value=0.0, max_value=50.0, value=7.5, step=0.1)
style_name = st.selectbox("Image Style", options=STYLE_OPTIONS, index=0)
submit_button = st.form_submit_button("Generate Images")
# Handle form submission
if submit_button:
if not prompt:
st.error("Please enter a prompt.")
else:
with st.spinner("Generating images..."):
images = generate_images(prompt, negative_prompt, guidance_scale, style_name)
if images:
# Display images in a gallery
st.subheader("Generated Images")
cols = st.columns(min(len(images), 3)) # Adjust columns based on number of images
for idx, img in enumerate(images):
with cols[idx % 3]:
st.image(img, use_column_width=True)