randydev commited on
Commit
944d63d
·
verified ·
1 Parent(s): 8d05188

Update server.py

Browse files
Files changed (1) hide show
  1. server.py +8 -58
server.py CHANGED
@@ -1,62 +1,12 @@
1
- from io import BytesIO
2
- from fastapi import Response
3
- import torch
4
- import time
5
- import litserve as ls
6
- from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL
7
- from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
8
- from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
9
- from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast
10
 
11
- class FluxLitAPI(ls.LitAPI):
12
- def setup(self, device):
13
- # Load the model
14
- scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="scheduler", revision="refs/pr/1")
15
- text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.bfloat16)
16
- tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.bfloat16)
17
- text_encoder_2 = T5EncoderModel.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="text_encoder_2", torch_dtype=torch.bfloat16, revision="refs/pr/1")
18
- tokenizer_2 = T5TokenizerFast.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="tokenizer_2", torch_dtype=torch.bfloat16, revision="refs/pr/1")
19
- vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="vae", torch_dtype=torch.bfloat16, revision="refs/pr/1")
20
- transformer = FluxTransformer2DModel.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="transformer", torch_dtype=torch.bfloat16, revision="refs/pr/1")
21
 
22
- self.pipe = FluxPipeline(
23
- scheduler=scheduler,
24
- text_encoder=text_encoder,
25
- tokenizer=tokenizer,
26
- text_encoder_2=None,
27
- tokenizer_2=tokenizer_2,
28
- vae=vae,
29
- transformer=None,
30
- )
31
- self.pipe.text_encoder_2 = text_encoder_2
32
- self.pipe.transformer = transformer
33
- self.pipe.enable_model_cpu_offload()
34
 
35
-
36
- def decode_request(self, request):
37
- # Extract prompt from request
38
- prompt = request["prompt"]
39
- return prompt
40
-
41
- def predict(self, prompt):
42
- # Generate image from prompt
43
- image = self.pipe(
44
- prompt=prompt,
45
- width=1024,
46
- height=1024,
47
- num_inference_steps=4,
48
- generator=torch.Generator().manual_seed(int(time.time())),
49
- guidance_scale=3.5,
50
- ).images[0]
51
-
52
- return image
53
-
54
- def encode_response(self, image):
55
- buffered = BytesIO()
56
- image.save(buffered, format="PNG")
57
- return Response(content=buffered.getvalue(), headers={"Content-Type": "image/png"})
58
-
59
  if __name__ == "__main__":
60
- api = FluxLitAPI()
61
- server = ls.LitServer(api, timeout=False)
62
- server.run(port=8000)
 
1
+ import os
2
+ import uvicorn
3
+ from fastapi import FastAPI
 
 
 
 
 
 
4
 
5
+ app = FastAPI()
 
 
 
 
 
 
 
 
 
6
 
7
+ @app.get("/")
8
+ def hello():
9
+ return {"message": "running"}
 
 
 
 
 
 
 
 
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  if __name__ == "__main__":
12
+ uvicorn.run(app, host="0.0.0.0", port=7860)