upscaler / app.py
aronsaras's picture
Create app.py
1d1ede2 verified
raw
history blame
7.15 kB
import os
import time
import torch
from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel, DDIMScheduler
from diffusers.models import AutoencoderKL
from PIL import Image
import cv2
import numpy as np
import gradio as gr
from gradio_imageslider import ImageSlider
from huggingface_hub import hf_hub_download
import subprocess
# Install Real-ESRGAN with dependencies
subprocess.run("pip install git+https://github.com/inference-sh/Real-ESRGAN.git basicsr opencv-python-headless", shell=True)
from RealESRGAN import RealESRGAN
# Force CPU usage
device = torch.device("cpu")
ENABLE_CPU_OFFLOAD = True # Enable CPU offloading to manage memory
USE_TORCH_COMPILE = False # Disable torch.compile for CPU compatibility
# Create model directories
os.makedirs("models/Stable-diffusion", exist_ok=True)
os.makedirs("models/ControlNet", exist_ok=True)
os.makedirs("models/VAE", exist_ok=True)
os.makedirs("models/upscalers", exist_ok=True)
# Download essential models (reduced set to save storage)
def download_models():
models = {
"MODEL": ("dantea1118/juggernaut_reborn", "juggernaut_reborn.safetensors", "models/Stable-diffusion"),
"CONTROLNET": ("lllyasviel/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet"),
"VAE": ("stabilityai/sd-vae-ft-mse-original", "vae-ft-mse-840000-ema-pruned.safetensors", "models/VAE"),
"UPSCALER_X2": ("ai-forever/Real-ESRGAN", "RealESRGAN_x2.pth", "models/upscalers"),
}
for model, (repo_id, filename, local_dir) in models.items():
print(f"Downloading {model}...")
hf_hub_download(repo_id=repo_id, filename=filename, local_dir=local_dir)
download_models()
# Timer decorator for performance tracking
def timer_func(func):
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
print(f"{func.__name__} took {time.time() - start_time:.2f} seconds")
return result
return wrapper
# Lazy pipeline for memory efficiency
class LazyLoadPipeline:
def __init__(self):
self.pipe = None
@timer_func
def load(self):
if self.pipe is None:
print("Setting up pipeline...")
controlnet = ControlNetModel.from_single_file(
"models/ControlNet/control_v11f1e_sd15_tile.pth", torch_dtype=torch.float16
)
model_path = "models/Stable-diffusion/juggernaut_reborn.safetensors"
pipe = StableDiffusionControlNetImg2ImgPipeline.from_single_file(
model_path,
controlnet=controlnet,
torch_dtype=torch.float16,
use_safetensors=True,
)
vae = AutoencoderKL.from_single_file(
"models/VAE/vae-ft-mse-840000-ema-pruned.safetensors",
torch_dtype=torch.float16
)
pipe.vae = vae
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.to(device)
if ENABLE_CPU_OFFLOAD:
print("Enabling CPU offloading...")
pipe.enable_model_cpu_offload()
return pipe
return self.pipe
def __call__(self, *args, **kwargs):
if self.pipe is None:
self.pipe = self.load()
return self.pipe(*args, **kwargs)
# Lazy Real-ESRGAN upscaler
class LazyRealESRGAN:
def __init__(self, device, scale):
self.device = device
self.scale = scale
self.model = None
def load_model(self):
if self.model is None:
self.model = RealESRGAN(self.device, scale=self.scale)
self.model.load_weights(f'models/upscalers/RealESRGAN_x{self.scale}.pth', download=False)
def predict(self, img):
self.load_model()
return self.model.predict(img)
lazy_realesrgan_x2 = LazyRealESRGAN(device, scale=2)
@timer_func
def resize_and_upscale(input_image, resolution):
input_image = input_image.convert("RGB")
W, H = input_image.size
k = float(resolution) / min(H, W)
H = int(round(H * k / 64.0)) * 64
W = int(round(W * k / 64.0)) * 64
img = input_image.resize((W, H), resample=Image.LANCZOS)
img = lazy_realesrgan_x2.predict(img)
return img
@timer_func
def create_hdr_effect(original_image, hdr):
if hdr == 0:
return original_image
cv_original = cv2.cvtColor(np.array(original_image), cv2.COLOR_RGB2BGR)
factors = [1.0 - 0.9 * hdr, 1.0 - 0.7 * hdr, 1.0, 1.0 + 0.2 * hdr]
images = [cv2.convertScaleAbs(cv_original, alpha=factor) for factor in factors]
merge_mertens = cv2.createMergeMertens()
hdr_image = merge_mertens.process(images)
hdr_image_8bit = np.clip(hdr_image * 255, 0, 255).astype('uint8')
return Image.fromarray(cv2.cvtColor(hdr_image_8bit, cv2.COLOR_BGR2RGB))
lazy_pipe = LazyLoadPipeline()
@timer_func
def gradio_process_image(input_image, resolution, num_inference_steps, strength, hdr, guidance_scale):
print("Starting image processing...")
condition_image = resize_and_upscale(input_image, resolution)
condition_image = create_hdr_effect(condition_image, hdr)
prompt = "masterpiece, best quality, highres"
negative_prompt = "low quality, normal quality, blurry, lowres"
options = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"image": condition_image,
"control_image": condition_image,
"width": condition_image.size[0],
"height": condition_image.size[1],
"strength": strength,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
"generator": torch.Generator(device=device).manual_seed(0),
}
print("Running inference...")
result = lazy_pipe(**options).images[0]
print("Image processing completed successfully")
return [np.array(input_image), np.array(result)]
# Gradio interface
title = """<h1 align="center">Image Upscaler with Tile ControlNet</h1>
<p align="center">CPU-optimized version for Hugging Face Spaces</p>"""
with gr.Blocks() as demo:
gr.HTML(title)
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Input Image")
run_button = gr.Button("Enhance Image")
with gr.Column():
output_slider = ImageSlider(label="Before / After", type="numpy")
with gr.Accordion("Advanced Options", open=False):
resolution = gr.Slider(minimum=256, maximum=1024, value=512, step=64, label="Resolution")
num_inference_steps = gr.Slider(minimum=1, maximum=20, value=10, step=1, label="Inference Steps")
strength = gr.Slider(minimum=0, maximum=1, value=0.4, step=0.01, label="Strength")
hdr = gr.Slider(minimum=0, maximum=1, value=0, step=0.1, label="HDR Effect")
guidance_scale = gr.Slider(minimum=0, maximum=10, value=3, step=0.5, label="Guidance Scale")
run_button.click(fn=gradio_process_image,
inputs=[input_image, resolution, num_inference_steps, strength, hdr, guidance_scale],
outputs=output_slider)
# Launch the app
demo.launch()