File size: 4,662 Bytes
bff7ef2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import os
import gc
import time
import torch
from PIL import Image as img
from PIL.Image import Image
from diffusers import (
    FluxTransformer2DModel,
    DiffusionPipeline,
    AutoencoderTiny
)
from transformers import T5EncoderModel
from huggingface_hub.constants import HF_HUB_CACHE
from torchao.quantization import quantize_, int8_weight_only
from first_block_cache.diffusers_adapters import apply_cache_on_pipe
from pipelines.models import TextToImageRequest
from torch import Generator

os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"

Pipeline = None
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True

ckpt_id = "black-forest-labs/FLUX.1-schnell"
ckpt_revision = "741f7c3ce8b383c54771c7003378a50191e9efe9"

import torch.nn.functional as F

def are_two_tensors_similar(t1, t2, *, threshold=0.95):
    """
    Cosine similarity comparison
    Returns True if tensors are similar
    """
    cos_sim = F.cosine_similarity(t1.flatten(), t2.flatten(), dim=0)
    return cos_sim.item() > threshold

def are_two_tensors_similar_mse(t1, t2, *, threshold=0.85):
    """
    Mean Squared Error comparison
    Returns True if tensors are similar
    """
    mse = F.mse_loss(t1, t2)
    return mse.item() < threshold

def are_two_tensors_similar_relative(t1, t2, *, threshold=0.15):
    """
    Optimized relative difference comparison
    Returns True if tensors are similar
    """
    with torch.no_grad():  # Disable gradient computation for efficiency
        mean_diff = torch.mean(torch.abs(t1 - t2))
        mean_t1 = torch.mean(torch.abs(t1))
        relative_diff = mean_diff / (mean_t1 + 1e-8)  # Added small epsilon for numerical stability
        return relative_diff.item() < threshold

def are_two_tensors_similar_normalized(t1, t2, *, threshold=0.85):
    """
    Normalized difference comparison
    Returns True if tensors are similar
    """
    with torch.no_grad():
        # Normalize tensors
        t1_norm = (t1 - t1.mean()) / (t1.std() + 1e-8)
        t2_norm = (t2 - t2.mean()) / (t2.std() + 1e-8)
        diff = torch.mean(torch.abs(t1_norm - t2_norm))
        return diff.item() < threshold

def are_two_tensors_similar_l1(t1, t2, *, threshold=0.85):
    """
    L1 distance comparison
    Returns True if tensors are similar
    """
    with torch.no_grad():
        l1_dist = torch.nn.L1Loss()(t1, t2)
        return l1_dist.item() < threshold

def are_two_tensors_similar_max_diff(t1, t2, *, threshold=0.85):
    """
    Maximum difference comparison
    Returns True if tensors are similar
    """
    with torch.no_grad():
        max_diff = torch.max(torch.abs(t1 - t2))
        return max_diff.item() < threshold



def empty_cache():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    torch.cuda.reset_peak_memory_stats()

def load_pipeline() -> Pipeline:
    empty_cache()

    dtype, device = torch.bfloat16, "cuda"

    text_encoder_2 = T5EncoderModel.from_pretrained(
        "city96/t5-v1_1-xxl-encoder-bf16", 
        revision="1b9c856aadb864af93c1dcdc226c2774fa67bc86", 
        torch_dtype=torch.bfloat16
    ).to(memory_format=torch.channels_last)

    path = os.path.join(HF_HUB_CACHE, "models--RobertML--FLUX.1-schnell-int8wo/snapshots/307e0777d92df966a3c0f99f31a6ee8957a9857a")
    model = FluxTransformer2DModel.from_pretrained(
        path, 
        torch_dtype=dtype, 
        use_safetensors=False
    ).to(memory_format=torch.channels_last)
    
    pipeline = DiffusionPipeline.from_pretrained(
        ckpt_id,
        revision=ckpt_revision,
        transformer=model,
        text_encoder_2=text_encoder_2,
        torch_dtype=dtype,
    ).to(device)
    
    #quantize_(pipeline.vae, int8_weight_only())
    apply_cache_on_pipe(pipeline)

    for _ in range(3):
        pipeline(
            prompt="onomancy, aftergo, spirantic, Platyhelmia, modificator, drupaceous, jobbernowl, hereness",
            width=1024,
            height=1024,
            guidance_scale=0.0,
            num_inference_steps=4,
            max_sequence_length=256
        )

    return pipeline

@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: Pipeline, generator: Generator) -> Image:
    try:
        image = pipeline(
            request.prompt,
            generator=generator,
            guidance_scale=0.0,
            num_inference_steps=4,
            max_sequence_length=256,
            height=request.height,
            width=request.width,
            output_type="pil"
        ).images[0]
    except:
        image = img.open("./RobertML.png")
    return image