File size: 4,282 Bytes
8eca80a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
import gc
import torch
from torch import Generator
from PIL.Image import Image
from diffusers import AutoencoderKL, FluxPipeline
from diffusers.image_processor import VaeImageProcessor
from pipelines.models import TextToImageRequest
from transformers import T5EncoderModel
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:False,garbage_collection_threshold:0.001"
torch.set_float32_matmul_precision("medium")
os.environ["TOKENIZERS_PARALLELISM"] = "True"
ckpt_id = "black-forest-labs/FLUX.1-schnell"
dtype = torch.bfloat16
Pipeline = None
# Configure CUDA settings
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.cuda.set_per_process_memory_fraction(0.999)

class BasicQuantization:
    def __init__(self, bits=16):
        self.bits = bits
        self.qmin = -(2**(bits-1))
        self.qmax = 2**(bits-1) - 1

    def quantize_tensor(self, tensor):
        scale = (tensor.max() - tensor.min()) / (self.qmax - self.qmin)
        zero_point = self.qmin - torch.round(tensor.min() / scale)
        qtensor = torch.round(tensor / scale + zero_point)
        qtensor = torch.clamp(qtensor, self.qmin, self.qmax)
        return (qtensor - zero_point) * scale, scale, zero_point

class ModelQuantization:
    def __init__(self, model, bits=16):
        self.model = model
        self.quant = BasicQuantization(bits)

    def quantize_model(self):
        for name, module in self.model.named_modules():
            if isinstance(module, torch.nn.Linear):
                if hasattr(module, 'weightML'):
                    quantized_weight, _, _ = self.quant.quantize_tensor(module.weight)
                    module.weight = torch.nn.Parameter(quantized_weight)
                if hasattr(module, 'bias') and module.bias is not None:
                    quantized_bias, _, _ = self.quant.quantize_tensor(module.bias)
                    module.bias = torch.nn.Parameter(quantized_bias)

def empty_cache():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    torch.cuda.reset_peak_memory_stats()

def load_pipeline() -> Pipeline:
    empty_cache()
    
    # Load and quantize VAE
    vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae", torch_dtype=dtype)
    quantizer = ModelQuantization(vae)
    quantizer.quantize_model()
    
    text_encoder_2 = T5EncoderModel.from_pretrained(
        "city96/t5-v1_1-xxl-encoder-bf16", torch_dtype=torch.bfloat16
    )

    # Initialize pipeline
    pipeline = FluxPipeline.from_pretrained(
        ckpt_id,
        text_encoder_2=text_encoder_2,
        vae=vae,
        torch_dtype=dtype
    )


    # Optimize memory format
    for component in [pipeline.text_encoder, pipeline.text_encoder_2, pipeline.transformer, pipeline.vae]:
        component.to(memory_format=torch.channels_last)

    # Compile and configure pipeline
    pipeline.vae = torch.compile(pipeline.vae, fullgraph=True, dynamic=False, mode="max-autotune")
    pipeline._exclude_from_cpu_offload = ["vae"]
    pipeline.enable_sequential_cpu_offload()

    # Warmup run
    empty_cache()
    for _ in range(3):
        pipeline(
            prompt="posteroexternal, eurythmical, inspection, semicotton, specification, Mercatorial, ethylate, misprint",
            width=1024,
            height=1024,
            guidance_scale=0.0,
            num_inference_steps=4,
            max_sequence_length=256
        )
    
    return pipeline

_inference_count = 0

@torch.inference_mode()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
    global _inference_count
    
    # Clear on first inference
    if _inference_count == 0:
        empty_cache()
    
    # Increment counter and empty cache every 4 inferences
    _inference_count += 1
    if _inference_count >= 4:
        empty_cache()
        _inference_count = 0
    
    torch.cuda.reset_peak_memory_stats()
    generator = Generator("cuda").manual_seed(request.seed)
    return pipeline(
            prompt=request.prompt,
            generator=generator,
            guidance_scale=0.0,
            num_inference_steps=4,
            max_sequence_length=256,
            height=request.height,
            width=request.width,
            output_type="pil"
        ).images[0]