Spaces:
Paused
Paused
import os | |
import re | |
import time | |
from dataclasses import dataclass | |
from glob import iglob | |
from mmgp import offload as offload | |
import torch | |
from wan.utils.utils import calculate_new_dimensions | |
from flux.sampling import denoise, get_schedule, prepare_kontext, unpack | |
from flux.modules.layers import get_linear_split_map | |
from flux.util import ( | |
aspect_ratio_to_height_width, | |
load_ae, | |
load_clip, | |
load_flow_model, | |
load_t5, | |
save_image, | |
) | |
from PIL import Image | |
def stitch_images(img1, img2): | |
# Resize img2 to match img1's height | |
width1, height1 = img1.size | |
width2, height2 = img2.size | |
new_width2 = int(width2 * height1 / height2) | |
img2_resized = img2.resize((new_width2, height1), Image.Resampling.LANCZOS) | |
stitched = Image.new('RGB', (width1 + new_width2, height1)) | |
stitched.paste(img1, (0, 0)) | |
stitched.paste(img2_resized, (width1, 0)) | |
return stitched | |
class model_factory: | |
def __init__( | |
self, | |
checkpoint_dir, | |
model_filename = None, | |
model_type = None, | |
model_def = None, | |
base_model_type = None, | |
text_encoder_filename = None, | |
quantizeTransformer = False, | |
save_quantized = False, | |
dtype = torch.bfloat16, | |
VAE_dtype = torch.float32, | |
mixed_precision_transformer = False | |
): | |
self.device = torch.device(f"cuda") | |
self.VAE_dtype = VAE_dtype | |
self.dtype = dtype | |
torch_device = "cpu" | |
# model_filename = ["c:/temp/flux1-schnell.safetensors"] | |
self.t5 = load_t5(torch_device, text_encoder_filename, max_length=512) | |
self.clip = load_clip(torch_device) | |
self.name = model_def.get("flux-model", "flux-dev") | |
# self.name= "flux-dev-kontext" | |
# self.name= "flux-dev" | |
# self.name= "flux-schnell" | |
source = model_def.get("source", None) | |
self.model = load_flow_model(self.name, model_filename[0] if source is None else source, torch_device) | |
self.vae = load_ae(self.name, device=torch_device) | |
# offload.change_dtype(self.model, dtype, True) | |
# offload.save_model(self.model, "flux-dev.safetensors") | |
if not source is None: | |
from wgp import save_model | |
save_model(self.model, model_type, dtype, None) | |
if save_quantized: | |
from wgp import save_quantized_model | |
save_quantized_model(self.model, model_type, model_filename[0], dtype, None) | |
split_linear_modules_map = get_linear_split_map() | |
self.model.split_linear_modules_map = split_linear_modules_map | |
offload.split_linear_modules(self.model, split_linear_modules_map ) | |
def generate( | |
self, | |
seed: int | None = None, | |
input_prompt: str = "replace the logo with the text 'Black Forest Labs'", | |
sampling_steps: int = 20, | |
input_ref_images = None, | |
width= 832, | |
height=480, | |
embedded_guidance_scale: float = 2.5, | |
fit_into_canvas = None, | |
callback = None, | |
loras_slists = None, | |
batch_size = 1, | |
video_prompt_type = "", | |
**bbargs | |
): | |
if self._interrupt: | |
return None | |
device="cuda" | |
if "I" in video_prompt_type and input_ref_images != None and len(input_ref_images) > 0: | |
if "K" in video_prompt_type and False : | |
# image latents tiling method | |
w, h = input_ref_images[0].size | |
height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas) | |
else: | |
# image stiching method | |
stiched = input_ref_images[0] | |
if "K" in video_prompt_type : | |
w, h = input_ref_images[0].size | |
height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas) | |
for new_img in input_ref_images[1:]: | |
stiched = stitch_images(stiched, new_img) | |
input_ref_images = [stiched] | |
else: | |
input_ref_images = None | |
inp, height, width = prepare_kontext( | |
t5=self.t5, | |
clip=self.clip, | |
prompt=input_prompt, | |
ae=self.vae, | |
img_cond_list=input_ref_images, | |
target_width=width, | |
target_height=height, | |
bs=batch_size, | |
seed=seed, | |
device=device, | |
) | |
timesteps = get_schedule(sampling_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell")) | |
def unpack_latent(x): | |
return unpack(x.float(), height, width) | |
# denoise initial noise | |
x = denoise(self.model, **inp, timesteps=timesteps, guidance=embedded_guidance_scale, callback=callback, pipeline=self, loras_slists= loras_slists, unpack_latent = unpack_latent) | |
if x==None: return None | |
# decode latents to pixel space | |
x = unpack_latent(x) | |
with torch.autocast(device_type=device, dtype=torch.bfloat16): | |
x = self.vae.decode(x) | |
x = x.clamp(-1, 1) | |
x = x.transpose(0, 1) | |
return x | |
def query_model_def(model_type, model_def): | |
flux_model = model_def.get("flux-model", "flux-dev") | |
flux_schnell = flux_model == "flux-schnell" | |
model_def_output = { | |
"image_outputs" : True, | |
} | |
if flux_schnell: | |
model_def_output["no_guidance"] = True | |
return model_def_output |