## Setup & Installation

In [1]:
%%writefile requirements.txt
diffusers==0.2.4

Overwriting requirements.txt


In [None]:
!pip install -r requirements.txt --upgrade

## 3. Create Custom Handler for Inference Endpoints


In [10]:
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [11]:
if device.type != 'cuda':
    raise ValueError("need to run on GPU")

In [5]:
%%writefile handler.py
from typing import  Dict, List, Any
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline
import base64
from io import BytesIO


# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if device.type != 'cuda':
    raise ValueError("need to run on GPU")

class EndpointHandler():
    def __init__(self, path=""):
        # load the optimized model
        self.pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16)
        self.pipe = self.pipe.to(device)


    def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
        """
        Args:
            data (:obj:):
                includes the input data and the parameters for the inference.
        Return:
            A :obj:`dict`:. base64 encoded image
        """
        inputs = data.pop("inputs", data)
        
        # run inference pipeline
        with autocast(device.type):
            image = self.pipe(inputs, guidance_scale=7.5)["sample"][0]  
            
        # encode image as base 64
        buffered = BytesIO()
        image.save(buffered, format="JPEG")
        img_str = base64.b64encode(buffered.getvalue())

        # postprocess the prediction
        return {"image": img_str.decode()}

Overwriting handler.py


test custom pipeline

In [6]:
import torch

torch.__version__

'1.11.0+cu113'

In [1]:
from handler import EndpointHandler

# init handler
my_handler = EndpointHandler(path=".")

ftfy or spacy is not installed using BERT BasicTokenizer instead of ftfy.


In [6]:
import base64
from PIL import Image
from io import BytesIO
import json

# helper decoder
def decode_base64_image(image_string):
  base64_image = base64.b64decode(image_string)
  buffer = BytesIO(base64_image)
  return  Image.open(buffer)

# prepare sample payload
request = {"inputs": "a high resulotion image of a macbook"}

# test the handler
pred = my_handler(request)

0it [00:00, ?it/s]

In [4]:
decode_base64_image(pred["image"]).save("sample.jpg")

![test](sample.jpg)