pradeep6kumar2024's picture
updated
be16ca0
import gradio as gr
import torch
from diffusers import StableDiffusionPipeline
from huggingface_hub import hf_hub_download
import numpy as np
from PIL import Image
import os
import gc
# Suppress symlink warnings
os.environ['HF_HUB_DISABLE_SYMLINKS_WARNING'] = "1"
# Define styles
styles = {
"glitch": {
"concept_url": "sd-concepts-library/001glitch-core",
"seed": 42,
"token": "<glitch-core>"
},
"roth": {
"concept_url": "sd-concepts-library/2814-roth",
"seed": 123,
"token": "<2814-roth>"
},
"night": {
"concept_url": "sd-concepts-library/4tnght",
"seed": 456,
"token": "<4tnght>"
},
"anime80s": {
"concept_url": "sd-concepts-library/80s-anime-ai",
"seed": 789,
"token": "<80s-anime>"
},
"animeai": {
"concept_url": "sd-concepts-library/80s-anime-ai-being",
"seed": 1024,
"token": "<80s-anime-being>"
}
}
# Pre-generate example images
example_images = {
"glitch": "examples/glitch_example.jpg",
"anime80s": "examples/anime80s_example.jpg",
"night": "examples/night_example.jpg"
}
def load_pipeline():
"""Load and prepare the pipeline with all style embeddings"""
# Check if CUDA is available
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
# Use smaller model for CPU
model_id = "runwayml/stable-diffusion-v1-5" if device == "cuda" else "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=dtype,
low_cpu_mem_usage=True
).to(device)
# Load all embeddings
for style_info in styles.values():
embedding_path = hf_hub_download(
repo_id=style_info["concept_url"],
filename="learned_embeds.bin",
repo_type="model"
)
pipe.load_textual_inversion(embedding_path)
return pipe
def apply_purple_guidance(image, strength=0.5):
"""Apply purple guidance to an image"""
img_array = np.array(image).astype(float)
purple_mask = (img_array[:,:,0] > 100) & (img_array[:,:,2] > 100)
img_array[purple_mask] = img_array[purple_mask] * (1 - strength) + np.array([128, 0, 128]) * strength
return Image.fromarray(np.uint8(img_array.clip(0, 255)))
def generate_image(prompt, style, seed, apply_guidance, guidance_strength=0.5):
"""Generate an image with selected style and optional purple guidance"""
# Check if this is one of our examples with pre-generated images
if prompt == "A serene mountain landscape with a lake at sunset" and style == "glitch" and seed == 42:
if os.path.exists(example_images["glitch"]):
image = Image.open(example_images["glitch"])
if apply_guidance:
image = apply_purple_guidance(image, guidance_strength)
return image
if prompt == "A magical forest at twilight" and style == "anime80s" and seed == 789:
if os.path.exists(example_images["anime80s"]):
image = Image.open(example_images["anime80s"])
if apply_guidance:
image = apply_purple_guidance(image, guidance_strength)
return image
if prompt == "A cyberpunk city at night" and style == "night" and seed == 456:
if os.path.exists(example_images["night"]):
image = Image.open(example_images["night"])
if apply_guidance:
image = apply_purple_guidance(image, guidance_strength)
return image
if style not in styles:
return None
# Get style info
style_info = styles[style]
# Prepare generator with appropriate device
device = "cuda" if torch.cuda.is_available() else "cpu"
generator = torch.Generator(device).manual_seed(int(seed))
# Create styled prompt
styled_prompt = f"{prompt} {style_info['token']}"
# Generate image with reduced settings for CPU
if device == "cpu":
# Use much smaller image size and fewer steps on CPU
image = pipe(
styled_prompt,
generator=generator,
guidance_scale=7.5,
num_inference_steps=10, # Reduced steps
height=256, # Smaller height
width=256 # Smaller width
).images[0]
else:
image = pipe(
styled_prompt,
generator=generator,
guidance_scale=7.5,
num_inference_steps=20
).images[0]
# Apply purple guidance if requested
if apply_guidance:
image = apply_purple_guidance(image, guidance_strength)
# Clean up memory
gc.collect()
if device == "cuda":
torch.cuda.empty_cache()
return image
# Create examples directory
os.makedirs("examples", exist_ok=True)
# Initialize the pipeline globally
print("Loading pipeline and embeddings...")
pipe = load_pipeline()
# Create the Gradio interface
demo = gr.Interface(
fn=generate_image,
inputs=[
gr.Textbox(label="Prompt", value="A serene mountain landscape with a lake at sunset"),
gr.Radio(choices=list(styles.keys()), label="Style", value="glitch"),
gr.Number(label="Seed", value=42),
gr.Checkbox(label="Apply Purple Guidance", value=False),
gr.Slider(minimum=0.0, maximum=1.0, value=0.5, label="Purple Guidance Strength")
],
outputs=gr.Image(label="Generated Image"),
title="Style-Guided Image Generation with Purple Enhancement",
description="""Generate images in different styles with optional purple color guidance.
Choose a style, enter a prompt, and optionally apply purple color enhancement.
Note: Generation may take a few minutes on CPU.""",
examples=[
["A serene mountain landscape with a lake at sunset", "glitch", 42, True, 0.5],
["A magical forest at twilight", "anime80s", 789, True, 0.7],
["A cyberpunk city at night", "night", 456, False, 0.5],
],
cache_examples=True,
allow_flagging="never" # Disable flagging to reduce overhead
)
if __name__ == "__main__":
# Generate and save example images if they don't exist
if not all(os.path.exists(path) for path in example_images.values()):
print("Pre-generating example images...")
# Example 1
if not os.path.exists(example_images["glitch"]):
img = generate_image("A serene mountain landscape with a lake at sunset", "glitch", 42, False, 0.5)
img.save(example_images["glitch"])
# Example 2
if not os.path.exists(example_images["anime80s"]):
img = generate_image("A magical forest at twilight", "anime80s", 789, False, 0.7)
img.save(example_images["anime80s"])
# Example 3
if not os.path.exists(example_images["night"]):
img = generate_image("A cyberpunk city at night", "night", 456, False, 0.5)
img.save(example_images["night"])
# Launch the app
demo.launch(share=False, show_error=True)