randydev commited on
Commit
37b42d2
·
verified ·
1 Parent(s): 19600bf

Create server.py

Browse files
Files changed (1) hide show
  1. server.py +69 -0
server.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+ from fastapi import Response
3
+ import torch
4
+ import time
5
+ import litserve as ls
6
+ from optimum.quanto import freeze, qfloat8, quantize
7
+ from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL
8
+ from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
9
+ from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
10
+ from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast
11
+
12
+ class FluxLitAPI(ls.LitAPI):
13
+ def setup(self, device):
14
+ # Load the model
15
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="scheduler", revision="refs/pr/1")
16
+ text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.bfloat16)
17
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.bfloat16)
18
+ text_encoder_2 = T5EncoderModel.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="text_encoder_2", torch_dtype=torch.bfloat16, revision="refs/pr/1")
19
+ tokenizer_2 = T5TokenizerFast.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="tokenizer_2", torch_dtype=torch.bfloat16, revision="refs/pr/1")
20
+ vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="vae", torch_dtype=torch.bfloat16, revision="refs/pr/1")
21
+ transformer = FluxTransformer2DModel.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="transformer", torch_dtype=torch.bfloat16, revision="refs/pr/1")
22
+
23
+ # quantize to 8-bit to fit on an L4
24
+ quantize(transformer, weights=qfloat8)
25
+ freeze(transformer)
26
+ quantize(text_encoder_2, weights=qfloat8)
27
+ freeze(text_encoder_2)
28
+
29
+ self.pipe = FluxPipeline(
30
+ scheduler=scheduler,
31
+ text_encoder=text_encoder,
32
+ tokenizer=tokenizer,
33
+ text_encoder_2=None,
34
+ tokenizer_2=tokenizer_2,
35
+ vae=vae,
36
+ transformer=None,
37
+ )
38
+ self.pipe.text_encoder_2 = text_encoder_2
39
+ self.pipe.transformer = transformer
40
+ self.pipe.enable_model_cpu_offload()
41
+
42
+
43
+ def decode_request(self, request):
44
+ # Extract prompt from request
45
+ prompt = request["prompt"]
46
+ return prompt
47
+
48
+ def predict(self, prompt):
49
+ # Generate image from prompt
50
+ image = self.pipe(
51
+ prompt=prompt,
52
+ width=1024,
53
+ height=1024,
54
+ num_inference_steps=4,
55
+ generator=torch.Generator().manual_seed(int(time.time())),
56
+ guidance_scale=3.5,
57
+ ).images[0]
58
+
59
+ return image
60
+
61
+ def encode_response(self, image):
62
+ buffered = BytesIO()
63
+ image.save(buffered, format="PNG")
64
+ return Response(content=buffered.getvalue(), headers={"Content-Type": "image/png"})
65
+
66
+ if __name__ == "__main__":
67
+ api = FluxLitAPI()
68
+ server = ls.LitServer(api, timeout=False)
69
+ server.run(port=8000)