import os from typing import Dict, List, Any import sys rootDir = os.path.abspath(os.path.dirname(__file__)) sys.path.append(rootDir) from imageRequest import ImageRequest from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler from huggingface_hub import hf_hub_download from safetensors.torch import load_file import torch class EndpointHandler: def __init__(self, path=""): # Preload all the elements you are going to need at inference. # pseudo: # self.model= load_model(path) self.pipe = None self.modelName = "" baseReq = ImageRequest() baseReq.model = "SG161222/RealVisXL_V4.0" self.LoadModel(baseReq) def LoadModel(self, request): base = "stabilityai/stable-diffusion-xl-base-1.0" repo = "ByteDance/SDXL-Lightning" ckpt = "sdxl_lightning_8step_unet.safetensors" # Use the correct ckpt for your step setting! if request.model == "default": request.model = base else: base = request.model if self.pipe is None or self.modelName != request.model: # Load model. unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16) unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda")) pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda") # Ensure sampler uses "trailing" timesteps. pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") self.pipe = pipe def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: input (:obj: `str` | `PIL.Image` | `np.array`) seed (:obj: `int`) prompt (:obj: `str`) negative_prompt (:obj: `str`) steps (:obj: `int`) guidance_scale (:obj: `float`) width (:obj: `int`) height (:obj: `int`) kwargs Return: A :obj:`list` | `dict`: will be serialized and returned """ inputs = data.pop("inputs", data) request = ImageRequest.FromDict(inputs) response = self.__runProcess__(request) return response def ImageToBase64(self, image): import io import base64 from PIL import Image buffered = io.BytesIO() image.save(buffered, format="PNG") return base64.b64encode(buffered.getvalue()).decode() def __runProcess__(self, request: ImageRequest) -> List[Dict[str, Any]]: """ Run SDXL Lightning pipeline """ import torch # Ensure using the same inference steps as the loaded model and CFG set to 0. images = pipe(request.prompt, negative_prompt = request.negative_prompt, num_inference_steps=request.steps, guidance_scale=0).images return {"media":[self.ImageToBase64(img) for img in images]}