bol commited on
Commit
47dbef4
·
1 Parent(s): de42ae8
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
app.py CHANGED
@@ -4,18 +4,17 @@ import spaces
4
  import os
5
  import numpy as np
6
  from PIL import Image
7
-
8
  from huggingface_hub import hf_hub_download
9
  from safetensors.torch import load_file
10
  from omegaconf import OmegaConf
 
 
11
  from src.flux.util import load_ae, load_clip, load_flow_model2, load_t5, tensor_to_pil_image
12
  from src.flux.xflux_pipeline import XFluxSampler
13
- from image_datasets.dataset import image_resize
14
-
15
- # ===== No CUDA/model initialization globally =====
16
  args = OmegaConf.load("inference_configs/inference.yaml")
17
  is_schnell = args.model_name == "flux-schnell"
18
-
 
19
  # sampler = None
20
  device = torch.device("cuda")
21
  dtype = torch.bfloat16
@@ -23,36 +22,39 @@ dit = load_flow_model2(args.model_name, device="cpu").to(device, dtype=dtype)
23
  vae = load_ae(args.model_name, device="cpu").to(device, dtype=dtype)
24
  t5 = load_t5(device="cpu", max_length=256 if is_schnell else 512).to(device, dtype=dtype)
25
  clip = load_clip("cpu").to(device, dtype=dtype)
26
-
27
- vae.requires_grad_(False)
28
- t5.requires_grad_(False)
29
- clip.requires_grad_(False)
30
-
31
- model_path = hf_hub_download(
32
- repo_id="Boese0601/ByteMorpher",
33
- filename="dit.safetensors",
34
- use_auth_token=os.getenv("HF_TOKEN")
35
- )
36
- state_dict = load_file(model_path)
37
- dit.load_state_dict(state_dict)
38
- dit.eval()
39
- dit.to(device, dtype=dtype)
40
-
41
- sampler = XFluxSampler(
42
- clip=clip,
43
- t5=t5,
44
- ae=vae,
45
- model=dit,
46
- device=device,
47
- ip_loaded=False,
48
- spatial_condition=False,
49
- clip_image_processor=None,
50
- image_encoder=None,
51
- improj=None
52
- )
53
  #test push
54
  @spaces.GPU
55
  def generate(image: Image.Image, edit_prompt: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  # global sampler
57
  # device = torch.device("cuda")
58
  # dtype = torch.bfloat16
@@ -95,18 +97,17 @@ def generate(image: Image.Image, edit_prompt: str):
95
  img = torch.from_numpy((np.array(img) / 127.5) - 1)
96
  img = img.permute(2, 0, 1).unsqueeze(0).to(device, dtype=dtype)
97
 
98
- with torch.no_grad():
99
- result = sampler(
100
- prompt=edit_prompt,
101
- width=args.sample_width,
102
- height=args.sample_height,
103
- num_steps=args.sample_steps,
104
- image_prompt=None,
105
- true_gs=args.cfg_scale,
106
- seed=args.seed,
107
- ip_scale=args.ip_scale if args.use_ip else 1.0,
108
- source_image=img if args.use_spatial_condition else None,
109
- )
110
  return tensor_to_pil_image(result)
111
 
112
  def get_samples():
 
4
  import os
5
  import numpy as np
6
  from PIL import Image
 
7
  from huggingface_hub import hf_hub_download
8
  from safetensors.torch import load_file
9
  from omegaconf import OmegaConf
10
+
11
+ from image_datasets.dataset import image_resize
12
  from src.flux.util import load_ae, load_clip, load_flow_model2, load_t5, tensor_to_pil_image
13
  from src.flux.xflux_pipeline import XFluxSampler
 
 
 
14
  args = OmegaConf.load("inference_configs/inference.yaml")
15
  is_schnell = args.model_name == "flux-schnell"
16
+ '/home/user/app/assets/0_camera_zoom/20486354.png'
17
+ '/home/user/app/assets/0_camera_zoom/20486354.png'
18
  # sampler = None
19
  device = torch.device("cuda")
20
  dtype = torch.bfloat16
 
22
  vae = load_ae(args.model_name, device="cpu").to(device, dtype=dtype)
23
  t5 = load_t5(device="cpu", max_length=256 if is_schnell else 512).to(device, dtype=dtype)
24
  clip = load_clip("cpu").to(device, dtype=dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  #test push
26
  @spaces.GPU
27
  def generate(image: Image.Image, edit_prompt: str):
28
+
29
+
30
+
31
+
32
+ vae.requires_grad_(False)
33
+ t5.requires_grad_(False)
34
+ clip.requires_grad_(False)
35
+
36
+ model_path = hf_hub_download(
37
+ repo_id="Boese0601/ByteMorpher",
38
+ filename="dit.safetensors",
39
+ use_auth_token=os.getenv("HF_TOKEN")
40
+ )
41
+ state_dict = load_file(model_path)
42
+ dit.load_state_dict(state_dict)
43
+ dit.eval()
44
+ dit.to(device, dtype=dtype)
45
+
46
+ sampler = XFluxSampler(
47
+ clip=clip,
48
+ t5=t5,
49
+ ae=vae,
50
+ model=dit,
51
+ device=device,
52
+ ip_loaded=False,
53
+ spatial_condition=False,
54
+ clip_image_processor=None,
55
+ image_encoder=None,
56
+ improj=None
57
+ )
58
  # global sampler
59
  # device = torch.device("cuda")
60
  # dtype = torch.bfloat16
 
97
  img = torch.from_numpy((np.array(img) / 127.5) - 1)
98
  img = img.permute(2, 0, 1).unsqueeze(0).to(device, dtype=dtype)
99
 
100
+ result = sampler(
101
+ prompt=edit_prompt,
102
+ width=args.sample_width,
103
+ height=args.sample_height,
104
+ num_steps=args.sample_steps,
105
+ image_prompt=None,
106
+ true_gs=args.cfg_scale,
107
+ seed=args.seed,
108
+ ip_scale=args.ip_scale if args.use_ip else 1.0,
109
+ source_image=img if args.use_spatial_condition else None,
110
+ )
 
111
  return tensor_to_pil_image(result)
112
 
113
  def get_samples():
assets/0_camera_zoom/20486354.json DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:7917201faf043e935ea4ddd94c7e570fe5ca51f8bed66ee4d4dabe480f8390b5
3
- size 1128
 
 
 
 
assets/0_camera_zoom/20486354.png DELETED

Git LFS Details

  • SHA256: b124690b006104dbde0b59fb88189cf27fb5ccb07d31bd86bb376f12e6c845b0
  • Pointer size: 128 Bytes
  • Size of remote file: 131 Bytes
assets/0_camera_zoom/20486354_2.png DELETED

Git LFS Details

  • SHA256: 122440474104f5e2cf739f5a4a8294d997abe90528737f384dc64d9860a05b9b
  • Pointer size: 128 Bytes
  • Size of remote file: 132 Bytes