GFPGAN-Fix / app.py
Gemini899's picture
Update app.py
55f9bde verified
raw
history blame
6.52 kB
# app.py
import os
import sys
# --- Install Dependencies ---
print("Installing required packages: diffusers, gradio_imageslider, huggingface-hub…")
os.system("pip install --no-input diffusers gradio_imageslider huggingface-hub")
# --- Standard Imports ---
import logging
import random
import warnings
import io
import base64
import gradio as gr
import numpy as np
import spaces
import torch
from diffusers import FluxControlNetModel
from diffusers.pipelines import FluxControlNetPipeline
from gradio_imageslider import ImageSlider
from PIL import Image, ImageOps
from huggingface_hub import snapshot_download
# --- Logging & Device Setup ---
logging.basicConfig(level=logging.INFO)
warnings.filterwarnings("ignore")
css = """
#col-container {
margin: 0 auto;
max-width: 512px;
}
.gradio-container {
max-width: 900px !important;
margin: auto !important;
}
"""
if torch.cuda.is_available():
power_device = "GPU"
device = "cuda"
torch_dtype = torch.bfloat16
else:
power_device = "CPU"
device = "cpu"
torch_dtype = torch.float32
logging.info(f"Running on device={device} with dtype={torch_dtype}")
# --- Model IDs & Download (no token) ---
flux_model_id = "black-forest-labs/FLUX.1-dev"
controlnet_model_id = "jasperai/Flux.1-dev-Controlnet-Upscaler"
local_model_dir = flux_model_id.split("/")[-1]
pipe = None
try:
logging.info(f"Downloading base model: {flux_model_id}")
model_path = snapshot_download(
repo_id=flux_model_id,
repo_type="model",
local_dir=local_model_dir,
ignore_patterns=["*.md", "*.gitattributes"],
)
logging.info(f"Downloaded base model to: {model_path}")
logging.info(f"Loading ControlNet: {controlnet_model_id}")
controlnet = FluxControlNetModel.from_pretrained(
controlnet_model_id,
torch_dtype=torch_dtype
).to(device)
logging.info("ControlNet loaded.")
logging.info("Initializing FluxControlNetPipeline…")
pipe = FluxControlNetPipeline.from_pretrained(
model_path,
controlnet=controlnet,
torch_dtype=torch_dtype
).to(device)
logging.info("Pipeline ready.")
except Exception as e:
logging.error(f"Error loading models: {e}", exc_info=True)
print(f"FATAL: could not load models: {e}")
sys.exit(1)
# --- Constants & Helpers ---
MAX_SEED = 2**32 - 1
MAX_PIXEL_BUDGET = 1280 * 1280
INTERNAL_PROCESSING_FACTOR = 4
def process_input(input_image):
if input_image is None:
raise gr.Error("No input image provided!")
img = ImageOps.exif_transpose(input_image)
if img.mode != "RGB":
img = img.convert("RGB")
w, h = img.size
# enforce intermediate‐scale budget
target_px = (w*INTERNAL_PROCESSING_FACTOR)*(h*INTERNAL_PROCESSING_FACTOR)
if target_px > MAX_PIXEL_BUDGET:
max_in = MAX_PIXEL_BUDGET / (INTERNAL_PROCESSING_FACTOR**2)
scale = (max_in / (w*h))**0.5
w2, h2 = max(8,int(w*scale)), max(8,int(h*scale))
img = img.resize((w2,h2), Image.Resampling.LANCZOS)
was_resized = True
else:
was_resized = False
# round dimensions to multiples of 8
w2, h2 = img.size
w2 -= w2 % 8; h2 -= h2 % 8
if img.size != (w2,h2):
img = img.resize((w2,h2), Image.Resampling.LANCZOS)
return img, w, h, was_resized
@spaces.GPU(duration=75)
def infer(
seed,
randomize_seed,
input_image,
num_inference_steps,
final_upscale_factor,
controlnet_conditioning_scale,
progress=gr.Progress(track_tqdm=True),
):
global pipe
if pipe is None:
raise gr.Error("Pipeline not loaded.")
if randomize_seed:
seed = random.randint(0, MAX_SEED)
seed = int(seed)
final_upscale_factor = int(final_upscale_factor)
processed, w0, h0, resized_flag = process_input(input_image)
w_proc, h_proc = processed.size
# prepare control image at INTERNAL scale
cw, ch = w_proc*INTERNAL_PROCESSING_FACTOR, h_proc*INTERNAL_PROCESSING_FACTOR
control_img = processed.resize((cw, ch), Image.Resampling.LANCZOS)
gen = torch.Generator(device=device).manual_seed(seed)
with torch.inference_mode():
result = pipe(
prompt="",
control_image=control_img,
controlnet_conditioning_scale=float(controlnet_conditioning_scale),
num_inference_steps=int(num_inference_steps),
guidance_scale=0.0,
height=ch, width=cw,
generator=gen
).images[0]
# final resize to user factor
if resized_flag:
fw, fh = w_proc*final_upscale_factor, h_proc*final_upscale_factor
else:
fw, fh = w0*final_upscale_factor, h0*final_upscale_factor
if (fw, fh) != result.size:
result = result.resize((fw, fh), Image.Resampling.LANCZOS)
buf = io.BytesIO()
result.save(buf, format="WEBP", quality=90)
b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
return [[input_image, result], seed, f"data:image/webp;base64,{b64}"]
# --- Gradio UI ---
with gr.Blocks(css=css, theme=gr.themes.Soft(), title="Flux Upscaler Demo") as demo:
gr.Markdown(f"""
# ⚡ Flux.1‑dev Upscaler
**Device:** {power_device} · **Internal scale:** {INTERNAL_PROCESSING_FACTOR}x · **Budget:** {MAX_PIXEL_BUDGET} px
""")
with gr.Row():
with gr.Column(scale=2):
inp = gr.Image(label="Input Image", type="pil", sources=["upload","clipboard"], height=350)
with gr.Column(scale=1):
upf = gr.Slider("Final Upscale Factor", 1, INTERNAL_PROCESSING_FACTOR, step=1, value=2)
steps = gr.Slider("Inference Steps", 4, 50, step=1, value=15)
cscale= gr.Slider("ControlNet Scale", 0.0, 1.5, step=0.05, value=0.6)
with gr.Row():
sld = gr.Slider("Seed", 0, MAX_SEED, step=1, value=42)
rnd = gr.Checkbox("Randomize", value=True, scale=0, min_width=80)
btn = gr.Button("⚡ Upscale Image", variant="primary")
slider = ImageSlider("Input / Output", type="pil", interactive=False, show_label=True, position=0.5)
out_seed= gr.Textbox("Seed Used", interactive=False, visible=True)
out_b64 = gr.Textbox("API Base64 Output", interactive=False, visible=False)
btn.click(
fn=infer,
inputs=[sld, rnd, inp, steps, upf, cscale],
outputs=[slider, out_seed, out_b64],
api_name="upscale"
)
# Expose JSON API at /run/upscale
demo.queue(max_size=10).launch(share=False, show_api=True)