Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from diffusers import StableDiffusionPipeline | |
from huggingface_hub import hf_hub_download | |
import numpy as np | |
from PIL import Image | |
import os | |
import gc | |
# Suppress symlink warnings | |
os.environ['HF_HUB_DISABLE_SYMLINKS_WARNING'] = "1" | |
# Define styles | |
styles = { | |
"glitch": { | |
"concept_url": "sd-concepts-library/001glitch-core", | |
"seed": 42, | |
"token": "<glitch-core>" | |
}, | |
"roth": { | |
"concept_url": "sd-concepts-library/2814-roth", | |
"seed": 123, | |
"token": "<2814-roth>" | |
}, | |
"night": { | |
"concept_url": "sd-concepts-library/4tnght", | |
"seed": 456, | |
"token": "<4tnght>" | |
}, | |
"anime80s": { | |
"concept_url": "sd-concepts-library/80s-anime-ai", | |
"seed": 789, | |
"token": "<80s-anime>" | |
}, | |
"animeai": { | |
"concept_url": "sd-concepts-library/80s-anime-ai-being", | |
"seed": 1024, | |
"token": "<80s-anime-being>" | |
} | |
} | |
# Pre-generate example images | |
example_images = { | |
"glitch": "examples/glitch_example.jpg", | |
"anime80s": "examples/anime80s_example.jpg", | |
"night": "examples/night_example.jpg" | |
} | |
def load_pipeline(): | |
"""Load and prepare the pipeline with all style embeddings""" | |
# Check if CUDA is available | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
dtype = torch.float16 if device == "cuda" else torch.float32 | |
# Use smaller model for CPU | |
model_id = "runwayml/stable-diffusion-v1-5" if device == "cuda" else "CompVis/stable-diffusion-v1-4" | |
pipe = StableDiffusionPipeline.from_pretrained( | |
model_id, | |
torch_dtype=dtype, | |
low_cpu_mem_usage=True | |
).to(device) | |
# Load all embeddings | |
for style_info in styles.values(): | |
embedding_path = hf_hub_download( | |
repo_id=style_info["concept_url"], | |
filename="learned_embeds.bin", | |
repo_type="model" | |
) | |
pipe.load_textual_inversion(embedding_path) | |
return pipe | |
def apply_purple_guidance(image, strength=0.5): | |
"""Apply purple guidance to an image""" | |
img_array = np.array(image).astype(float) | |
purple_mask = (img_array[:,:,0] > 100) & (img_array[:,:,2] > 100) | |
img_array[purple_mask] = img_array[purple_mask] * (1 - strength) + np.array([128, 0, 128]) * strength | |
return Image.fromarray(np.uint8(img_array.clip(0, 255))) | |
def generate_image(prompt, style, seed, apply_guidance, guidance_strength=0.5): | |
"""Generate an image with selected style and optional purple guidance""" | |
# Check if this is one of our examples with pre-generated images | |
if prompt == "A serene mountain landscape with a lake at sunset" and style == "glitch" and seed == 42: | |
if os.path.exists(example_images["glitch"]): | |
image = Image.open(example_images["glitch"]) | |
if apply_guidance: | |
image = apply_purple_guidance(image, guidance_strength) | |
return image | |
if prompt == "A magical forest at twilight" and style == "anime80s" and seed == 789: | |
if os.path.exists(example_images["anime80s"]): | |
image = Image.open(example_images["anime80s"]) | |
if apply_guidance: | |
image = apply_purple_guidance(image, guidance_strength) | |
return image | |
if prompt == "A cyberpunk city at night" and style == "night" and seed == 456: | |
if os.path.exists(example_images["night"]): | |
image = Image.open(example_images["night"]) | |
if apply_guidance: | |
image = apply_purple_guidance(image, guidance_strength) | |
return image | |
if style not in styles: | |
return None | |
# Get style info | |
style_info = styles[style] | |
# Prepare generator with appropriate device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
generator = torch.Generator(device).manual_seed(int(seed)) | |
# Create styled prompt | |
styled_prompt = f"{prompt} {style_info['token']}" | |
# Generate image with reduced settings for CPU | |
if device == "cpu": | |
# Use much smaller image size and fewer steps on CPU | |
image = pipe( | |
styled_prompt, | |
generator=generator, | |
guidance_scale=7.5, | |
num_inference_steps=10, # Reduced steps | |
height=256, # Smaller height | |
width=256 # Smaller width | |
).images[0] | |
else: | |
image = pipe( | |
styled_prompt, | |
generator=generator, | |
guidance_scale=7.5, | |
num_inference_steps=20 | |
).images[0] | |
# Apply purple guidance if requested | |
if apply_guidance: | |
image = apply_purple_guidance(image, guidance_strength) | |
# Clean up memory | |
gc.collect() | |
if device == "cuda": | |
torch.cuda.empty_cache() | |
return image | |
# Create examples directory | |
os.makedirs("examples", exist_ok=True) | |
# Initialize the pipeline globally | |
print("Loading pipeline and embeddings...") | |
pipe = load_pipeline() | |
# Create the Gradio interface | |
demo = gr.Interface( | |
fn=generate_image, | |
inputs=[ | |
gr.Textbox(label="Prompt", value="A serene mountain landscape with a lake at sunset"), | |
gr.Radio(choices=list(styles.keys()), label="Style", value="glitch"), | |
gr.Number(label="Seed", value=42), | |
gr.Checkbox(label="Apply Purple Guidance", value=False), | |
gr.Slider(minimum=0.0, maximum=1.0, value=0.5, label="Purple Guidance Strength") | |
], | |
outputs=gr.Image(label="Generated Image"), | |
title="Style-Guided Image Generation with Purple Enhancement", | |
description="""Generate images in different styles with optional purple color guidance. | |
Choose a style, enter a prompt, and optionally apply purple color enhancement. | |
Note: Generation may take a few minutes on CPU.""", | |
examples=[ | |
["A serene mountain landscape with a lake at sunset", "glitch", 42, True, 0.5], | |
["A magical forest at twilight", "anime80s", 789, True, 0.7], | |
["A cyberpunk city at night", "night", 456, False, 0.5], | |
], | |
cache_examples=True, | |
allow_flagging="never" # Disable flagging to reduce overhead | |
) | |
if __name__ == "__main__": | |
# Generate and save example images if they don't exist | |
if not all(os.path.exists(path) for path in example_images.values()): | |
print("Pre-generating example images...") | |
# Example 1 | |
if not os.path.exists(example_images["glitch"]): | |
img = generate_image("A serene mountain landscape with a lake at sunset", "glitch", 42, False, 0.5) | |
img.save(example_images["glitch"]) | |
# Example 2 | |
if not os.path.exists(example_images["anime80s"]): | |
img = generate_image("A magical forest at twilight", "anime80s", 789, False, 0.7) | |
img.save(example_images["anime80s"]) | |
# Example 3 | |
if not os.path.exists(example_images["night"]): | |
img = generate_image("A cyberpunk city at night", "night", 456, False, 0.5) | |
img.save(example_images["night"]) | |
# Launch the app | |
demo.launch(share=False, show_error=True) |