manbeast3b commited on
Commit
8f24c5c
·
verified ·
1 Parent(s): 065ed72

Update src/pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline.py +311 -136
src/pipeline.py CHANGED
@@ -1,141 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
- import gc
3
- import time
4
  import torch
5
- import torch.nn.functional as F
6
- from PIL import Image as img
 
 
 
7
  from PIL.Image import Image
8
- from typing import Optional, Type
9
- from dataclasses import dataclass
10
-
11
- from diffusers import (
12
- FluxTransformer2DModel,
13
- DiffusionPipeline,
14
- AutoencoderTiny
15
- )
16
  from transformers import T5EncoderModel
17
  from huggingface_hub.constants import HF_HUB_CACHE
18
  from torchao.quantization import quantize_, int8_weight_only, float8_weight_only
 
19
  from pipelines.models import TextToImageRequest
20
- from torch import Generator
21
- from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
22
 
23
- # Configuration
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  @dataclass
25
- class Config:
26
- CKPT_ID: str = "black-forest-labs/FLUX.1-schnell"
27
- CKPT_REVISION: str = "741f7c3ce8b383c54771c7003378a50191e9efe9"
28
- DEVICE: str = "cuda"
29
- DTYPE = torch.bfloat16
30
- PYTORCH_CUDA_ALLOC_CONF: str = "expandable_segments:True"
31
-
32
- # Initialize global settings
33
- def init_global_settings():
34
- torch.backends.cuda.matmul.allow_tf32 = True
35
- torch.backends.cudnn.enabled = True
36
- torch.backends.cudnn.benchmark = True
37
- os.environ['PYTORCH_CUDA_ALLOC_CONF'] = Config.PYTORCH_CUDA_ALLOC_CONF
38
 
39
- # Tensor comparison utilities
40
- class TensorComparator:
41
- @staticmethod
42
- def orig_comparison(t1, t2, *, threshold=0.85):
43
- mean_diff = (t1 - t2).abs().mean()
44
- mean_t1 = t1.abs().mean()
45
- diff = mean_diff / mean_t1
46
- return diff.item() < threshold
47
 
48
- @staticmethod
49
- def mse_comparison(t1, t2, threshold=0.95):
50
- mse = F.mse_loss(t1, t2)
51
- return mse.item() < threshold
52
-
53
- @staticmethod
54
- def relative_comparison(t1, t2, threshold=0.15):
55
- with torch.no_grad():
56
- mean_diff = torch.mean(torch.abs(t1 - t2))
57
- mean_t1 = torch.mean(torch.abs(t1))
58
- relative_diff = mean_diff / (mean_t1 + 1e-8)
59
- return relative_diff.item() < threshold
60
-
61
- @staticmethod
62
- def normalized_comparison(t1, t2, threshold=0.85):
63
- with torch.no_grad():
64
- t1_norm = (t1 - t1.mean()) / (t1.std() + 1e-8)
65
- t2_norm = (t2 - t2.mean()) / (t2.std() + 1e-8)
66
- diff = torch.mean(torch.abs(t1_norm - t2_norm))
67
- return diff.item() < threshold
68
-
69
- @staticmethod
70
- def l1_comparison(t1, t2, threshold=0.85):
71
- with torch.no_grad():
72
- l1_dist = torch.nn.L1Loss()(t1, t2)
73
- return l1_dist.item() < threshold
74
-
75
- @staticmethod
76
- def max_diff_comparison(t1, t2, threshold=0.85):
77
- with torch.no_grad():
78
- max_diff = torch.max(torch.abs(t1 - t2))
79
- return max_diff.item() < threshold
80
-
81
- # Memory management
82
- class MemoryManager:
83
- @staticmethod
84
- def empty_cache():
85
- gc.collect()
86
- torch.cuda.empty_cache()
87
- torch.cuda.reset_max_memory_allocated()
88
- torch.cuda.reset_peak_memory_stats()
89
-
90
- # Pipeline management
91
- class PipelineManager:
92
- @staticmethod
93
- def load_pipeline() -> DiffusionPipeline:
94
- MemoryManager.empty_cache()
95
-
96
- text_encoder_2 = T5EncoderModel.from_pretrained(
97
- "city96/t5-v1_1-xxl-encoder-bf16",
98
- revision="1b9c856aadb864af93c1dcdc226c2774fa67bc86",
99
- torch_dtype=Config.DTYPE
100
- ).to(memory_format=torch.channels_last)
101
- vae = AutoencoderTiny.from_pretrained("RobertML/FLUX.1-schnell-vae_e3m2", revision="da0d2cd7815792fb40d084dbd8ed32b63f153d8d", torch_dtype=Config.DTYPE)
102
- # vae.encoder=_load(vae.encoder, "E", dtype=torch.bfloat16); vae.decoder=_load(vae.decoder, "D", dtype=torch.bfloat16)
103
-
104
- path = os.path.join(HF_HUB_CACHE, "models--RobertML--FLUX.1-schnell-int8wo/snapshots/307e0777d92df966a3c0f99f31a6ee8957a9857a")
105
- model = FluxTransformer2DModel.from_pretrained(
106
- path,
107
- torch_dtype=Config.DTYPE,
108
- use_safetensors=False
109
- ).to(memory_format=torch.channels_last)
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  pipeline = DiffusionPipeline.from_pretrained(
112
- Config.CKPT_ID,
113
- vae=vae,
114
- revision=Config.CKPT_REVISION,
115
- transformer=model,
116
- text_encoder_2=text_encoder_2,
117
- torch_dtype=Config.DTYPE,
118
- ).to(Config.DEVICE)
119
-
120
- apply_cache_on_pipe(pipeline)
121
- pipeline.to(memory_format=torch.channels_last)
122
- pipeline.vae = torch.compile(pipeline.vae, mode="max-autotune")
123
- quantize_(pipeline.vae, int8_weight_only())
124
- quantize_(pipeline.vae, float8_weight_only())
125
- PipelineManager._warmup(pipeline)
126
 
 
 
 
127
  return pipeline
128
 
129
- @staticmethod
130
- def _warmup(pipeline):
131
- for _ in range(3):
132
- pipeline(prompt=" ")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
- @staticmethod
135
- @torch.no_grad()
136
- def infer(request: TextToImageRequest, pipeline: DiffusionPipeline, generator: Generator) -> Image:
137
- image = pipeline(
138
- request.prompt,
 
139
  generator=generator,
140
  guidance_scale=0.0,
141
  num_inference_steps=4,
@@ -143,18 +332,4 @@ class PipelineManager:
143
  height=request.height,
144
  width=request.width,
145
  output_type="pil"
146
- ).images[0]
147
- return image
148
-
149
- # Initialize global settings
150
- init_global_settings()
151
-
152
- # Keep original interface
153
- load_pipeline = PipelineManager.load_pipeline
154
- infer = PipelineManager.infer
155
- are_two_tensors_similar = TensorComparator.orig_comparison
156
- are_two_tensors_similar_relative = TensorComparator.relative_comparison
157
- are_two_tensors_similar_normalized = TensorComparator.normalized_comparison
158
- are_two_tensors_similar_l1 = TensorComparator.l1_comparison
159
- are_two_tensors_similar_max_diff = TensorComparator.max_diff_comparison
160
- empty_cache = MemoryManager.empty_cache
 
1
+ # import os
2
+ # import gc
3
+ # import time
4
+ # import torch
5
+ # import torch.nn.functional as F
6
+ # from PIL import Image as img
7
+ # from PIL.Image import Image
8
+ # from typing import Optional, Type
9
+ # from dataclasses import dataclass
10
+
11
+ # from diffusers import (
12
+ # FluxTransformer2DModel,
13
+ # DiffusionPipeline,
14
+ # AutoencoderTiny
15
+ # )
16
+ # from transformers import T5EncoderModel
17
+ # from huggingface_hub.constants import HF_HUB_CACHE
18
+ # from torchao.quantization import quantize_, int8_weight_only, float8_weight_only
19
+ # from pipelines.models import TextToImageRequest
20
+ # from torch import Generator
21
+ # from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
22
+
23
+ # # Configuration
24
+ # @dataclass
25
+ # class Config:
26
+ # CKPT_ID: str = "black-forest-labs/FLUX.1-schnell"
27
+ # CKPT_REVISION: str = "741f7c3ce8b383c54771c7003378a50191e9efe9"
28
+ # DEVICE: str = "cuda"
29
+ # DTYPE = torch.bfloat16
30
+ # PYTORCH_CUDA_ALLOC_CONF: str = "expandable_segments:True"
31
+
32
+ # # Initialize global settings
33
+ # def init_global_settings():
34
+ # torch.backends.cuda.matmul.allow_tf32 = True
35
+ # torch.backends.cudnn.enabled = True
36
+ # torch.backends.cudnn.benchmark = True
37
+ # os.environ['PYTORCH_CUDA_ALLOC_CONF'] = Config.PYTORCH_CUDA_ALLOC_CONF
38
+
39
+ # # Tensor comparison utilities
40
+ # class TensorComparator:
41
+ # @staticmethod
42
+ # def orig_comparison(t1, t2, *, threshold=0.85):
43
+ # mean_diff = (t1 - t2).abs().mean()
44
+ # mean_t1 = t1.abs().mean()
45
+ # diff = mean_diff / mean_t1
46
+ # return diff.item() < threshold
47
+
48
+ # @staticmethod
49
+ # def mse_comparison(t1, t2, threshold=0.95):
50
+ # mse = F.mse_loss(t1, t2)
51
+ # return mse.item() < threshold
52
+
53
+ # @staticmethod
54
+ # def relative_comparison(t1, t2, threshold=0.15):
55
+ # with torch.no_grad():
56
+ # mean_diff = torch.mean(torch.abs(t1 - t2))
57
+ # mean_t1 = torch.mean(torch.abs(t1))
58
+ # relative_diff = mean_diff / (mean_t1 + 1e-8)
59
+ # return relative_diff.item() < threshold
60
+
61
+ # @staticmethod
62
+ # def normalized_comparison(t1, t2, threshold=0.85):
63
+ # with torch.no_grad():
64
+ # t1_norm = (t1 - t1.mean()) / (t1.std() + 1e-8)
65
+ # t2_norm = (t2 - t2.mean()) / (t2.std() + 1e-8)
66
+ # diff = torch.mean(torch.abs(t1_norm - t2_norm))
67
+ # return diff.item() < threshold
68
+
69
+ # @staticmethod
70
+ # def l1_comparison(t1, t2, threshold=0.85):
71
+ # with torch.no_grad():
72
+ # l1_dist = torch.nn.L1Loss()(t1, t2)
73
+ # return l1_dist.item() < threshold
74
+
75
+ # @staticmethod
76
+ # def max_diff_comparison(t1, t2, threshold=0.85):
77
+ # with torch.no_grad():
78
+ # max_diff = torch.max(torch.abs(t1 - t2))
79
+ # return max_diff.item() < threshold
80
+
81
+ # # Memory management
82
+ # class MemoryManager:
83
+ # @staticmethod
84
+ # def empty_cache():
85
+ # gc.collect()
86
+ # torch.cuda.empty_cache()
87
+ # torch.cuda.reset_max_memory_allocated()
88
+ # torch.cuda.reset_peak_memory_stats()
89
+
90
+ # # Pipeline management
91
+ # class PipelineManager:
92
+ # @staticmethod
93
+ # def load_pipeline() -> DiffusionPipeline:
94
+ # MemoryManager.empty_cache()
95
+
96
+ # text_encoder_2 = T5EncoderModel.from_pretrained(
97
+ # "city96/t5-v1_1-xxl-encoder-bf16",
98
+ # revision="1b9c856aadb864af93c1dcdc226c2774fa67bc86",
99
+ # torch_dtype=Config.DTYPE
100
+ # ).to(memory_format=torch.channels_last)
101
+ # vae = AutoencoderTiny.from_pretrained("RobertML/FLUX.1-schnell-vae_e3m2", revision="da0d2cd7815792fb40d084dbd8ed32b63f153d8d", torch_dtype=Config.DTYPE)
102
+ # # vae.encoder=_load(vae.encoder, "E", dtype=torch.bfloat16); vae.decoder=_load(vae.decoder, "D", dtype=torch.bfloat16)
103
+
104
+ # path = os.path.join(HF_HUB_CACHE, "models--RobertML--FLUX.1-schnell-int8wo/snapshots/307e0777d92df966a3c0f99f31a6ee8957a9857a")
105
+ # model = FluxTransformer2DModel.from_pretrained(
106
+ # path,
107
+ # torch_dtype=Config.DTYPE,
108
+ # use_safetensors=False
109
+ # ).to(memory_format=torch.channels_last)
110
+
111
+ # pipeline = DiffusionPipeline.from_pretrained(
112
+ # Config.CKPT_ID,
113
+ # vae=vae,
114
+ # revision=Config.CKPT_REVISION,
115
+ # transformer=model,
116
+ # text_encoder_2=text_encoder_2,
117
+ # torch_dtype=Config.DTYPE,
118
+ # ).to(Config.DEVICE)
119
+
120
+ # apply_cache_on_pipe(pipeline)
121
+ # pipeline.to(memory_format=torch.channels_last)
122
+ # pipeline.vae = torch.compile(pipeline.vae, mode="max-autotune")
123
+ # quantize_(pipeline.vae, int8_weight_only())
124
+ # quantize_(pipeline.vae, float8_weight_only())
125
+ # PipelineManager._warmup(pipeline)
126
+
127
+ # return pipeline
128
+
129
+ # @staticmethod
130
+ # def _warmup(pipeline):
131
+ # for _ in range(3):
132
+ # pipeline(prompt=" ")
133
+
134
+ # @staticmethod
135
+ # @torch.no_grad()
136
+ # def infer(request: TextToImageRequest, pipeline: DiffusionPipeline, generator: Generator) -> Image:
137
+ # image = pipeline(
138
+ # request.prompt,
139
+ # generator=generator,
140
+ # guidance_scale=0.0,
141
+ # num_inference_steps=4,
142
+ # max_sequence_length=256,
143
+ # height=request.height,
144
+ # width=request.width,
145
+ # output_type="pil"
146
+ # ).images[0]
147
+ # return image
148
+
149
+ # # Initialize global settings
150
+ # init_global_settings()
151
+
152
+ # # Keep original interface
153
+ # load_pipeline = PipelineManager.load_pipeline
154
+ # infer = PipelineManager.infer
155
+ # are_two_tensors_similar = TensorComparator.orig_comparison
156
+ # are_two_tensors_similar_relative = TensorComparator.relative_comparison
157
+ # are_two_tensors_similar_normalized = TensorComparator.normalized_comparison
158
+ # are_two_tensors_similar_l1 = TensorComparator.l1_comparison
159
+ # are_two_tensors_similar_max_diff = TensorComparator.max_diff_comparison
160
+ # empty_cache = MemoryManager.empty_cache
161
+
162
+
163
+ from __future__ import annotations
164
  import os
 
 
165
  import torch
166
+ import functools
167
+ from enum import Enum, auto
168
+ from contextlib import contextmanager
169
+ from typing import Protocol, TypeVar, Generic, Callable, Any
170
+ from dataclasses import dataclass, field
171
  from PIL.Image import Image
172
+ from torch import Generator
173
+ from diffusers import DiffusionPipeline, AutoencoderTiny, FluxTransformer2DModel
 
 
 
 
 
 
174
  from transformers import T5EncoderModel
175
  from huggingface_hub.constants import HF_HUB_CACHE
176
  from torchao.quantization import quantize_, int8_weight_only, float8_weight_only
177
+ from first_block_cache.diffusers_adapters import apply_cache_on_pipe
178
  from pipelines.models import TextToImageRequest
 
 
179
 
180
+ T = TypeVar('T')
181
+
182
+ class ModelComponent(Protocol):
183
+ def to(self, *args, **kwargs) -> Any: ...
184
+
185
+ class ModelState(Enum):
186
+ INITIALIZED = auto()
187
+ LOADED = auto()
188
+ OPTIMIZED = auto()
189
+ READY = auto()
190
+
191
+ class ResourceMonitor:
192
+ """Monitors and manages system resources."""
193
+
194
+ @contextmanager
195
+ def monitor_memory(self, threshold_mb: int = 1000):
196
+ initial_memory = torch.cuda.memory_allocated() / 1024**2
197
+ yield
198
+ final_memory = torch.cuda.memory_allocated() / 1024**2
199
+ if final_memory - initial_memory > threshold_mb:
200
+ torch.cuda.empty_cache()
201
+
202
  @dataclass
203
+ class ModelRegistry(Generic[T]):
204
+ """Type-safe registry for model components."""
205
+ _components: dict[str, T] = field(default_factory=dict)
 
 
 
 
 
 
 
 
 
 
206
 
207
+ def register(self, name: str, component: T) -> None:
208
+ self._components[name] = component
 
 
 
 
 
 
209
 
210
+ def get(self, name: str) -> T:
211
+ return self._components[name]
212
+
213
+ def __iter__(self):
214
+ return iter(self._components.values())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
+ class PipelineBuilder:
217
+ """Fluent builder for pipeline construction."""
218
+
219
+ def __init__(self):
220
+ self.config = {
221
+ "model_id": "black-forest-labs/FLUX.1-schnell",
222
+ "revision": "741f7c3ce8b383c54771c7003378a50191e9efe9",
223
+ "device": "cuda",
224
+ "dtype": torch.bfloat16
225
+ }
226
+ self.registry = ModelRegistry[ModelComponent]()
227
+ self.state = ModelState.INITIALIZED
228
+ self.monitor = ResourceMonitor()
229
+
230
+ def with_torch_settings(self) -> PipelineBuilder:
231
+ torch.backends.cuda.matmul.allow_tf32 = True
232
+ torch.backends.cudnn.enabled = True
233
+ torch.backends.cudnn.benchmark = True
234
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "expandable_segments:True"
235
+ return self
236
+
237
+ def load_text_encoder(self) -> PipelineBuilder:
238
+ with self.monitor.monitor_memory():
239
+ encoder = T5EncoderModel.from_pretrained(
240
+ "city96/t5-v1_1-xxl-encoder-bf16",
241
+ revision="1b9c856aadb864af93c1dcdc226c2774fa67bc86",
242
+ torch_dtype=self.config["dtype"]
243
+ ).to(memory_format=torch.channels_last)
244
+ self.registry.register("text_encoder", encoder)
245
+ return self
246
+
247
+ def load_vae(self) -> PipelineBuilder:
248
+ with self.monitor.monitor_memory():
249
+ vae = AutoencoderTiny.from_pretrained(
250
+ "RobertML/FLUX.1-schnell-vae_e3m2",
251
+ revision="da0d2cd7815792fb40d084dbd8ed32b63f153d8d",
252
+ torch_dtype=self.config["dtype"]
253
+ )
254
+ self.registry.register("vae", vae)
255
+ return self
256
+
257
+ def load_transformer(self) -> PipelineBuilder:
258
+ with self.monitor.monitor_memory():
259
+ path = os.path.join(
260
+ HF_HUB_CACHE,
261
+ "models--RobertML--FLUX.1-schnell-int8wo/snapshots/307e0777d92df966a3c0f99f31a6ee8957a9857a"
262
+ )
263
+ transformer = FluxTransformer2DModel.from_pretrained(
264
+ path,
265
+ torch_dtype=self.config["dtype"],
266
+ use_safetensors=False
267
+ ).to(memory_format=torch.channels_last)
268
+ self.registry.register("transformer", transformer)
269
+ return self
270
+
271
+ def optimize(self, pipeline: DiffusionPipeline) -> PipelineBuilder:
272
+ with self.monitor.monitor_memory():
273
+ pipeline.to(memory_format=torch.channels_last)
274
+ pipeline.vae = torch.compile(pipeline.vae, mode="max-autotune")
275
+ quantize_(pipeline.vae, int8_weight_only())
276
+ quantize_(pipeline.vae, float8_weight_only())
277
+ apply_cache_on_pipe(pipeline)
278
+ return self
279
+
280
+ def warmup(self, pipeline: DiffusionPipeline) -> PipelineBuilder:
281
+ with torch.no_grad(), self.monitor.monitor_memory():
282
+ for _ in range(3):
283
+ pipeline(prompt=" ")
284
+ return self
285
+
286
+ def build(self) -> DiffusionPipeline:
287
  pipeline = DiffusionPipeline.from_pretrained(
288
+ self.config["model_id"],
289
+ vae=self.registry.get("vae"),
290
+ revision=self.config["revision"],
291
+ transformer=self.registry.get("transformer"),
292
+ text_encoder_2=self.registry.get("text_encoder"),
293
+ torch_dtype=self.config["dtype"],
294
+ ).to(self.config["device"])
 
 
 
 
 
 
 
295
 
296
+ self.optimize(pipeline)
297
+ self.warmup(pipeline)
298
+ self.state = ModelState.READY
299
  return pipeline
300
 
301
+ class InferenceContext:
302
+ """Context manager for inference operations."""
303
+
304
+ def __init__(self, pipeline: DiffusionPipeline):
305
+ self.pipeline = pipeline
306
+ self.monitor = ResourceMonitor()
307
+
308
+ @contextmanager
309
+ def inference_mode(self):
310
+ with torch.no_grad(), self.monitor.monitor_memory():
311
+ yield self.pipeline
312
+
313
+ def load() -> DiffusionPipeline:
314
+ """Build and configure the pipeline using the fluent builder pattern."""
315
+ return (PipelineBuilder()
316
+ .with_torch_settings()
317
+ .load_text_encoder()
318
+ .load_vae()
319
+ .load_transformer()
320
+ .build())
321
 
322
+ def infer(request: TextToImageRequest, pipeline: DiffusionPipeline, generator: Generator) -> Image:
323
+ """Generate image using the pipeline within a managed context."""
324
+ context = InferenceContext(pipeline)
325
+ with context.inference_mode() as p:
326
+ return p(
327
+ prompt=request.prompt,
328
  generator=generator,
329
  guidance_scale=0.0,
330
  num_inference_steps=4,
 
332
  height=request.height,
333
  width=request.width,
334
  output_type="pil"
335
+ ).images[0]