Boese0601 commited on
Commit
8373499
·
verified ·
1 Parent(s): ba8544d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -26
app.py CHANGED
@@ -12,16 +12,14 @@ from image_datasets.dataset import image_resize
12
  from src.flux.util import load_ae, load_clip, load_flow_model2, load_t5, tensor_to_pil_image
13
  from src.flux.xflux_pipeline import XFluxSampler
14
  args = OmegaConf.load("inference_configs/inference.yaml")
15
- is_schnell = args.model_name == "flux-schnell"
16
- '/home/user/app/assets/0_camera_zoom/20486354.png'
17
- '/home/user/app/assets/0_camera_zoom/20486354.png'
18
  # sampler = None
19
- device = torch.device("cuda")
20
- dtype = torch.bfloat16
21
- dit = load_flow_model2(args.model_name, device="cpu").to(device, dtype=dtype)
22
- vae = load_ae(args.model_name, device="cpu").to(device, dtype=dtype)
23
- t5 = load_t5(device="cpu", max_length=256 if is_schnell else 512).to(device, dtype=dtype)
24
- clip = load_clip("cpu").to(device, dtype=dtype)
25
  #test push
26
  @spaces.GPU
27
  def generate(image: Image.Image, edit_prompt: str):
@@ -29,26 +27,21 @@ def generate(image: Image.Image, edit_prompt: str):
29
 
30
 
31
 
32
- vae.requires_grad_(False)
33
- t5.requires_grad_(False)
34
- clip.requires_grad_(False)
35
 
36
- model_path = hf_hub_download(
37
- repo_id="Boese0601/ByteMorpher",
38
- filename="dit.safetensors",
39
- use_auth_token=os.getenv("HF_TOKEN")
40
- )
41
- state_dict = load_file(model_path)
42
- dit.load_state_dict(state_dict)
43
- dit.eval()
44
- dit.to(device, dtype=dtype)
45
 
46
  sampler = XFluxSampler(
47
- clip=clip,
48
- t5=t5,
49
- ae=vae,
50
- model=dit,
51
- device=device,
52
  ip_loaded=False,
53
  spatial_condition=False,
54
  clip_image_processor=None,
 
12
  from src.flux.util import load_ae, load_clip, load_flow_model2, load_t5, tensor_to_pil_image
13
  from src.flux.xflux_pipeline import XFluxSampler
14
  args = OmegaConf.load("inference_configs/inference.yaml")
15
+ # is_schnell = args.model_name == "flux-schnell"
 
 
16
  # sampler = None
17
+ # device = torch.device("cuda")
18
+ # dtype = torch.bfloat16
19
+ # dit = load_flow_model2(args.model_name, device="cpu").to(device, dtype=dtype)
20
+ # vae = load_ae(args.model_name, device="cpu").to(device, dtype=dtype)
21
+ # t5 = load_t5(device="cpu", max_length=256 if is_schnell else 512).to(device, dtype=dtype)
22
+ # clip = load_clip("cpu").to(device, dtype=dtype)
23
  #test push
24
  @spaces.GPU
25
  def generate(image: Image.Image, edit_prompt: str):
 
27
 
28
 
29
 
30
+ # vae.requires_grad_(False)
31
+ # t5.requires_grad_(False)
32
+ # clip.requires_grad_(False)
33
 
34
+ # model_path = hf_hub_download(
35
+ # repo_id="Boese0601/ByteMorpher",
36
+ # filename="dit.safetensors",
37
+ # use_auth_token=os.getenv("HF_TOKEN")
38
+ # )
39
+ # state_dict = load_file(model_path)
40
+ # dit.load_state_dict(state_dict)
41
+ # dit.eval()
42
+ # dit.to(device, dtype=dtype)
43
 
44
  sampler = XFluxSampler(
 
 
 
 
 
45
  ip_loaded=False,
46
  spatial_condition=False,
47
  clip_image_processor=None,