File size: 6,516 Bytes
8588d27
f4114be
 
 
 
2ab86fb
c035c39
f4114be
 
 
8588d27
 
 
ab00ded
 
f4114be
ab00ded
8588d27
ab00ded
8588d27
 
 
 
 
2ab86fb
8588d27
c035c39
8588d27
 
 
 
 
 
 
82cc8c9
8588d27
 
 
82cc8c9
18ba928
f669463
8588d27
 
 
 
 
 
 
 
 
c035c39
8588d27
55f9bde
f4114be
8588d27
c035c39
8588d27
 
88f1bc6
55f9bde
8588d27
 
 
 
c035c39
f669463
c035c39
8588d27
 
 
2ab86fb
55f9bde
8588d27
 
 
c035c39
8588d27
 
 
55f9bde
8588d27
 
82cc8c9
f669463
c035c39
55f9bde
c035c39
8588d27
c035c39
8588d27
 
 
 
 
 
f4114be
c035c39
 
 
 
2ab86fb
 
c035c39
 
8588d27
 
 
c035c39
8588d27
 
 
2ab86fb
 
c035c39
f4114be
c035c39
 
2ab86fb
c035c39
8588d27
 
 
 
 
 
 
 
 
 
 
 
 
 
c035c39
8588d27
 
 
 
f4114be
8588d27
 
f4114be
2ab86fb
8588d27
 
f4114be
c035c39
8588d27
 
 
 
 
 
 
 
c035c39
8588d27
f4114be
2ab86fb
8588d27
c035c39
8588d27
c035c39
 
 
f4114be
8588d27
 
 
 
f4114be
8588d27
c035c39
8588d27
 
c035c39
 
8588d27
 
 
c035c39
8588d27
2ab86fb
c035c39
2ab86fb
8588d27
c035c39
 
 
 
 
2ab86fb
 
8588d27
c035c39
8588d27
c035c39
 
8588d27
 
 
2ab86fb
8588d27
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
# app.py

import os
import sys

# --- Install Dependencies ---
print("Installing required packages: diffusers, gradio_imageslider, huggingface-hub…")
os.system("pip install --no-input diffusers gradio_imageslider huggingface-hub")

# --- Standard Imports ---
import logging
import random
import warnings
import io
import base64

import gradio as gr
import numpy as np
import spaces
import torch
from diffusers import FluxControlNetModel
from diffusers.pipelines import FluxControlNetPipeline
from gradio_imageslider import ImageSlider
from PIL import Image, ImageOps
from huggingface_hub import snapshot_download

# --- Logging & Device Setup ---
logging.basicConfig(level=logging.INFO)
warnings.filterwarnings("ignore")

css = """
#col-container {
    margin: 0 auto;
    max-width: 512px;
}
.gradio-container {
    max-width: 900px !important;
    margin: auto !important;
}
"""

if torch.cuda.is_available():
    power_device = "GPU"
    device = "cuda"
    torch_dtype = torch.bfloat16
else:
    power_device = "CPU"
    device = "cpu"
    torch_dtype = torch.float32

logging.info(f"Running on device={device} with dtype={torch_dtype}")

# --- Model IDs & Download (no token) ---
flux_model_id       = "black-forest-labs/FLUX.1-dev"
controlnet_model_id = "jasperai/Flux.1-dev-Controlnet-Upscaler"
local_model_dir     = flux_model_id.split("/")[-1]
pipe = None

try:
    logging.info(f"Downloading base model: {flux_model_id}")
    model_path = snapshot_download(
        repo_id=flux_model_id,
        repo_type="model",
        local_dir=local_model_dir,
        ignore_patterns=["*.md", "*.gitattributes"],
    )
    logging.info(f"Downloaded base model to: {model_path}")

    logging.info(f"Loading ControlNet: {controlnet_model_id}")
    controlnet = FluxControlNetModel.from_pretrained(
        controlnet_model_id,
        torch_dtype=torch_dtype
    ).to(device)
    logging.info("ControlNet loaded.")

    logging.info("Initializing FluxControlNetPipeline…")
    pipe = FluxControlNetPipeline.from_pretrained(
        model_path,
        controlnet=controlnet,
        torch_dtype=torch_dtype
    ).to(device)
    logging.info("Pipeline ready.")

except Exception as e:
    logging.error(f"Error loading models: {e}", exc_info=True)
    print(f"FATAL: could not load models: {e}")
    sys.exit(1)

# --- Constants & Helpers ---
MAX_SEED = 2**32 - 1
MAX_PIXEL_BUDGET = 1280 * 1280
INTERNAL_PROCESSING_FACTOR = 4

def process_input(input_image):
    if input_image is None:
        raise gr.Error("No input image provided!")
    img = ImageOps.exif_transpose(input_image)
    if img.mode != "RGB":
        img = img.convert("RGB")
    w, h = img.size

    # enforce intermediate‐scale budget
    target_px = (w*INTERNAL_PROCESSING_FACTOR)*(h*INTERNAL_PROCESSING_FACTOR)
    if target_px > MAX_PIXEL_BUDGET:
        max_in = MAX_PIXEL_BUDGET / (INTERNAL_PROCESSING_FACTOR**2)
        scale = (max_in / (w*h))**0.5
        w2, h2 = max(8,int(w*scale)), max(8,int(h*scale))
        img = img.resize((w2,h2), Image.Resampling.LANCZOS)
        was_resized = True
    else:
        was_resized = False

    # round dimensions to multiples of 8
    w2, h2 = img.size
    w2 -= w2 % 8; h2 -= h2 % 8
    if img.size != (w2,h2):
        img = img.resize((w2,h2), Image.Resampling.LANCZOS)

    return img, w, h, was_resized

@spaces.GPU(duration=75)
def infer(
    seed,
    randomize_seed,
    input_image,
    num_inference_steps,
    final_upscale_factor,
    controlnet_conditioning_scale,
    progress=gr.Progress(track_tqdm=True),
):
    global pipe
    if pipe is None:
        raise gr.Error("Pipeline not loaded.")

    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    seed = int(seed)
    final_upscale_factor = int(final_upscale_factor)

    processed, w0, h0, resized_flag = process_input(input_image)
    w_proc, h_proc = processed.size

    # prepare control image at INTERNAL scale
    cw, ch = w_proc*INTERNAL_PROCESSING_FACTOR, h_proc*INTERNAL_PROCESSING_FACTOR
    control_img = processed.resize((cw, ch), Image.Resampling.LANCZOS)

    gen = torch.Generator(device=device).manual_seed(seed)
    with torch.inference_mode():
        result = pipe(
            prompt="",
            control_image=control_img,
            controlnet_conditioning_scale=float(controlnet_conditioning_scale),
            num_inference_steps=int(num_inference_steps),
            guidance_scale=0.0,
            height=ch, width=cw,
            generator=gen
        ).images[0]

    # final resize to user factor
    if resized_flag:
        fw, fh = w_proc*final_upscale_factor, h_proc*final_upscale_factor
    else:
        fw, fh = w0*final_upscale_factor, h0*final_upscale_factor
    if (fw, fh) != result.size:
        result = result.resize((fw, fh), Image.Resampling.LANCZOS)

    buf = io.BytesIO()
    result.save(buf, format="WEBP", quality=90)
    b64 = base64.b64encode(buf.getvalue()).decode("utf-8")

    return [[input_image, result], seed, f"data:image/webp;base64,{b64}"]

# --- Gradio UI ---
with gr.Blocks(css=css, theme=gr.themes.Soft(), title="Flux Upscaler Demo") as demo:
    gr.Markdown(f"""
    # ⚡ Flux.1‑dev Upscaler  
    **Device:** {power_device} · **Internal scale:** {INTERNAL_PROCESSING_FACTOR}x · **Budget:** {MAX_PIXEL_BUDGET} px  
    """)
    with gr.Row():
        with gr.Column(scale=2):
            inp = gr.Image(label="Input Image", type="pil", sources=["upload","clipboard"], height=350)
        with gr.Column(scale=1):
            upf   = gr.Slider("Final Upscale Factor", 1, INTERNAL_PROCESSING_FACTOR, step=1, value=2)
            steps = gr.Slider("Inference Steps", 4, 50, step=1, value=15)
            cscale= gr.Slider("ControlNet Scale", 0.0, 1.5, step=0.05, value=0.6)
            with gr.Row():
                sld = gr.Slider("Seed", 0, MAX_SEED, step=1, value=42)
                rnd = gr.Checkbox("Randomize", value=True, scale=0, min_width=80)
            btn = gr.Button("⚡ Upscale Image", variant="primary")

    slider = ImageSlider("Input / Output", type="pil", interactive=False, show_label=True, position=0.5)
    out_seed= gr.Textbox("Seed Used", interactive=False, visible=True)
    out_b64 = gr.Textbox("API Base64 Output", interactive=False, visible=False)

    btn.click(
        fn=infer,
        inputs=[sld, rnd, inp, steps, upf, cscale],
        outputs=[slider, out_seed, out_b64],
        api_name="upscale"
    )

# Expose JSON API at /run/upscale
demo.queue(max_size=10).launch(share=False, show_api=True)