from diffusers import ( DiffusionPipeline, AutoencoderKL, FluxPipeline, FluxTransformer2DModel ) from diffusers.image_processor import VaeImageProcessor from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from huggingface_hub.constants import HF_HUB_CACHE from transformers import ( T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel ) import torch import torch._dynamo import gc from PIL import Image from pipelines.models import TextToImageRequest from torch import Generator import time import math from typing import Type, Dict, Any, Tuple, Callable, Optional, Union import numpy as np import torch.nn as nn import torch.nn.functional as F from torchao.quantization import quantize_, float8_weight_only, int4_weight_only from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe as cacher import os os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True" os.environ["TOKENIZERS_PARALLELISM"] = "True" torch._dynamo.config.suppress_errors = True torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True torch.cuda.set_per_process_memory_fraction(0.99) # globals Pipeline = None ckpt_id = "manbeast3b/flux.1-schnell-full1" ckpt_revision = "cb1b599b0d712b9aab2c4df3ad27b050a27ec146" def load_pipeline() -> Pipeline: model_name = "manbeast3b/Flux.1.Schnell-full-quant1" revision = "e7ddf488a4ea8a3cba05db5b8d06e7e0feb826a2" hub_model_dir = os.path.join( HF_HUB_CACHE, f"models--{model_name.replace('/', '--')}", "snapshots", revision, "transformer" ) transformer = FluxTransformer2DModel.from_pretrained( hub_model_dir, torch_dtype=torch.bfloat16, use_safetensors=False ).to(memory_format=torch.channels_last) pipeline = FluxPipeline.from_pretrained( ckpt_id, revision=ckpt_revision, # text_encoder_2=text_encoder_2, transformer=transformer, # vae=vae, torch_dtype=torch.bfloat16 ) # pipeline.vae = torch.compile(vae) pipeline.to("cuda") pipeline = cacher(pipeline,residual_diff_threshold=0.56) quantize_(pipeline.vae, int4_weight_only()) warmup_ = "controllable varied focus thai warriors entertainment claude still goat gang gang yeah" for _ in range(1): pipeline( prompt=warmup_, width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256 ) return pipeline sample = 1 @torch.no_grad() def infer(request: TextToImageRequest, pipeline: Pipeline, generator: Generator) -> Image: global sample if not sample: sample=1 gc.collect() torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() torch.cuda.reset_peak_memory_stats() return pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0]