yanze commited on
Commit
ecd462b
·
1 Parent(s): 122eb28

a better model

Browse files
Files changed (3) hide show
  1. app.py +15 -8
  2. dreamo/dreamo_pipeline.py +28 -6
  3. dreamo/utils.py +10 -0
app.py CHANGED
@@ -28,19 +28,20 @@ from PIL import Image
28
  from torchvision.transforms.functional import normalize
29
 
30
  from dreamo.dreamo_pipeline import DreamOPipeline
31
- from dreamo.utils import img2tensor, resize_numpy_image_area, tensor2img
32
  from tools import BEN2
33
 
34
  parser = argparse.ArgumentParser()
35
  parser.add_argument('--port', type=int, default=8080)
 
36
  args = parser.parse_args()
37
 
38
  huggingface_hub.login(os.getenv('HF_TOKEN'))
39
 
40
- # try:
41
- # shutil.rmtree('gradio_cached_examples')
42
- # except FileNotFoundError:
43
- # print("cache folder not exist")
44
 
45
  class Generator:
46
  def __init__(self):
@@ -63,7 +64,7 @@ class Generator:
63
  # load dreamo
64
  model_root = 'black-forest-labs/FLUX.1-dev'
65
  dreamo_pipeline = DreamOPipeline.from_pretrained(model_root, torch_dtype=torch.bfloat16)
66
- dreamo_pipeline.load_dreamo_model(device, use_turbo=True)
67
  self.dreamo_pipeline = dreamo_pipeline.to(device)
68
 
69
  @torch.no_grad()
@@ -126,10 +127,12 @@ def generate_image(
126
  for idx, (ref_image, ref_task) in enumerate(zip(ref_images, ref_tasks)):
127
  if ref_image is not None:
128
  if ref_task == "id":
 
129
  ref_image = generator.get_align_face(ref_image)
130
  elif ref_task != "style":
131
  ref_image = generator.bg_rm_model.inference(Image.fromarray(ref_image))
132
- ref_image = resize_numpy_image_area(np.array(ref_image), ref_res * ref_res)
 
133
  debug_images.append(ref_image)
134
  ref_image = img2tensor(ref_image, bgr2rgb=False).unsqueeze(0) / 255.0
135
  ref_image = 2 * ref_image - 1.0
@@ -170,9 +173,13 @@ _HEADER_ = '''
170
  <p style="font-size: 1rem; margin-bottom: 1.5rem;">Paper: <a href='https://arxiv.org/abs/2504.16915' target='_blank'>DreamO: A Unified Framework for Image Customization</a> | Codes: <a href='https://github.com/bytedance/DreamO' target='_blank'>GitHub</a></p>
171
  </div>
172
 
 
 
 
173
  ❗️❗️❗️**User Guide:**
174
  - The most important thing to do first is to try the examples provided below the demo, which will help you better understand the capabilities of the DreamO model and the types of tasks it currently supports
175
  - For each input, please select the appropriate task type. For general objects, characters, or clothing, choose IP — we will remove the background from the input image. If you select ID, we will extract the face region from the input image (similar to PuLID). If you select Style, the background will be preserved, and you must prepend the prompt with the instruction: 'generate a same style image.' to activate the style task.
 
176
  - To accelerate inference, we adopt FLUX-turbo LoRA, which reduces the sampling steps from 25 to 12 compared to FLUX-dev. Additionally, we distill a CFG LoRA, achieving nearly a twofold reduction in steps by eliminating the need for true CFG
177
 
178
  ''' # noqa E501
@@ -249,7 +256,7 @@ def create_demo():
249
  'ip',
250
  'ip',
251
  'a purple toy holding a sign saying "DreamO", on the mountain',
252
- 1563188099017016129,
253
  ],
254
  [
255
  'example_inputs/perfume.png',
 
28
  from torchvision.transforms.functional import normalize
29
 
30
  from dreamo.dreamo_pipeline import DreamOPipeline
31
+ from dreamo.utils import img2tensor, resize_numpy_image_area, tensor2img, resize_numpy_image_long
32
  from tools import BEN2
33
 
34
  parser = argparse.ArgumentParser()
35
  parser.add_argument('--port', type=int, default=8080)
36
+ parser.add_argument('--no_turbo', action='store_true')
37
  args = parser.parse_args()
38
 
39
  huggingface_hub.login(os.getenv('HF_TOKEN'))
40
 
41
+ try:
42
+ shutil.rmtree('gradio_cached_examples')
43
+ except FileNotFoundError:
44
+ print("cache folder not exist")
45
 
46
  class Generator:
47
  def __init__(self):
 
64
  # load dreamo
65
  model_root = 'black-forest-labs/FLUX.1-dev'
66
  dreamo_pipeline = DreamOPipeline.from_pretrained(model_root, torch_dtype=torch.bfloat16)
67
+ dreamo_pipeline.load_dreamo_model(device, use_turbo=not args.no_turbo)
68
  self.dreamo_pipeline = dreamo_pipeline.to(device)
69
 
70
  @torch.no_grad()
 
127
  for idx, (ref_image, ref_task) in enumerate(zip(ref_images, ref_tasks)):
128
  if ref_image is not None:
129
  if ref_task == "id":
130
+ ref_image = resize_numpy_image_long(ref_image, 1024)
131
  ref_image = generator.get_align_face(ref_image)
132
  elif ref_task != "style":
133
  ref_image = generator.bg_rm_model.inference(Image.fromarray(ref_image))
134
+ if ref_task != "id":
135
+ ref_image = resize_numpy_image_area(np.array(ref_image), ref_res * ref_res)
136
  debug_images.append(ref_image)
137
  ref_image = img2tensor(ref_image, bgr2rgb=False).unsqueeze(0) / 255.0
138
  ref_image = 2 * ref_image - 1.0
 
173
  <p style="font-size: 1rem; margin-bottom: 1.5rem;">Paper: <a href='https://arxiv.org/abs/2504.16915' target='_blank'>DreamO: A Unified Framework for Image Customization</a> | Codes: <a href='https://github.com/bytedance/DreamO' target='_blank'>GitHub</a></p>
174
  </div>
175
 
176
+ 🚩 Update Notes:
177
+ - 2025.05.11: We have updated the model to mitigate over-saturation and plastic-face issues. The new version shows consistent improvements over the previous release.
178
+
179
  ❗️❗️❗️**User Guide:**
180
  - The most important thing to do first is to try the examples provided below the demo, which will help you better understand the capabilities of the DreamO model and the types of tasks it currently supports
181
  - For each input, please select the appropriate task type. For general objects, characters, or clothing, choose IP — we will remove the background from the input image. If you select ID, we will extract the face region from the input image (similar to PuLID). If you select Style, the background will be preserved, and you must prepend the prompt with the instruction: 'generate a same style image.' to activate the style task.
182
+ - The most import hyperparameter in this demo is the guidance scale, which is set to 3.5 by default. If you notice that faces appear overly glossy or unrealistic—especially in ID tasks—you can lower the guidance scale (e.g., to 3). Conversely, if text rendering is poor or limb distortion occurs, increasing the guidance scale (e.g., to 4) may help.
183
  - To accelerate inference, we adopt FLUX-turbo LoRA, which reduces the sampling steps from 25 to 12 compared to FLUX-dev. Additionally, we distill a CFG LoRA, achieving nearly a twofold reduction in steps by eliminating the need for true CFG
184
 
185
  ''' # noqa E501
 
256
  'ip',
257
  'ip',
258
  'a purple toy holding a sign saying "DreamO", on the mountain',
259
+ 10441727852953907380,
260
  ],
261
  [
262
  'example_inputs/perfume.png',
dreamo/dreamo_pipeline.py CHANGED
@@ -44,24 +44,35 @@ class DreamOPipeline(FluxPipeline):
44
  self.idx_embedding = nn.Embedding(10, 3072)
45
 
46
  def load_dreamo_model(self, device, use_turbo=True):
 
47
  hf_hub_download(repo_id='ByteDance/DreamO', filename='dreamo.safetensors', local_dir='models')
48
  hf_hub_download(repo_id='ByteDance/DreamO', filename='dreamo_cfg_distill.safetensors', local_dir='models')
 
 
49
  dreamo_lora = load_file('models/dreamo.safetensors')
50
  cfg_distill_lora = load_file('models/dreamo_cfg_distill.safetensors')
 
 
 
 
51
  self.t5_embedding.weight.data = dreamo_lora.pop('dreamo_t5_embedding.weight')[-10:]
52
  self.task_embedding.weight.data = dreamo_lora.pop('dreamo_task_embedding.weight')
53
  self.idx_embedding.weight.data = dreamo_lora.pop('dreamo_idx_embedding.weight')
54
  self._prepare_t5()
55
 
 
56
  dreamo_diffuser_lora = convert_flux_lora_to_diffusers(dreamo_lora)
57
- cfg_diffuser_lora = convert_flux_lora_to_diffusers(cfg_distill_lora)
58
  adapter_names = ['dreamo']
59
  adapter_weights = [1]
60
  self.load_lora_weights(dreamo_diffuser_lora, adapter_name='dreamo')
61
- if cfg_diffuser_lora is not None:
62
- self.load_lora_weights(cfg_diffuser_lora, adapter_name='cfg')
63
- adapter_names.append('cfg')
64
- adapter_weights.append(1)
 
 
 
 
65
  if use_turbo:
66
  self.load_lora_weights(
67
  hf_hub_download(
@@ -72,7 +83,18 @@ class DreamOPipeline(FluxPipeline):
72
  adapter_names.append('turbo')
73
  adapter_weights.append(1)
74
 
75
- self.fuse_lora(adapter_names=adapter_names, adapter_weights=adapter_weights, lora_scale=1)
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  self.t5_embedding = self.t5_embedding.to(device)
78
  self.task_embedding = self.task_embedding.to(device)
 
44
  self.idx_embedding = nn.Embedding(10, 3072)
45
 
46
  def load_dreamo_model(self, device, use_turbo=True):
47
+ # download models and load file
48
  hf_hub_download(repo_id='ByteDance/DreamO', filename='dreamo.safetensors', local_dir='models')
49
  hf_hub_download(repo_id='ByteDance/DreamO', filename='dreamo_cfg_distill.safetensors', local_dir='models')
50
+ hf_hub_download(repo_id='ByteDance/DreamO', filename='dreamo_quality_lora_pos.safetensors', local_dir='models')
51
+ hf_hub_download(repo_id='ByteDance/DreamO', filename='dreamo_quality_lora_neg.safetensors', local_dir='models')
52
  dreamo_lora = load_file('models/dreamo.safetensors')
53
  cfg_distill_lora = load_file('models/dreamo_cfg_distill.safetensors')
54
+ quality_lora_pos = load_file('models/dreamo_quality_lora_pos.safetensors')
55
+ quality_lora_neg = load_file('models/dreamo_quality_lora_neg.safetensors')
56
+
57
+ # load embedding
58
  self.t5_embedding.weight.data = dreamo_lora.pop('dreamo_t5_embedding.weight')[-10:]
59
  self.task_embedding.weight.data = dreamo_lora.pop('dreamo_task_embedding.weight')
60
  self.idx_embedding.weight.data = dreamo_lora.pop('dreamo_idx_embedding.weight')
61
  self._prepare_t5()
62
 
63
+ # main lora
64
  dreamo_diffuser_lora = convert_flux_lora_to_diffusers(dreamo_lora)
 
65
  adapter_names = ['dreamo']
66
  adapter_weights = [1]
67
  self.load_lora_weights(dreamo_diffuser_lora, adapter_name='dreamo')
68
+
69
+ # cfg lora to avoid true image cfg
70
+ cfg_diffuser_lora = convert_flux_lora_to_diffusers(cfg_distill_lora)
71
+ self.load_lora_weights(cfg_diffuser_lora, adapter_name='cfg')
72
+ adapter_names.append('cfg')
73
+ adapter_weights.append(1)
74
+
75
+ # turbo lora to speed up (from 25+ step to 12 step)
76
  if use_turbo:
77
  self.load_lora_weights(
78
  hf_hub_download(
 
83
  adapter_names.append('turbo')
84
  adapter_weights.append(1)
85
 
86
+ # quality loras, one pos, one neg
87
+ quality_lora_pos = convert_flux_lora_to_diffusers(quality_lora_pos)
88
+ self.load_lora_weights(quality_lora_pos, adapter_name='quality_pos')
89
+ adapter_names.append('quality_pos')
90
+ adapter_weights.append(0.15)
91
+ quality_lora_neg = convert_flux_lora_to_diffusers(quality_lora_neg)
92
+ self.load_lora_weights(quality_lora_neg, adapter_name='quality_neg')
93
+ adapter_names.append('quality_neg')
94
+ adapter_weights.append(-0.8)
95
+
96
+ self.set_adapters(adapter_names, adapter_weights)
97
+ self.fuse_lora(adapter_names=adapter_names, lora_scale=1)
98
 
99
  self.t5_embedding = self.t5_embedding.to(device)
100
  self.task_embedding = self.task_embedding.to(device)
dreamo/utils.py CHANGED
@@ -117,6 +117,16 @@ def resize_numpy_image_area(image, area=512 * 512):
117
  image = cv2.resize(image, (w, h), interpolation=cv2.INTER_AREA)
118
  return image
119
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  # reference: https://github.com/huggingface/diffusers/pull/9295/files
122
  def convert_flux_lora_to_diffusers(old_state_dict):
 
117
  image = cv2.resize(image, (w, h), interpolation=cv2.INTER_AREA)
118
  return image
119
 
120
+ def resize_numpy_image_long(image, long_edge=768):
121
+ h, w = image.shape[:2]
122
+ if max(h, w) <= long_edge:
123
+ return image
124
+ k = long_edge / max(h, w)
125
+ h = int(h * k)
126
+ w = int(w * k)
127
+ image = cv2.resize(image, (w, h), interpolation=cv2.INTER_AREA)
128
+ return image
129
+
130
 
131
  # reference: https://github.com/huggingface/diffusers/pull/9295/files
132
  def convert_flux_lora_to_diffusers(old_state_dict):