import os import gc import torch from torch import Generator from PIL.Image import Image from diffusers import AutoencoderKL, FluxPipeline from diffusers.image_processor import VaeImageProcessor from pipelines.models import TextToImageRequest from transformers import T5EncoderModel os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:False,garbage_collection_threshold:0.001" torch.set_float32_matmul_precision("medium") os.environ["TOKENIZERS_PARALLELISM"] = "True" ckpt_id = "black-forest-labs/FLUX.1-schnell" dtype = torch.bfloat16 Pipeline = None # Configure CUDA settings torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True torch.cuda.set_per_process_memory_fraction(0.999) class BasicQuantization: def __init__(self, bits=16): self.bits = bits self.qmin = -(2**(bits-1)) self.qmax = 2**(bits-1) - 1 def quantize_tensor(self, tensor): scale = (tensor.max() - tensor.min()) / (self.qmax - self.qmin) zero_point = self.qmin - torch.round(tensor.min() / scale) qtensor = torch.round(tensor / scale + zero_point) qtensor = torch.clamp(qtensor, self.qmin, self.qmax) return (qtensor - zero_point) * scale, scale, zero_point class ModelQuantization: def __init__(self, model, bits=16): self.model = model self.quant = BasicQuantization(bits) def quantize_model(self): for name, module in self.model.named_modules(): if isinstance(module, torch.nn.Linear): if hasattr(module, 'weightML'): quantized_weight, _, _ = self.quant.quantize_tensor(module.weight) module.weight = torch.nn.Parameter(quantized_weight) if hasattr(module, 'bias') and module.bias is not None: quantized_bias, _, _ = self.quant.quantize_tensor(module.bias) module.bias = torch.nn.Parameter(quantized_bias) def empty_cache(): gc.collect() torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() torch.cuda.reset_peak_memory_stats() def load_pipeline() -> Pipeline: empty_cache() # Load and quantize VAE vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae", torch_dtype=dtype) quantizer = ModelQuantization(vae) quantizer.quantize_model() text_encoder_2 = T5EncoderModel.from_pretrained( "city96/t5-v1_1-xxl-encoder-bf16", torch_dtype=torch.bfloat16 ) # Initialize pipeline pipeline = FluxPipeline.from_pretrained( ckpt_id, text_encoder_2=text_encoder_2, vae=vae, torch_dtype=dtype ) # Optimize memory format for component in [pipeline.text_encoder, pipeline.text_encoder_2, pipeline.transformer, pipeline.vae]: component.to(memory_format=torch.channels_last) # Compile and configure pipeline pipeline.vae = torch.compile(pipeline.vae, fullgraph=True, dynamic=False, mode="max-autotune") pipeline._exclude_from_cpu_offload = ["vae"] pipeline.enable_sequential_cpu_offload() # Warmup run empty_cache() for _ in range(3): pipeline( prompt="posteroexternal, eurythmical, inspection, semicotton, specification, Mercatorial, ethylate, misprint", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256 ) return pipeline _inference_count = 0 @torch.inference_mode() def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image: global _inference_count # Clear on first inference if _inference_count == 0: empty_cache() # Increment counter and empty cache every 4 inferences _inference_count += 1 if _inference_count >= 4: empty_cache() _inference_count = 0 torch.cuda.reset_peak_memory_stats() generator = Generator("cuda").manual_seed(request.seed) return pipeline( prompt=request.prompt, generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil" ).images[0]