Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import gradio as gr | |
from PIL import Image | |
from diffusers import StableDiffusionPipeline, DiffusionPipeline | |
from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel | |
from tqdm.auto import tqdm | |
import torchvision.transforms as T | |
import torch.nn.functional as F | |
import gc | |
import signal | |
import time | |
import traceback | |
# Configure constants - optimized for CPU | |
HEIGHT, WIDTH = 384, 384 # Smaller images use less memory | |
GUIDANCE_SCALE = 7.5 | |
LOSS_SCALE = 200 | |
NUM_INFERENCE_STEPS = 30 # Reduced from 50 | |
BATCH_SIZE = 1 | |
DEFAULT_PROMPT = "A deadly witcher slinging a sword with a lion medallion in his neck, casting a fire spell from his hand in a snowy forest" | |
# Define the device | |
TORCH_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" | |
print(f"Using device: {TORCH_DEVICE}") | |
# Initialize the elastic transformer | |
elastic_transformer = T.ElasticTransform(alpha=550.0, sigma=5.0) | |
# Timeout handler for CPU processing | |
def timeout_handler(signum, frame): | |
raise TimeoutError("Image generation took too long") | |
# Load the model | |
def load_model(): | |
try: | |
# Initialize signal handler only on Unix-like systems | |
if TORCH_DEVICE == "cpu" and hasattr(signal, 'SIGALRM'): | |
signal.signal(signal.SIGALRM, timeout_handler) | |
signal.alarm(2100) # 15 minutes timeout for model loading | |
pipe = DiffusionPipeline.from_pretrained( | |
"CompVis/stable-diffusion-v1-4", | |
torch_dtype=torch.float16 if TORCH_DEVICE == "cuda" else torch.float32, | |
safety_checker=None, # Disable safety checker for memory | |
low_cpu_mem_usage=True # Enable memory optimization | |
).to(TORCH_DEVICE) | |
# Load textual inversion for all devices including CPU | |
try: | |
# Load one at a time with memory cleanup between each | |
concepts = [ | |
"sd-concepts-library/rimworld-art-style", | |
"sd-concepts-library/hk-goldenlantern", | |
"sd-concepts-library/phoenix-01", | |
"sd-concepts-library/fractal-flame", | |
"sd-concepts-library/scarlet-witch" | |
] | |
for concept in concepts: | |
try: | |
print(f"Loading textual inversion concept: {concept}") | |
pipe.load_textual_inversion(concept, mean_resizing=False) | |
# Clear memory after loading each concept | |
if TORCH_DEVICE == "cpu": | |
gc.collect() | |
except Exception as e: | |
print(f"Warning: Could not load textual inversion concept {concept}: {e}") | |
except Exception as e: | |
print(f"Warning: Could not load textual inversion concepts: {e}") | |
# Clear the alarm if set | |
if TORCH_DEVICE == "cpu" and hasattr(signal, 'SIGALRM'): | |
signal.alarm(0) | |
return pipe | |
except Exception as e: | |
# Clear the alarm if set | |
if TORCH_DEVICE == "cpu" and hasattr(signal, 'SIGALRM'): | |
signal.alarm(0) | |
print(f"Error loading model: {e}") | |
traceback.print_exc() | |
raise | |
# Helper functions | |
def image_grid(imgs, rows, cols): | |
assert len(imgs) == rows*cols | |
w, h = imgs[0].size | |
grid = Image.new('RGB', size=(cols*w, rows*h)) | |
for i, img in enumerate(imgs): | |
grid.paste(img, box=(i%cols*w, i//cols*h)) | |
return grid | |
def image_loss(images, loss_type): | |
if loss_type == 'blue': | |
# blue loss | |
error = torch.abs(images[:,2] - 0.9).mean() | |
elif loss_type == 'elastic': | |
# elastic loss | |
transformed_imgs = elastic_transformer(images) | |
error = torch.abs(transformed_imgs - images).mean() | |
elif loss_type == 'symmetry': | |
flipped_image = torch.flip(images, [3]) | |
error = F.mse_loss(images, flipped_image) | |
elif loss_type == 'saturation': | |
# saturation loss | |
transformed_imgs = T.functional.adjust_saturation(images, saturation_factor=10) | |
error = torch.abs(transformed_imgs - images).mean() | |
else: | |
print("Error. Loss not defined") | |
error = torch.tensor(0.0) | |
return error | |
def latents_to_pil(latents, pipe): | |
# batch of latents -> list of images | |
latents = (1 / 0.18215) * latents | |
with torch.no_grad(): | |
image = pipe.vae.decode(latents).sample | |
image = (image / 2 + 0.5).clamp(0, 1) | |
image = image.detach().cpu().permute(0, 2, 3, 1).numpy() | |
images = (image * 255).round().astype("uint8") | |
pil_images = [Image.fromarray(image) for image in images] | |
return pil_images | |
def generate_image(pipe, seed_no, prompts, loss_type, loss_apply=False, progress=gr.Progress()): | |
try: | |
# Set timeout for CPU | |
if TORCH_DEVICE == "cpu": | |
signal.signal(signal.SIGALRM, timeout_handler) | |
signal.alarm(600) # 10 minute timeout | |
# Initialization and Setup | |
generator = torch.manual_seed(seed_no) | |
scheduler = LMSDiscreteScheduler( | |
beta_start=0.00085, | |
beta_end=0.012, | |
beta_schedule="scaled_linear", | |
num_train_timesteps=1000 | |
) | |
scheduler.set_timesteps(NUM_INFERENCE_STEPS) | |
scheduler.timesteps = scheduler.timesteps.to(torch.float32) | |
# Text Processing | |
text_input = pipe.tokenizer( | |
prompts, | |
padding='max_length', | |
max_length=pipe.tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt" | |
) | |
input_ids = text_input.input_ids.to(TORCH_DEVICE) | |
# Convert text inputs to embeddings | |
with torch.no_grad(): | |
text_embeddings = pipe.text_encoder(input_ids)[0] | |
# Handle padding and truncation of text inputs | |
max_length = text_input.input_ids.shape[-1] | |
uncond_input = pipe.tokenizer( | |
[""] * BATCH_SIZE, | |
padding="max_length", | |
max_length=max_length, | |
return_tensors="pt" | |
) | |
with torch.no_grad(): | |
uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(TORCH_DEVICE))[0] | |
# Concatenate unconditioned and text embeddings | |
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | |
# Create random initial latents | |
latents = torch.randn( | |
(BATCH_SIZE, pipe.unet.config.in_channels, HEIGHT // 8, WIDTH // 8), | |
generator=generator, | |
) | |
# Move latents to device and apply noise scaling | |
if TORCH_DEVICE == "cuda": | |
latents = latents.to(torch.float16) | |
latents = latents.to(TORCH_DEVICE) | |
latents = latents * scheduler.init_noise_sigma | |
# Diffusion Process | |
timesteps = scheduler.timesteps | |
progress(0, desc="Generating") | |
# Fixed loop - separate the progress tracking from the enumeration | |
for i in range(len(timesteps)): | |
progress((i + 1) / len(timesteps), desc=f"Diffusion step {i+1}/{len(timesteps)}") | |
t = timesteps[i] | |
# Process the latent model input | |
latent_model_input = torch.cat([latents] * 2) | |
sigma = scheduler.sigmas[i] | |
latent_model_input = scheduler.scale_model_input(latent_model_input, t) | |
with torch.no_grad(): | |
noise_pred = pipe.unet( | |
latent_model_input, | |
t, | |
encoder_hidden_states=text_embeddings | |
)["sample"] | |
# Apply noise prediction | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + GUIDANCE_SCALE * (noise_pred_text - noise_pred_uncond) | |
# Apply loss if requested | |
if loss_apply and i % 5 == 0 and loss_type != "N/A": | |
latents = latents.detach().requires_grad_() | |
latents_x0 = latents - sigma * noise_pred | |
# Use VAE to decode the image | |
denoised_images = pipe.vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 | |
# Apply loss | |
loss = image_loss(denoised_images, loss_type) * LOSS_SCALE | |
print(f"Step {i}, Loss: {loss.item()}") | |
# Compute gradients for optimization | |
cond_grad = torch.autograd.grad(loss, latents)[0] | |
latents = latents.detach() - cond_grad * sigma**2 | |
# Update latents using the scheduler | |
latents = scheduler.step(noise_pred, t, latents).prev_sample | |
# Garbage collect every 5 steps if on CPU | |
if TORCH_DEVICE == "cpu" and i % 5 == 0: | |
gc.collect() | |
# Clear the alarm if set | |
if TORCH_DEVICE == "cpu": | |
signal.alarm(0) | |
return latents | |
except Exception as e: | |
print(f"Error in generate_image: {e}") | |
traceback.print_exc() | |
# Return empty latents as fallback | |
return torch.zeros( | |
(BATCH_SIZE, pipe.unet.config.in_channels, HEIGHT // 8, WIDTH // 8), | |
device=TORCH_DEVICE | |
) | |
def generate_images(prompt, loss_type, apply_loss, seeds, pipe, progress=gr.Progress()): | |
try: | |
images_list = [] | |
# Convert comma-separated string to list and clean | |
seeds = [int(seed.strip()) for seed in seeds.split(',') if seed.strip()] | |
if not seeds: | |
seeds = [1000] # Default seed if none provided | |
# Process one seed at a time to save memory | |
for i, seed_no in enumerate(seeds): | |
progress((i / len(seeds)) * 0.1, desc=f"Starting seed {seed_no}") | |
# Clear memory | |
if TORCH_DEVICE == "cuda": | |
torch.cuda.empty_cache() | |
gc.collect() | |
try: | |
# Generate image | |
prompts = [prompt] | |
latents = generate_image(pipe, seed_no, prompts, loss_type, loss_apply=apply_loss, progress=progress) | |
pil_images = latents_to_pil(latents, pipe) | |
images_list.extend(pil_images) | |
except Exception as e: | |
print(f"Error generating image with seed {seed_no}: {e}") | |
# Create an error image | |
error_img = Image.new('RGB', (HEIGHT, WIDTH), color=(255, 0, 0)) | |
images_list.append(error_img) | |
# Force garbage collection | |
gc.collect() | |
# Create image grid | |
if len(images_list) > 1: | |
result = image_grid(images_list, 1, len(images_list)) | |
return result | |
else: | |
return images_list[0] | |
except Exception as e: | |
print(f"Error in generate_images: {e}") | |
traceback.print_exc() | |
# Create an error image | |
error_img = Image.new('RGB', (WIDTH, HEIGHT), color=(255, 0, 0)) | |
return error_img | |
# Gradio Interface | |
def create_interface(): | |
with gr.Blocks(title="Stable Diffusion Text Inversion with Loss Functions") as app: | |
gr.Markdown(""" | |
# Stable Diffusion Text Inversion with Loss Functions | |
Generate images using Stable Diffusion with various loss functions to guide the diffusion process. | |
""") | |
if TORCH_DEVICE == "cpu": | |
gr.Markdown(""" | |
⚠️ **Running on CPU**: Generation will be slow and memory-intensive. | |
Each image may take several minutes to generate. | |
""") | |
pipe = None # Initialize to None to avoid loading during interface creation | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox( | |
label="Prompt", | |
value=DEFAULT_PROMPT, | |
lines=3 | |
) | |
loss_type = gr.Radio( | |
label="Loss Type", | |
choices=["N/A", "blue", "elastic", "symmetry", "saturation"], | |
value="N/A" | |
) | |
apply_loss = gr.Checkbox( | |
label="Apply Loss Function", | |
value=False | |
) | |
if TORCH_DEVICE == "cpu": | |
seeds = gr.Textbox( | |
label="Seeds (comma-separated) - Use fewer seeds for CPU", | |
value="1000", | |
lines=1 | |
) | |
else: | |
seeds = gr.Textbox( | |
label="Seeds (comma-separated)", | |
value="3000,2000,1000", | |
lines=1 | |
) | |
# Load model button | |
load_model_btn = gr.Button("Load Model") | |
model_status = gr.Textbox(label="Model Status", value="Model not loaded", interactive=False) | |
generate_btn = gr.Button("Generate Images", interactive=False) | |
with gr.Column(): | |
output_image = gr.Image(label="Generated Image") | |
def load_model_fn(): | |
nonlocal pipe | |
try: | |
pipe = load_model() | |
return "Model loaded successfully", True | |
except Exception as e: | |
return f"Error loading model: {str(e)}", False | |
load_model_btn.click( | |
fn=load_model_fn, | |
inputs=[], | |
outputs=[model_status, generate_btn] | |
) | |
generate_btn.click( | |
fn=lambda p, lt, al, s, prog: generate_images(p, lt, al, s, pipe, prog), | |
inputs=[prompt, loss_type, apply_loss, seeds], | |
outputs=output_image | |
) | |
gr.Markdown(""" | |
## About the Loss Functions | |
- **Blue**: Encourages more blue tones in the image | |
- **Elastic**: Creates distortion effects by minimizing differences with elastically transformed versions | |
- **Symmetry**: Encourages symmetrical images by minimizing differences with horizontally flipped versions | |
- **Saturation**: Increases color saturation in the image | |
Set "N/A" and uncheck "Apply Loss Function" for normal image generation. | |
""") | |
if TORCH_DEVICE == "cpu": | |
gr.Markdown(""" | |
## CPU Mode Tips | |
- Use smaller prompts | |
- Process one seed at a time | |
- Be patient, generation can take 5-10 minutes per image | |
- If you encounter memory errors, try restarting the app and using even smaller dimensions | |
""") | |
return app | |
if __name__ == "__main__": | |
# Create and launch the interface | |
app = create_interface() | |
app.launch() |