Spaces:
Sleeping
Sleeping
# --- Fix 1: Set Matplotlib backend --- | |
import matplotlib | |
matplotlib.use('Agg') # Set backend BEFORE importing pyplot or other conflicting libs | |
# --- End Fix 1 --- | |
import gradio as gr | |
import torch | |
from diffusers import EulerAncestralDiscreteScheduler | |
from DoodlePix_pipeline import StableDiffusionInstructPix2PixPipeline | |
from PIL import Image, ImageOps # Added ImageOps for inversion | |
import numpy as np | |
import os | |
import importlib | |
import traceback # For detailed error printing | |
# --- FidelityMLP Class (Ensure this is correct as provided by user) --- | |
class FidelityMLP(torch.nn.Module): | |
def __init__(self, hidden_size, output_size=None): | |
super().__init__() | |
self.hidden_size = hidden_size | |
self.output_size = output_size or hidden_size | |
self.net = torch.nn.Sequential( | |
torch.nn.Linear(1, 128), torch.nn.LayerNorm(128), torch.nn.SiLU(), | |
torch.nn.Linear(128, 256), torch.nn.LayerNorm(256), torch.nn.SiLU(), | |
torch.nn.Linear(256, hidden_size), torch.nn.LayerNorm(hidden_size), torch.nn.Tanh() | |
) | |
self.output_proj = torch.nn.Linear(hidden_size, self.output_size) | |
self.apply(self._init_weights) | |
def _init_weights(self, module): | |
if isinstance(module, torch.nn.Linear): | |
module.weight.data.normal_(mean=0.0, std=0.01) | |
if module.bias is not None: module.bias.data.zero_() | |
def forward(self, x, target_dim=None): | |
features = self.net(x) | |
outputs = self.output_proj(features) | |
if target_dim is not None and target_dim != self.output_size: | |
return self._adjust_dimension(outputs, target_dim) | |
return outputs | |
def _adjust_dimension(self, embeddings, target_dim): | |
current_dim = embeddings.shape[-1] | |
if target_dim > current_dim: | |
pad_size = target_dim - current_dim | |
padding = torch.zeros((*embeddings.shape[:-1], pad_size), device=embeddings.device, dtype=embeddings.dtype) | |
return torch.cat([embeddings, padding], dim=-1) | |
elif target_dim < current_dim: | |
return embeddings[..., :target_dim] | |
return embeddings | |
def save_pretrained(self, save_directory): | |
os.makedirs(save_directory, exist_ok=True) | |
config = {"hidden_size": self.hidden_size, "output_size": self.output_size} | |
torch.save(config, os.path.join(save_directory, "config.json")) | |
torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin")) | |
def from_pretrained(cls, pretrained_model_path): | |
config_file = os.path.join(pretrained_model_path, "config.json") | |
model_file = os.path.join(pretrained_model_path, "pytorch_model.bin") | |
if not os.path.exists(config_file): raise FileNotFoundError(f"Config file not found at {config_file}") | |
if not os.path.exists(model_file): raise FileNotFoundError(f"Model file not found at {model_file}") | |
try: | |
config = torch.load(config_file, map_location=torch.device('cpu')) | |
if not isinstance(config, dict): raise TypeError(f"Expected config dict, got {type(config)}") | |
except Exception as e: print(f"Error loading config {config_file}: {e}"); raise | |
model = cls(hidden_size=config["hidden_size"], output_size=config.get("output_size", config["hidden_size"])) | |
try: | |
state_dict = torch.load(model_file, map_location=torch.device('cpu')) | |
model.load_state_dict(state_dict) | |
print(f"Successfully loaded FidelityMLP state dict from {model_file}") | |
except Exception as e: print(f"Error loading state dict {model_file}: {e}"); raise | |
return model | |
# --- Global Variables --- | |
pipeline = None | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model_id = "Scaryplasmon96/DoodlePixV1" | |
# --- Model Loading Function --- | |
def load_pipeline(): | |
global pipeline | |
if pipeline is not None: return True | |
print(f"Loading model {model_id} onto {device}...") | |
try: | |
hf_cache_dir = os.path.expanduser("~/.cache/huggingface/hub") | |
local_model_path = model_id # Let diffusers find/download | |
# Load Fidelity MLP if possible | |
fidelity_mlp_instance = None | |
try: | |
from huggingface_hub import snapshot_download, hf_hub_download | |
# Attempt to download config first to check existence | |
hf_hub_download(repo_id=model_id, filename="fidelity_mlp/config.json", cache_dir=hf_cache_dir) | |
# If config exists, download the whole subfolder | |
fidelity_mlp_path = snapshot_download(repo_id=model_id, allow_patterns="fidelity_mlp/*", local_dir_use_symlinks=False, cache_dir=hf_cache_dir) | |
fidelity_mlp_instance = FidelityMLP.from_pretrained(os.path.join(fidelity_mlp_path, "fidelity_mlp")) | |
fidelity_mlp_instance = fidelity_mlp_instance.to(device=device, dtype=torch.float16) | |
print("Fidelity MLP loaded successfully.") | |
except Exception as e: | |
print(f"Fidelity MLP not found or failed to load for {model_id}: {e}. Proceeding without MLP.") | |
fidelity_mlp_instance = None | |
scheduler = EulerAncestralDiscreteScheduler.from_pretrained(local_model_path, subfolder="scheduler") | |
pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained( | |
local_model_path, torch_dtype=torch.float16, scheduler=scheduler, safety_checker=None | |
).to(device) | |
if fidelity_mlp_instance: | |
pipeline.fidelity_mlp = fidelity_mlp_instance | |
print("Attached Fidelity MLP to pipeline.") | |
# Optimizations | |
if device == "cuda" and hasattr(pipeline, "enable_xformers_memory_efficient_attention"): | |
try: pipeline.enable_xformers_memory_efficient_attention(); print("Enabled xformers.") | |
except: print("Could not enable xformers. Using attention slicing."); pipeline.enable_attention_slicing() | |
else: pipeline.enable_attention_slicing(); print("Enabled attention slicing.") | |
print("Pipeline loaded successfully.") | |
return True | |
except Exception as e: | |
print(f"Error loading pipeline: {e}"); traceback.print_exc() | |
pipeline = None; raise gr.Error(f"Failed to load model: {e}") | |
# --- Image Generation Function (Corrected Input Handling) --- | |
def generate_image(drawing_input, prompt, fidelity_slider, steps, guidance, image_guidance, seed_val): | |
global pipeline | |
if pipeline is None: | |
if not load_pipeline(): return None, "Model not loaded. Check logs." | |
# --- Corrected Input Processing --- | |
print(f"DEBUG: Received drawing_input type: {type(drawing_input)}") | |
if isinstance(drawing_input, dict): print(f"DEBUG: Received drawing_input keys: {drawing_input.keys()}") | |
# Check if input is dict and get PIL image from 'composite' key | |
if isinstance(drawing_input, dict) and "composite" in drawing_input and isinstance(drawing_input["composite"], Image.Image): | |
input_image_pil = drawing_input["composite"].convert("RGB") # Get composite image | |
print("DEBUG: Using PIL Image from 'composite' key.") | |
else: | |
err_msg = "Drawing input format unexpected. Expected dict with PIL Image under 'composite' key." | |
print(f"ERROR: {err_msg} Input: {drawing_input}") | |
return None, err_msg | |
# --- End Corrected Input Processing --- | |
try: | |
# Invert the image: White bg -> Black bg, Black lines -> White lines | |
input_image_inverted = ImageOps.invert(input_image_pil) | |
#save the inverted image | |
# input_image_inverted.save("input_image_inverted.png") | |
# Ensure image is 512x512 | |
if input_image_inverted.size != (512, 512): | |
print(f"Resizing input image from {input_image_inverted.size} to (512, 512)") | |
input_image_inverted = input_image_inverted.resize((512, 512), Image.Resampling.LANCZOS) | |
# Prompt Construction | |
final_prompt = f"f{int(fidelity_slider)}, {prompt}" | |
if not final_prompt.endswith("background."): final_prompt += " background." | |
negative_prompt = "artifacts, blur, jpg, uncanny, deformed, glow, shadow, text, words, letters, signature, watermark" | |
# Generation | |
print(f"Generating with: Prompt='{final_prompt[:100]}...', Fidelity={int(fidelity_slider)}, Steps={steps}, Guidance={guidance}, ImageGuidance={image_guidance}, Seed={seed_val}") | |
seed_val = int(seed_val) | |
generator = torch.Generator(device=device).manual_seed(seed_val) | |
with torch.no_grad(): | |
output = pipeline( | |
prompt=final_prompt, negative_prompt=negative_prompt, image=input_image_inverted, | |
num_inference_steps=int(steps), guidance_scale=float(guidance), | |
image_guidance_scale=float(image_guidance), generator=generator, | |
).images[0] | |
print("Generation complete.") | |
return output, "Generation Complete" | |
except Exception as e: | |
print(f"Error during generation: {e}"); traceback.print_exc() | |
return None, f"Error during generation: {str(e)}" | |
# --- Gradio Interface --- | |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="orange", secondary_hue="blue")) as demo: | |
gr.Markdown("# DoodlePix Gradio App") | |
gr.Markdown(f"Using model: `{model_id}`.") | |
status_output = gr.Textbox(label="Status", interactive=False, value="App loading...") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("## 1. Draw Something (Black on White)") | |
# Keep type="pil" as it provides the composite key | |
drawing = gr.Sketchpad( | |
label="Drawing Canvas", | |
type="pil", # type="pil" gives dict output with 'composite' key | |
height=512, width=512, | |
brush=gr.Brush(colors=["#000000"], color_mode="fixed", default_size=5), | |
show_label=True | |
) | |
prompt_input = gr.Textbox(label="2. Enter Prompt", placeholder="Describe the image you want...") | |
fidelity = gr.Slider(0, 9, step=1, value=4, label="Fidelity (0=Creative, 9=Faithful)") | |
num_steps = gr.Slider(10, 50, step=1, value=25, label="Inference Steps") | |
guidance_scale = gr.Slider(1.0, 15.0, step=0.5, value=7.5, label="Guidance Scale (CFG)") | |
image_guidance_scale = gr.Slider(0.5, 5.0, step=0.1, value=1.5, label="Image Guidance Scale") | |
seed = gr.Number(label="Seed", value=42, precision=0) | |
generate_button = gr.Button("🚀 Generate Image!", variant="primary") | |
with gr.Column(scale=1): | |
gr.Markdown("## 3. Generated Image") | |
output_image = gr.Image(label="Result", type="pil", height=512, width=512, show_label=True) | |
generate_button.click( | |
fn=generate_image, | |
inputs=[drawing, prompt_input, fidelity, num_steps, guidance_scale, image_guidance_scale, seed], | |
outputs=[output_image, status_output] | |
) | |
# --- Launch App --- | |
if __name__ == "__main__": | |
initial_status = "App loading..." | |
print("Attempting to pre-load pipeline...") | |
try: | |
if load_pipeline(): initial_status = "Model pre-loaded successfully." | |
else: initial_status = "Model pre-loading failed. Will retry on first generation." | |
except Exception as e: | |
print(f"Pre-loading failed: {e}") | |
initial_status = f"Model pre-loading failed: {e}. Will retry on first generation." | |
print(f"Pre-loading status: {initial_status}") | |
demo.launch() |