|
import gc |
|
import random |
|
|
|
import gradio as gr |
|
import torch |
|
from controlnet_aux.processor import Processor |
|
from safetensors.torch import load_file |
|
from diffusers import ( |
|
AutoPipelineForText2Image, |
|
AutoPipelineForImage2Image, |
|
AutoPipelineForInpainting, |
|
FluxPipeline, |
|
FluxImg2ImgPipeline, |
|
FluxInpaintPipeline, |
|
FluxControlNetPipeline, |
|
StableDiffusionXLPipeline, |
|
StableDiffusionXLImg2ImgPipeline, |
|
StableDiffusionXLInpaintPipeline, |
|
StableDiffusionXLControlNetPipeline, |
|
StableDiffusionXLControlNetImg2ImgPipeline, |
|
StableDiffusionXLControlNetInpaintPipeline, |
|
) |
|
from sd_embed.embedding_funcs import get_weighted_text_embeddings_flux1, get_weighted_text_embeddings_sdxl |
|
from huggingface_hub import hf_hub_download |
|
from diffusers.schedulers import * |
|
|
|
from .models import * |
|
from .load_models import device, models, flux_vae, sdxl_vae, refiner, controlnets |
|
|
|
sd_pipes = (StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, |
|
StableDiffusionXLControlNetPipeline, StableDiffusionXLControlNetImg2ImgPipeline, StableDiffusionXLControlNetInpaintPipeline) |
|
flux_pipes = (FluxPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, FluxControlNetPipeline) |
|
|
|
|
|
def get_pipe(request: BaseReq | BaseImg2ImgReq | BaseInpaintReq): |
|
for model in models: |
|
if model['repo_id'] == request.model: |
|
pipe_args = { |
|
"pipeline": model['pipeline'], |
|
} |
|
|
|
|
|
if request.controlnet_config: |
|
pipe_args["controlnet"] = [] |
|
if model['loader'] == 'sdxl' or model['loader'] == 'flux': |
|
for controlnet in controlnets: |
|
if request.controlnet_config.controlnet in controlnet['layers']: |
|
pipe_args["controlnet"].append(controlnet['controlnet']) |
|
elif model['loader'] == 'flux-multi': |
|
controlnet = next((controlnet for controlnet in controlnets if controlnet['loader'] == 'flux-multi'), None) |
|
if controlnet is not None: |
|
|
|
pipe_args['control_mode'] = [controlnet['layers'].index(layer) for layer in request.controlnet_config.controlnet] |
|
pipe_args['controlnet'].append(controlnet['controlnet']) |
|
|
|
|
|
if not request.custom_addons: |
|
if isinstance(request, BaseInpaintReq): |
|
pipe_args['pipeline'] = AutoPipelineForInpainting.from_pipe(**pipe_args) |
|
elif isinstance(request, BaseImg2ImgReq): |
|
pipe_args['pipeline'] = AutoPipelineForImage2Image.from_pipe(**pipe_args) |
|
elif isinstance(request, BaseReq): |
|
pipe_args['pipeline'] = AutoPipelineForText2Image.from_pipe(**pipe_args) |
|
elif request.custom_addons: |
|
pipe_args['pipeline'] = None |
|
|
|
|
|
if request.vae: |
|
pipe_args["pipeline"].vae = sdxl_vae if model['loader'] == 'sdxl' else flux_vae |
|
elif not request.vae: |
|
pipe_args["pipeline"].vae = None if model['loader'] == 'sdxl' else flux_vae |
|
|
|
|
|
pipe_args["pipeline"].scheduler = load_scheduler(pipe_args["pipeline"], request.scheduler) |
|
|
|
|
|
if request.loras: |
|
for i, lora in enumerate(request.loras): |
|
pipe_args["pipeline"].load_lora_weights(lora['repo_id'], adapter_name=f"lora_{i}") |
|
adapter_names = [f"lora_{i}" for i in range(len(request.loras))] |
|
adapter_weights = [lora['weight'] for lora in request.loras] |
|
|
|
if request.fast_generation: |
|
hyper_lora = hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors") if model['loader'] == 'flux' \ |
|
else hf_hub_download("ByteDance/Hyper-SD", "Hyper-SDXL-8steps-lora.safetensors") |
|
hyper_weight = 0.125 if model['loader'] == 'flux' else 1.0 |
|
pipe_args["pipeline"].load_lora_weights(hyper_lora, adapter_name="hyper_lora") |
|
pipe_args["pipeline"].set_adapters(["hyper_lora"], [hyper_weight]) |
|
|
|
pipe_args["pipeline"].set_adapters(adapter_names, adapter_weights) |
|
|
|
|
|
if request.embeddings and model['loader'] == 'sdxl': |
|
for embedding in request.embeddings: |
|
state_dict = load_file(hf_hub_download(embedding['repo_id'])) |
|
pipe_args["pipeline"].load_textual_inversion(state_dict['clip_g'], token=embedding['token'], text_encoder=pipe_args["pipeline"].text_encoder_2, tokenizer=pipe_args["pipeline"].tokenizer_2) |
|
pipe_args["pipeline"].load_textual_inversion(state_dict["clip_l"], token=embedding['token'], text_encoder=pipe_args["pipeline"].text_encoder, tokenizer=pipe_args["pipeline"].tokenizer) |
|
|
|
return pipe_args |
|
|
|
|
|
def load_scheduler(pipeline, scheduler): |
|
schedulers = { |
|
"dpmpp_2m": (DPMSolverMultistepScheduler, {}), |
|
"dpmpp_2m_k": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True}), |
|
"dpmpp_2m_sde": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++"}), |
|
"dpmpp_2m_sde_k": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++", "use_karras_sigmas": True}), |
|
"dpmpp_sde": (DPMSolverSinglestepScheduler, {}), |
|
"dpmpp_sde_k": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": True}), |
|
"dpm2": (KDPM2DiscreteScheduler, {}), |
|
"dpm2_k": (KDPM2DiscreteScheduler, {"use_karras_sigmas": True}), |
|
"dpm2_a": (KDPM2AncestralDiscreteScheduler, {}), |
|
"dpm2_a_k": (KDPM2AncestralDiscreteScheduler, {"use_karras_sigmas": True}), |
|
"euler": (EulerDiscreteScheduler, {}), |
|
"euler_a": (EulerAncestralDiscreteScheduler, {}), |
|
"heun": (HeunDiscreteScheduler, {}), |
|
"lms": (LMSDiscreteScheduler, {}), |
|
"lms_k": (LMSDiscreteScheduler, {"use_karras_sigmas": True}), |
|
"deis": (DEISMultistepScheduler, {}), |
|
"unipc": (UniPCMultistepScheduler, {}), |
|
"fm_euler": (FlowMatchEulerDiscreteScheduler, {}), |
|
} |
|
scheduler_class, kwargs = schedulers.get(scheduler, (None, {})) |
|
|
|
if scheduler_class is not None: |
|
scheduler = scheduler_class.from_config(pipeline.scheduler.config, **kwargs) |
|
else: |
|
raise ValueError(f"Unknown scheduler: {scheduler}") |
|
|
|
return scheduler |
|
|
|
|
|
def resize_images(images: List[Image.Image], height: int, width: int, resize_mode: str): |
|
for image in images: |
|
if resize_mode == "resize_only": |
|
image = image.resize((width, height)) |
|
elif resize_mode == "crop_and_resize": |
|
image = image.crop((0, 0, width, height)) |
|
elif resize_mode == "resize_and_fill": |
|
image = image.resize((width, height), Image.Resampling.LANCZOS) |
|
|
|
return images |
|
|
|
|
|
def get_controlnet_images(controlnets: List[str], control_images: List[Image.Image], height: int, width: int, resize_mode: str): |
|
response_images = [] |
|
control_images = resize_images(control_images, height, width, resize_mode) |
|
for controlnet, image in zip(controlnets, control_images): |
|
if controlnet == "canny": |
|
processor = Processor('canny') |
|
elif controlnet == "depth": |
|
processor = Processor('depth_midas') |
|
elif controlnet == "pose": |
|
processor = Processor('openpose_full') |
|
elif controlnet == "scribble": |
|
processor = Processor('scribble') |
|
else: |
|
raise ValueError(f"Invalid Controlnet: {controlnet}") |
|
|
|
response_images.append(processor(image, to_pil=True)) |
|
|
|
return response_images |
|
|
|
|
|
def get_control_mode(controlnet_config: ControlNetReq): |
|
control_mode = [] |
|
for controlnet in controlnets: |
|
if controlnet['loader'] == 'flux-multi': |
|
layers = controlnet['layers'] |
|
|
|
for c in controlnet_config.controlnets: |
|
if c in layers: |
|
control_mode.append(layers.index(c)) |
|
|
|
return control_mode |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def cleanup(pipeline, loras = None, embeddings = None): |
|
if loras: |
|
|
|
pipeline.unload_lora_weights() |
|
if embeddings: |
|
pipeline.unload_textual_inversion() |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
def gen_img(request: BaseReq | BaseImg2ImgReq | BaseInpaintReq, progress=gr.Progress(track_tqdm=True)): |
|
progress(0.1, "Loading Pipeline") |
|
pipeline_args = get_pipe(request) |
|
pipeline = pipeline_args["pipeline"] |
|
try: |
|
progress(0.3, "Getting Prompt Embeddings") |
|
|
|
if isinstance(pipeline, flux_pipes): |
|
positive_prompt_embeds, positive_prompt_pooled = get_weighted_text_embeddings_flux1(pipeline, request.prompt) |
|
elif isinstance(pipeline, sd_pipes): |
|
positive_prompt_embeds, negative_prompt_embeds, positive_prompt_pooled, negative_prompt_pooled = get_weighted_text_embeddings_sdxl(pipeline, request.prompt, request.negative_prompt) |
|
|
|
progress(0.5, "Configuring Pipeline") |
|
|
|
args = { |
|
'prompt_embeds': positive_prompt_embeds, |
|
'pooled_prompt_embeds': positive_prompt_pooled, |
|
'height': request.height, |
|
'width': request.width, |
|
'num_images_per_prompt': request.num_images_per_prompt, |
|
'num_inference_steps': request.num_inference_steps, |
|
'guidance_scale': request.guidance_scale, |
|
'generator': [torch.Generator().manual_seed(request.seed + i) if not request.seed is any([None, 0, -1]) else torch.Generator().manual_seed(random.randint(0, 2**32 - 1)) for i in range(request.num_images_per_prompt)], |
|
} |
|
|
|
if isinstance(pipeline, sd_pipes): |
|
args['clip_skip'] = request.clip_skip |
|
args['negative_prompt_embeds'] = negative_prompt_embeds |
|
args['negative_pooled_prompt_embeds'] = negative_prompt_pooled |
|
|
|
if request.controlnet_config: |
|
args['control_image'] = get_controlnet_images(request.controlnet_config.controlnets, request.controlnet_config.control_images, request.height, request.width, request.resize_mode) |
|
args['controlnet_conditioning_scale'] = request.controlnet_config.controlnet_conditioning_scale |
|
|
|
if request.controlnet_config and isinstance(pipeline, flux_pipes): |
|
args['control_mode'] = get_control_mode(request.controlnet_config) |
|
|
|
if isinstance(request, (BaseImg2ImgReq, BaseInpaintReq)): |
|
args['image'] = resize_images([request.image], request.height, request.width, request.resize_mode)[0] |
|
args['strength'] = request.strength |
|
|
|
if isinstance(request, BaseInpaintReq): |
|
args['mask_image'] = resize_images([request.mask_image], request.height, request.width, request.resize_mode)[0] |
|
|
|
|
|
progress(0.9, "Generating Images") |
|
gr.Info(f"Request {type(request)}: {str(request.__dict__)}", duration=60) |
|
images = pipeline(**args).images |
|
|
|
|
|
if request.refiner: |
|
images = refiner(image=images, prompt=request.prompt, num_inference_steps=40, denoising_start=0.7).images |
|
|
|
progress(1.0, "Cleaning Up") |
|
cleanup(pipeline, request.loras, request.embeddings) |
|
return images |
|
except Exception as e: |
|
cleanup(pipeline, request.loras, request.embeddings) |
|
raise gr.Error(f"Error: {e}") |
|
|