File size: 3,149 Bytes
5e2d8cd
 
 
 
 
 
d3a4811
 
 
 
5e2d8cd
 
 
 
 
 
 
 
d3a4811
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e2d8cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3a4811
5e2d8cd
d3a4811
5e2d8cd
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
from typing import Dict, List, Any
import sys
rootDir = os.path.abspath(os.path.dirname(__file__))
sys.path.append(rootDir)
from imageRequest import ImageRequest
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import torch

class EndpointHandler:
    def __init__(self, path=""):
        # Preload all the elements you are going to need at inference.
        # pseudo:
        # self.model= load_model(path)
        self.pipe = None
        self.modelName = ""
        baseReq = ImageRequest()
        baseReq.model = "SG161222/RealVisXL_V4.0"
        self.LoadModel(baseReq)
        
    def LoadModel(self, request):
        base = "stabilityai/stable-diffusion-xl-base-1.0"
        repo = "ByteDance/SDXL-Lightning"
        ckpt = "sdxl_lightning_8step_unet.safetensors" # Use the correct ckpt for your step setting!
        if request.model == "default":
            request.model = base
        else:
            base = request.model
        if self.pipe is None or self.modelName != request.model:
            # Load model.
            unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
            unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda"))
            pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")

            # Ensure sampler uses "trailing" timesteps.
            pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
            self.pipe = pipe
        
    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
         data args:
              input (:obj: `str` | `PIL.Image` | `np.array`)
              seed (:obj: `int`)
              prompt (:obj: `str`)
              negative_prompt (:obj: `str`)
              steps (:obj: `int`)
              guidance_scale (:obj: `float`)
              width (:obj: `int`)
              height (:obj: `int`)
              
              kwargs
        Return:
              A :obj:`list` | `dict`: will be serialized and returned
        """
        inputs = data.pop("inputs", data)
        request = ImageRequest.FromDict(inputs)
        response = self.__runProcess__(request)
        return response
    
    def ImageToBase64(self, image):
        import io
        import base64
        from PIL import Image
        buffered = io.BytesIO()
        image.save(buffered, format="PNG")
        return base64.b64encode(buffered.getvalue()).decode()
    
    def __runProcess__(self, request: ImageRequest) -> List[Dict[str, Any]]:
        """
        Run SDXL Lightning pipeline
        """
        import torch
        

        

        # Ensure using the same inference steps as the loaded model and CFG set to 0.
        images = pipe(request.prompt, negative_prompt = request.negative_prompt, num_inference_steps=request.steps, guidance_scale=0).images

        return {"media":[self.ImageToBase64(img) for img in images]}