|
from io import BytesIO |
|
from fastapi import Response |
|
import torch |
|
import time |
|
import litserve as ls |
|
from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL |
|
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel |
|
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline |
|
from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast |
|
|
|
class FluxLitAPI(ls.LitAPI): |
|
def setup(self, device): |
|
|
|
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="scheduler", revision="refs/pr/1") |
|
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.bfloat16) |
|
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.bfloat16) |
|
text_encoder_2 = T5EncoderModel.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="text_encoder_2", torch_dtype=torch.bfloat16, revision="refs/pr/1") |
|
tokenizer_2 = T5TokenizerFast.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="tokenizer_2", torch_dtype=torch.bfloat16, revision="refs/pr/1") |
|
vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="vae", torch_dtype=torch.bfloat16, revision="refs/pr/1") |
|
transformer = FluxTransformer2DModel.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="transformer", torch_dtype=torch.bfloat16, revision="refs/pr/1") |
|
|
|
self.pipe = FluxPipeline( |
|
scheduler=scheduler, |
|
text_encoder=text_encoder, |
|
tokenizer=tokenizer, |
|
text_encoder_2=None, |
|
tokenizer_2=tokenizer_2, |
|
vae=vae, |
|
transformer=None, |
|
) |
|
self.pipe.text_encoder_2 = text_encoder_2 |
|
self.pipe.transformer = transformer |
|
self.pipe.enable_model_cpu_offload() |
|
|
|
|
|
def decode_request(self, request): |
|
|
|
prompt = request["prompt"] |
|
return prompt |
|
|
|
def predict(self, prompt): |
|
|
|
image = self.pipe( |
|
prompt=prompt, |
|
width=1024, |
|
height=1024, |
|
num_inference_steps=4, |
|
generator=torch.Generator().manual_seed(int(time.time())), |
|
guidance_scale=3.5, |
|
).images[0] |
|
|
|
return image |
|
|
|
def encode_response(self, image): |
|
buffered = BytesIO() |
|
image.save(buffered, format="PNG") |
|
return Response(content=buffered.getvalue(), headers={"Content-Type": "image/png"}) |
|
|
|
if __name__ == "__main__": |
|
api = FluxLitAPI() |
|
server = ls.LitServer(api, timeout=False) |
|
server.run(port=8000) |