Boese0601 commited on
Commit
2a887f4
·
verified ·
1 Parent(s): 80c1007

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -14,7 +14,7 @@ from src.flux.xflux_pipeline import XFluxSampler
14
  args = OmegaConf.load("inference_configs/inference.yaml")
15
  # is_schnell = args.model_name == "flux-schnell"
16
  # sampler = None
17
- # device = torch.device("cuda")
18
  # dtype = torch.bfloat16
19
  # dit = load_flow_model2(args.model_name, device="cpu").to(device, dtype=dtype)
20
  # vae = load_ae(args.model_name, device="cpu").to(device, dtype=dtype)
@@ -91,6 +91,7 @@ def generate(image: Image.Image, edit_prompt: str):
91
  img = img.permute(2, 0, 1).unsqueeze(0).to(device, dtype=dtype)
92
 
93
  result = sampler(
 
94
  prompt=edit_prompt,
95
  width=args.sample_width,
96
  height=args.sample_height,
 
14
  args = OmegaConf.load("inference_configs/inference.yaml")
15
  # is_schnell = args.model_name == "flux-schnell"
16
  # sampler = None
17
+ device = torch.device("cuda")
18
  # dtype = torch.bfloat16
19
  # dit = load_flow_model2(args.model_name, device="cpu").to(device, dtype=dtype)
20
  # vae = load_ae(args.model_name, device="cpu").to(device, dtype=dtype)
 
91
  img = img.permute(2, 0, 1).unsqueeze(0).to(device, dtype=dtype)
92
 
93
  result = sampler(
94
+ device='cuda',
95
  prompt=edit_prompt,
96
  width=args.sample_width,
97
  height=args.sample_height,