DoodlePix / app.py
Scaryplasmon96's picture
Update app.py
c7e8bdd verified
# --- Fix 1: Set Matplotlib backend ---
import matplotlib
matplotlib.use('Agg') # Set backend BEFORE importing pyplot or other conflicting libs
# --- End Fix 1 ---
import gradio as gr
import torch
from diffusers import EulerAncestralDiscreteScheduler
from DoodlePix_pipeline import StableDiffusionInstructPix2PixPipeline
from PIL import Image, ImageOps # Added ImageOps for inversion
import numpy as np
import os
import importlib
import traceback # For detailed error printing
# --- FidelityMLP Class (Ensure this is correct as provided by user) ---
class FidelityMLP(torch.nn.Module):
def __init__(self, hidden_size, output_size=None):
super().__init__()
self.hidden_size = hidden_size
self.output_size = output_size or hidden_size
self.net = torch.nn.Sequential(
torch.nn.Linear(1, 128), torch.nn.LayerNorm(128), torch.nn.SiLU(),
torch.nn.Linear(128, 256), torch.nn.LayerNorm(256), torch.nn.SiLU(),
torch.nn.Linear(256, hidden_size), torch.nn.LayerNorm(hidden_size), torch.nn.Tanh()
)
self.output_proj = torch.nn.Linear(hidden_size, self.output_size)
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, torch.nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.01)
if module.bias is not None: module.bias.data.zero_()
def forward(self, x, target_dim=None):
features = self.net(x)
outputs = self.output_proj(features)
if target_dim is not None and target_dim != self.output_size:
return self._adjust_dimension(outputs, target_dim)
return outputs
def _adjust_dimension(self, embeddings, target_dim):
current_dim = embeddings.shape[-1]
if target_dim > current_dim:
pad_size = target_dim - current_dim
padding = torch.zeros((*embeddings.shape[:-1], pad_size), device=embeddings.device, dtype=embeddings.dtype)
return torch.cat([embeddings, padding], dim=-1)
elif target_dim < current_dim:
return embeddings[..., :target_dim]
return embeddings
def save_pretrained(self, save_directory):
os.makedirs(save_directory, exist_ok=True)
config = {"hidden_size": self.hidden_size, "output_size": self.output_size}
torch.save(config, os.path.join(save_directory, "config.json"))
torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin"))
@classmethod
def from_pretrained(cls, pretrained_model_path):
config_file = os.path.join(pretrained_model_path, "config.json")
model_file = os.path.join(pretrained_model_path, "pytorch_model.bin")
if not os.path.exists(config_file): raise FileNotFoundError(f"Config file not found at {config_file}")
if not os.path.exists(model_file): raise FileNotFoundError(f"Model file not found at {model_file}")
try:
config = torch.load(config_file, map_location=torch.device('cpu'))
if not isinstance(config, dict): raise TypeError(f"Expected config dict, got {type(config)}")
except Exception as e: print(f"Error loading config {config_file}: {e}"); raise
model = cls(hidden_size=config["hidden_size"], output_size=config.get("output_size", config["hidden_size"]))
try:
state_dict = torch.load(model_file, map_location=torch.device('cpu'))
model.load_state_dict(state_dict)
print(f"Successfully loaded FidelityMLP state dict from {model_file}")
except Exception as e: print(f"Error loading state dict {model_file}: {e}"); raise
return model
# --- Global Variables ---
pipeline = None
device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "Scaryplasmon96/DoodlePixV1"
# --- Model Loading Function ---
def load_pipeline():
global pipeline
if pipeline is not None: return True
print(f"Loading model {model_id} onto {device}...")
try:
hf_cache_dir = os.path.expanduser("~/.cache/huggingface/hub")
local_model_path = model_id # Let diffusers find/download
# Load Fidelity MLP if possible
fidelity_mlp_instance = None
try:
from huggingface_hub import snapshot_download, hf_hub_download
# Attempt to download config first to check existence
hf_hub_download(repo_id=model_id, filename="fidelity_mlp/config.json", cache_dir=hf_cache_dir)
# If config exists, download the whole subfolder
fidelity_mlp_path = snapshot_download(repo_id=model_id, allow_patterns="fidelity_mlp/*", local_dir_use_symlinks=False, cache_dir=hf_cache_dir)
fidelity_mlp_instance = FidelityMLP.from_pretrained(os.path.join(fidelity_mlp_path, "fidelity_mlp"))
fidelity_mlp_instance = fidelity_mlp_instance.to(device=device, dtype=torch.float16)
print("Fidelity MLP loaded successfully.")
except Exception as e:
print(f"Fidelity MLP not found or failed to load for {model_id}: {e}. Proceeding without MLP.")
fidelity_mlp_instance = None
scheduler = EulerAncestralDiscreteScheduler.from_pretrained(local_model_path, subfolder="scheduler")
pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
local_model_path, torch_dtype=torch.float16, scheduler=scheduler, safety_checker=None
).to(device)
if fidelity_mlp_instance:
pipeline.fidelity_mlp = fidelity_mlp_instance
print("Attached Fidelity MLP to pipeline.")
# Optimizations
if device == "cuda" and hasattr(pipeline, "enable_xformers_memory_efficient_attention"):
try: pipeline.enable_xformers_memory_efficient_attention(); print("Enabled xformers.")
except: print("Could not enable xformers. Using attention slicing."); pipeline.enable_attention_slicing()
else: pipeline.enable_attention_slicing(); print("Enabled attention slicing.")
print("Pipeline loaded successfully.")
return True
except Exception as e:
print(f"Error loading pipeline: {e}"); traceback.print_exc()
pipeline = None; raise gr.Error(f"Failed to load model: {e}")
# --- Image Generation Function (Corrected Input Handling) ---
def generate_image(drawing_input, prompt, fidelity_slider, steps, guidance, image_guidance, seed_val):
global pipeline
if pipeline is None:
if not load_pipeline(): return None, "Model not loaded. Check logs."
# --- Corrected Input Processing ---
print(f"DEBUG: Received drawing_input type: {type(drawing_input)}")
if isinstance(drawing_input, dict): print(f"DEBUG: Received drawing_input keys: {drawing_input.keys()}")
# Check if input is dict and get PIL image from 'composite' key
if isinstance(drawing_input, dict) and "composite" in drawing_input and isinstance(drawing_input["composite"], Image.Image):
input_image_pil = drawing_input["composite"].convert("RGB") # Get composite image
print("DEBUG: Using PIL Image from 'composite' key.")
else:
err_msg = "Drawing input format unexpected. Expected dict with PIL Image under 'composite' key."
print(f"ERROR: {err_msg} Input: {drawing_input}")
return None, err_msg
# --- End Corrected Input Processing ---
try:
# Invert the image: White bg -> Black bg, Black lines -> White lines
input_image_inverted = ImageOps.invert(input_image_pil)
#save the inverted image
# input_image_inverted.save("input_image_inverted.png")
# Ensure image is 512x512
if input_image_inverted.size != (512, 512):
print(f"Resizing input image from {input_image_inverted.size} to (512, 512)")
input_image_inverted = input_image_inverted.resize((512, 512), Image.Resampling.LANCZOS)
# Prompt Construction
final_prompt = f"f{int(fidelity_slider)}, {prompt}"
if not final_prompt.endswith("background."): final_prompt += " background."
negative_prompt = "artifacts, blur, jpg, uncanny, deformed, glow, shadow, text, words, letters, signature, watermark"
# Generation
print(f"Generating with: Prompt='{final_prompt[:100]}...', Fidelity={int(fidelity_slider)}, Steps={steps}, Guidance={guidance}, ImageGuidance={image_guidance}, Seed={seed_val}")
seed_val = int(seed_val)
generator = torch.Generator(device=device).manual_seed(seed_val)
with torch.no_grad():
output = pipeline(
prompt=final_prompt, negative_prompt=negative_prompt, image=input_image_inverted,
num_inference_steps=int(steps), guidance_scale=float(guidance),
image_guidance_scale=float(image_guidance), generator=generator,
).images[0]
print("Generation complete.")
return output, "Generation Complete"
except Exception as e:
print(f"Error during generation: {e}"); traceback.print_exc()
return None, f"Error during generation: {str(e)}"
# --- Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft(primary_hue="orange", secondary_hue="blue")) as demo:
gr.Markdown("# DoodlePix Gradio App")
gr.Markdown(f"Using model: `{model_id}`.")
status_output = gr.Textbox(label="Status", interactive=False, value="App loading...")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("## 1. Draw Something (Black on White)")
# Keep type="pil" as it provides the composite key
drawing = gr.Sketchpad(
label="Drawing Canvas",
type="pil", # type="pil" gives dict output with 'composite' key
height=512, width=512,
brush=gr.Brush(colors=["#000000"], color_mode="fixed", default_size=5),
show_label=True
)
prompt_input = gr.Textbox(label="2. Enter Prompt", placeholder="Describe the image you want...")
fidelity = gr.Slider(0, 9, step=1, value=4, label="Fidelity (0=Creative, 9=Faithful)")
num_steps = gr.Slider(10, 50, step=1, value=25, label="Inference Steps")
guidance_scale = gr.Slider(1.0, 15.0, step=0.5, value=7.5, label="Guidance Scale (CFG)")
image_guidance_scale = gr.Slider(0.5, 5.0, step=0.1, value=1.5, label="Image Guidance Scale")
seed = gr.Number(label="Seed", value=42, precision=0)
generate_button = gr.Button("🚀 Generate Image!", variant="primary")
with gr.Column(scale=1):
gr.Markdown("## 3. Generated Image")
output_image = gr.Image(label="Result", type="pil", height=512, width=512, show_label=True)
generate_button.click(
fn=generate_image,
inputs=[drawing, prompt_input, fidelity, num_steps, guidance_scale, image_guidance_scale, seed],
outputs=[output_image, status_output]
)
# --- Launch App ---
if __name__ == "__main__":
initial_status = "App loading..."
print("Attempting to pre-load pipeline...")
try:
if load_pipeline(): initial_status = "Model pre-loaded successfully."
else: initial_status = "Model pre-loading failed. Will retry on first generation."
except Exception as e:
print(f"Pre-loading failed: {e}")
initial_status = f"Model pre-loading failed: {e}. Will retry on first generation."
print(f"Pre-loading status: {initial_status}")
demo.launch()