File size: 4,282 Bytes
8eca80a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import os
import gc
import torch
from torch import Generator
from PIL.Image import Image
from diffusers import AutoencoderKL, FluxPipeline
from diffusers.image_processor import VaeImageProcessor
from pipelines.models import TextToImageRequest
from transformers import T5EncoderModel
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:False,garbage_collection_threshold:0.001"
torch.set_float32_matmul_precision("medium")
os.environ["TOKENIZERS_PARALLELISM"] = "True"
ckpt_id = "black-forest-labs/FLUX.1-schnell"
dtype = torch.bfloat16
Pipeline = None
# Configure CUDA settings
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.cuda.set_per_process_memory_fraction(0.999)
class BasicQuantization:
def __init__(self, bits=16):
self.bits = bits
self.qmin = -(2**(bits-1))
self.qmax = 2**(bits-1) - 1
def quantize_tensor(self, tensor):
scale = (tensor.max() - tensor.min()) / (self.qmax - self.qmin)
zero_point = self.qmin - torch.round(tensor.min() / scale)
qtensor = torch.round(tensor / scale + zero_point)
qtensor = torch.clamp(qtensor, self.qmin, self.qmax)
return (qtensor - zero_point) * scale, scale, zero_point
class ModelQuantization:
def __init__(self, model, bits=16):
self.model = model
self.quant = BasicQuantization(bits)
def quantize_model(self):
for name, module in self.model.named_modules():
if isinstance(module, torch.nn.Linear):
if hasattr(module, 'weightML'):
quantized_weight, _, _ = self.quant.quantize_tensor(module.weight)
module.weight = torch.nn.Parameter(quantized_weight)
if hasattr(module, 'bias') and module.bias is not None:
quantized_bias, _, _ = self.quant.quantize_tensor(module.bias)
module.bias = torch.nn.Parameter(quantized_bias)
def empty_cache():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
def load_pipeline() -> Pipeline:
empty_cache()
# Load and quantize VAE
vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae", torch_dtype=dtype)
quantizer = ModelQuantization(vae)
quantizer.quantize_model()
text_encoder_2 = T5EncoderModel.from_pretrained(
"city96/t5-v1_1-xxl-encoder-bf16", torch_dtype=torch.bfloat16
)
# Initialize pipeline
pipeline = FluxPipeline.from_pretrained(
ckpt_id,
text_encoder_2=text_encoder_2,
vae=vae,
torch_dtype=dtype
)
# Optimize memory format
for component in [pipeline.text_encoder, pipeline.text_encoder_2, pipeline.transformer, pipeline.vae]:
component.to(memory_format=torch.channels_last)
# Compile and configure pipeline
pipeline.vae = torch.compile(pipeline.vae, fullgraph=True, dynamic=False, mode="max-autotune")
pipeline._exclude_from_cpu_offload = ["vae"]
pipeline.enable_sequential_cpu_offload()
# Warmup run
empty_cache()
for _ in range(3):
pipeline(
prompt="posteroexternal, eurythmical, inspection, semicotton, specification, Mercatorial, ethylate, misprint",
width=1024,
height=1024,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256
)
return pipeline
_inference_count = 0
@torch.inference_mode()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
global _inference_count
# Clear on first inference
if _inference_count == 0:
empty_cache()
# Increment counter and empty cache every 4 inferences
_inference_count += 1
if _inference_count >= 4:
empty_cache()
_inference_count = 0
torch.cuda.reset_peak_memory_stats()
generator = Generator("cuda").manual_seed(request.seed)
return pipeline(
prompt=request.prompt,
generator=generator,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
height=request.height,
width=request.width,
output_type="pil"
).images[0]
|