Akeno / server.py
randydev's picture
Update server.py
b379d26 verified
raw
history blame
2.74 kB
from io import BytesIO
from fastapi import Response
import torch
import time
import litserve as ls
from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast
class FluxLitAPI(ls.LitAPI):
def setup(self, device):
# Load the model
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="scheduler", revision="refs/pr/1")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.bfloat16)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.bfloat16)
text_encoder_2 = T5EncoderModel.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="text_encoder_2", torch_dtype=torch.bfloat16, revision="refs/pr/1")
tokenizer_2 = T5TokenizerFast.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="tokenizer_2", torch_dtype=torch.bfloat16, revision="refs/pr/1")
vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="vae", torch_dtype=torch.bfloat16, revision="refs/pr/1")
transformer = FluxTransformer2DModel.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="transformer", torch_dtype=torch.bfloat16, revision="refs/pr/1")
self.pipe = FluxPipeline(
scheduler=scheduler,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=None,
tokenizer_2=tokenizer_2,
vae=vae,
transformer=None,
)
self.pipe.text_encoder_2 = text_encoder_2
self.pipe.transformer = transformer
self.pipe.enable_model_cpu_offload()
def decode_request(self, request):
# Extract prompt from request
prompt = request["prompt"]
return prompt
def predict(self, prompt):
# Generate image from prompt
image = self.pipe(
prompt=prompt,
width=1024,
height=1024,
num_inference_steps=4,
generator=torch.Generator().manual_seed(int(time.time())),
guidance_scale=3.5,
).images[0]
return image
def encode_response(self, image):
buffered = BytesIO()
image.save(buffered, format="PNG")
return Response(content=buffered.getvalue(), headers={"Content-Type": "image/png"})
if __name__ == "__main__":
api = FluxLitAPI()
server = ls.LitServer(api, timeout=False)
server.run(port=8000)