|
import streamlit as st |
|
import requests |
|
import base64 |
|
from io import BytesIO |
|
from PIL import Image |
|
import time |
|
|
|
|
|
st.title("SDXL Image Generation Wrapper") |
|
st.markdown(""" |
|
This Streamlit app wraps the Google SDXL Gradio app to generate images from text prompts. |
|
Enter a prompt, select a style, and adjust settings to create high-quality images. |
|
""") |
|
|
|
|
|
BACKEND_URL = "https://google-sdxl.hf.space/api/predict/" |
|
|
|
|
|
STYLE_OPTIONS = [ |
|
"(No style)", "Cinematic", "Photographic", "Anime", "Manga", |
|
"Digital Art", "Pixel art", "Fantasy art", "Neonpunk", "3D Model" |
|
] |
|
|
|
|
|
def apply_style(style_name, prompt, negative): |
|
style_dict = { |
|
"(No style)": ("{prompt}", ""), |
|
"Cinematic": ( |
|
"cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy", |
|
"anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured" |
|
), |
|
"Photographic": ( |
|
"cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed", |
|
"drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly" |
|
), |
|
"Anime": ( |
|
"anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed", |
|
"photo, deformed, black and white, realism, disfigured, low contrast" |
|
), |
|
"Manga": ( |
|
"manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style", |
|
"ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style" |
|
), |
|
"Digital Art": ( |
|
"concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed", |
|
"photo, photorealistic, realism, ugly" |
|
), |
|
"Pixel art": ( |
|
"pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics", |
|
"sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic" |
|
), |
|
"Fantasy art": ( |
|
"ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy", |
|
"photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white" |
|
), |
|
"Neonpunk": ( |
|
"neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional", |
|
"painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured" |
|
), |
|
"3D Model": ( |
|
"professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting", |
|
"ugly, deformed, noisy, low poly, blurry, painting" |
|
) |
|
} |
|
positive, neg = style_dict.get(style_name, style_dict["(No style)"]) |
|
return positive.replace("{prompt}", prompt), neg + (negative or "") |
|
|
|
|
|
def generate_images(prompt, negative_prompt, guidance_scale, style_name): |
|
try: |
|
|
|
styled_prompt, styled_negative = apply_style(style_name, prompt, negative_prompt) |
|
|
|
|
|
payload = { |
|
"data": [ |
|
styled_prompt, |
|
styled_negative, |
|
guidance_scale, |
|
style_name |
|
] |
|
} |
|
|
|
|
|
start_time = time.time() |
|
response = requests.post(BACKEND_URL, json=payload, timeout=60) |
|
elapsed_time = time.time() - start_time |
|
|
|
|
|
if response.status_code != 200: |
|
st.error(f"Error: Received status code {response.status_code}") |
|
return None |
|
|
|
|
|
json_data = response.json() |
|
if "data" not in json_data or not json_data["data"]: |
|
st.error("Error: No images returned from the backend.") |
|
return None |
|
|
|
|
|
images = [] |
|
for img_data in json_data["data"][0]: |
|
if isinstance(img_data, str) and img_data.startswith("data:image"): |
|
img_base64 = img_data.split(",")[1] |
|
img_bytes = base64.b64decode(img_base64) |
|
img = Image.open(BytesIO(img_bytes)) |
|
images.append(img) |
|
else: |
|
st.warning("Unexpected image data format.") |
|
|
|
st.success(f"Images generated in {elapsed_time:.2f} seconds!") |
|
return images |
|
|
|
except requests.exceptions.RequestException as e: |
|
st.error(f"Network error: {str(e)}") |
|
return None |
|
except ValueError as e: |
|
st.error(f"Error decoding response: {str(e)}") |
|
return None |
|
except Exception as e: |
|
st.error(f"Unexpected error: {str(e)}") |
|
return None |
|
|
|
|
|
with st.form(key="input_form"): |
|
prompt = st.text_input("Enter your prompt", placeholder="A serious capybara at work, wearing a suit") |
|
negative_prompt = st.text_input("Negative prompt (optional)", placeholder="low_quality") |
|
guidance_scale = st.slider("Guidance Scale", min_value=0.0, max_value=50.0, value=7.5, step=0.1) |
|
style_name = st.selectbox("Image Style", options=STYLE_OPTIONS, index=0) |
|
submit_button = st.form_submit_button("Generate Images") |
|
|
|
|
|
if submit_button: |
|
if not prompt: |
|
st.error("Please enter a prompt.") |
|
else: |
|
with st.spinner("Generating images..."): |
|
images = generate_images(prompt, negative_prompt, guidance_scale, style_name) |
|
if images: |
|
|
|
st.subheader("Generated Images") |
|
cols = st.columns(min(len(images), 3)) |
|
for idx, img in enumerate(images): |
|
with cols[idx % 3]: |
|
st.image(img, use_column_width=True) |