edge-flux-bad-2D / src /pipeline.py
RobertML's picture
Add files using upload-large-folder tool
8eca80a verified
import os
import gc
import torch
from torch import Generator
from PIL.Image import Image
from diffusers import AutoencoderKL, FluxPipeline
from diffusers.image_processor import VaeImageProcessor
from pipelines.models import TextToImageRequest
from transformers import T5EncoderModel
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:False,garbage_collection_threshold:0.001"
torch.set_float32_matmul_precision("medium")
os.environ["TOKENIZERS_PARALLELISM"] = "True"
ckpt_id = "black-forest-labs/FLUX.1-schnell"
dtype = torch.bfloat16
Pipeline = None
# Configure CUDA settings
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.cuda.set_per_process_memory_fraction(0.999)
class BasicQuantization:
def __init__(self, bits=16):
self.bits = bits
self.qmin = -(2**(bits-1))
self.qmax = 2**(bits-1) - 1
def quantize_tensor(self, tensor):
scale = (tensor.max() - tensor.min()) / (self.qmax - self.qmin)
zero_point = self.qmin - torch.round(tensor.min() / scale)
qtensor = torch.round(tensor / scale + zero_point)
qtensor = torch.clamp(qtensor, self.qmin, self.qmax)
return (qtensor - zero_point) * scale, scale, zero_point
class ModelQuantization:
def __init__(self, model, bits=16):
self.model = model
self.quant = BasicQuantization(bits)
def quantize_model(self):
for name, module in self.model.named_modules():
if isinstance(module, torch.nn.Linear):
if hasattr(module, 'weightML'):
quantized_weight, _, _ = self.quant.quantize_tensor(module.weight)
module.weight = torch.nn.Parameter(quantized_weight)
if hasattr(module, 'bias') and module.bias is not None:
quantized_bias, _, _ = self.quant.quantize_tensor(module.bias)
module.bias = torch.nn.Parameter(quantized_bias)
def empty_cache():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
def load_pipeline() -> Pipeline:
empty_cache()
# Load and quantize VAE
vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae", torch_dtype=dtype)
quantizer = ModelQuantization(vae)
quantizer.quantize_model()
text_encoder_2 = T5EncoderModel.from_pretrained(
"city96/t5-v1_1-xxl-encoder-bf16", torch_dtype=torch.bfloat16
)
# Initialize pipeline
pipeline = FluxPipeline.from_pretrained(
ckpt_id,
text_encoder_2=text_encoder_2,
vae=vae,
torch_dtype=dtype
)
# Optimize memory format
for component in [pipeline.text_encoder, pipeline.text_encoder_2, pipeline.transformer, pipeline.vae]:
component.to(memory_format=torch.channels_last)
# Compile and configure pipeline
pipeline.vae = torch.compile(pipeline.vae, fullgraph=True, dynamic=False, mode="max-autotune")
pipeline._exclude_from_cpu_offload = ["vae"]
pipeline.enable_sequential_cpu_offload()
# Warmup run
empty_cache()
for _ in range(3):
pipeline(
prompt="posteroexternal, eurythmical, inspection, semicotton, specification, Mercatorial, ethylate, misprint",
width=1024,
height=1024,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256
)
return pipeline
_inference_count = 0
@torch.inference_mode()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
global _inference_count
# Clear on first inference
if _inference_count == 0:
empty_cache()
# Increment counter and empty cache every 4 inferences
_inference_count += 1
if _inference_count >= 4:
empty_cache()
_inference_count = 0
torch.cuda.reset_peak_memory_stats()
generator = Generator("cuda").manual_seed(request.seed)
return pipeline(
prompt=request.prompt,
generator=generator,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
height=request.height,
width=request.width,
output_type="pil"
).images[0]