yanze commited on
Commit
56a373a
·
1 Parent(s): ecd462b

update v1.1

Browse files
Files changed (2) hide show
  1. app.py +7 -6
  2. dreamo/dreamo_pipeline.py +33 -14
app.py CHANGED
@@ -33,6 +33,7 @@ from tools import BEN2
33
 
34
  parser = argparse.ArgumentParser()
35
  parser.add_argument('--port', type=int, default=8080)
 
36
  parser.add_argument('--no_turbo', action='store_true')
37
  args = parser.parse_args()
38
 
@@ -169,17 +170,17 @@ def generate_image(
169
 
170
  _HEADER_ = '''
171
  <div style="text-align: center; max-width: 650px; margin: 0 auto;">
172
- <h1 style="font-size: 2.5rem; font-weight: 700; margin-bottom: 1rem; display: contents;">DreamO</h1>
173
  <p style="font-size: 1rem; margin-bottom: 1.5rem;">Paper: <a href='https://arxiv.org/abs/2504.16915' target='_blank'>DreamO: A Unified Framework for Image Customization</a> | Codes: <a href='https://github.com/bytedance/DreamO' target='_blank'>GitHub</a></p>
174
  </div>
175
 
176
  🚩 Update Notes:
 
177
  - 2025.05.11: We have updated the model to mitigate over-saturation and plastic-face issues. The new version shows consistent improvements over the previous release.
178
 
179
  ❗️❗️❗️**User Guide:**
180
  - The most important thing to do first is to try the examples provided below the demo, which will help you better understand the capabilities of the DreamO model and the types of tasks it currently supports
181
  - For each input, please select the appropriate task type. For general objects, characters, or clothing, choose IP — we will remove the background from the input image. If you select ID, we will extract the face region from the input image (similar to PuLID). If you select Style, the background will be preserved, and you must prepend the prompt with the instruction: 'generate a same style image.' to activate the style task.
182
- - The most import hyperparameter in this demo is the guidance scale, which is set to 3.5 by default. If you notice that faces appear overly glossy or unrealistic—especially in ID tasks—you can lower the guidance scale (e.g., to 3). Conversely, if text rendering is poor or limb distortion occurs, increasing the guidance scale (e.g., to 4) may help.
183
  - To accelerate inference, we adopt FLUX-turbo LoRA, which reduces the sampling steps from 25 to 12 compared to FLUX-dev. Additionally, we distill a CFG LoRA, achieving nearly a twofold reduction in steps by eliminating the need for true CFG
184
 
185
  ''' # noqa E501
@@ -210,10 +211,10 @@ def create_demo():
210
  width = gr.Slider(768, 1024, 1024, step=16, label="Width")
211
  height = gr.Slider(768, 1024, 1024, step=16, label="Height")
212
  num_steps = gr.Slider(8, 30, 12, step=1, label="Number of steps")
213
- guidance = gr.Slider(1.0, 10.0, 3.5, step=0.1, label="Guidance")
214
  seed = gr.Textbox(label="Seed (-1 for random)", value="-1")
 
215
  with gr.Accordion("Advanced Options", open=False, visible=False):
216
- ref_res = gr.Slider(512, 1024, 512, step=16, label="resolution for ref image")
217
  neg_prompt = gr.Textbox(label="Neg Prompt", value="")
218
  neg_guidance = gr.Slider(1.0, 10.0, 3.5, step=0.1, label="Neg Guidance")
219
  true_cfg = gr.Slider(1, 5, 1, step=0.1, label="true cfg")
@@ -304,7 +305,7 @@ def create_demo():
304
  'id',
305
  'ip',
306
  'the woman wearing a dress, In the banquet hall',
307
- 7698454872441022867,
308
  ],
309
  [
310
  'example_inputs/dog1.png',
@@ -328,7 +329,7 @@ def create_demo():
328
  'ip',
329
  'ip',
330
  'a man is dancing with a woman in the room',
331
- 8303780338601106219,
332
  ],
333
  ]
334
  gr.Examples(
 
33
 
34
  parser = argparse.ArgumentParser()
35
  parser.add_argument('--port', type=int, default=8080)
36
+ parser.add_argument('--version', type=str, default='v1.1')
37
  parser.add_argument('--no_turbo', action='store_true')
38
  args = parser.parse_args()
39
 
 
170
 
171
  _HEADER_ = '''
172
  <div style="text-align: center; max-width: 650px; margin: 0 auto;">
173
+ <h1 style="font-size: 2.5rem; font-weight: 700; margin-bottom: 1rem; display: contents;">DreamO v1.1</h1>
174
  <p style="font-size: 1rem; margin-bottom: 1.5rem;">Paper: <a href='https://arxiv.org/abs/2504.16915' target='_blank'>DreamO: A Unified Framework for Image Customization</a> | Codes: <a href='https://github.com/bytedance/DreamO' target='_blank'>GitHub</a></p>
175
  </div>
176
 
177
  🚩 Update Notes:
178
+ - 2025.06.24: Updated to version 1.1 with significant improvements in image quality, reduced likelihood of body composition errors, and enhanced aesthetics. <a href='https://github.com/bytedance/DreamO/blob/main/dreamo_v1.1.md' target='_blank'>Learn more about this model</a>
179
  - 2025.05.11: We have updated the model to mitigate over-saturation and plastic-face issues. The new version shows consistent improvements over the previous release.
180
 
181
  ❗️❗️❗️**User Guide:**
182
  - The most important thing to do first is to try the examples provided below the demo, which will help you better understand the capabilities of the DreamO model and the types of tasks it currently supports
183
  - For each input, please select the appropriate task type. For general objects, characters, or clothing, choose IP — we will remove the background from the input image. If you select ID, we will extract the face region from the input image (similar to PuLID). If you select Style, the background will be preserved, and you must prepend the prompt with the instruction: 'generate a same style image.' to activate the style task.
 
184
  - To accelerate inference, we adopt FLUX-turbo LoRA, which reduces the sampling steps from 25 to 12 compared to FLUX-dev. Additionally, we distill a CFG LoRA, achieving nearly a twofold reduction in steps by eliminating the need for true CFG
185
 
186
  ''' # noqa E501
 
211
  width = gr.Slider(768, 1024, 1024, step=16, label="Width")
212
  height = gr.Slider(768, 1024, 1024, step=16, label="Height")
213
  num_steps = gr.Slider(8, 30, 12, step=1, label="Number of steps")
214
+ guidance = gr.Slider(1.0, 10.0, 4.5 if args.version == 'v1.1' else 3.5, step=0.1, label="Guidance")
215
  seed = gr.Textbox(label="Seed (-1 for random)", value="-1")
216
+ ref_res = gr.Slider(512, 1024, 512, step=16, label="resolution for ref image, increase it if necessary")
217
  with gr.Accordion("Advanced Options", open=False, visible=False):
 
218
  neg_prompt = gr.Textbox(label="Neg Prompt", value="")
219
  neg_guidance = gr.Slider(1.0, 10.0, 3.5, step=0.1, label="Neg Guidance")
220
  true_cfg = gr.Slider(1, 5, 1, step=0.1, label="true cfg")
 
305
  'id',
306
  'ip',
307
  'the woman wearing a dress, In the banquet hall',
308
+ 42,
309
  ],
310
  [
311
  'example_inputs/dog1.png',
 
329
  'ip',
330
  'ip',
331
  'a man is dancing with a woman in the room',
332
+ 42,
333
  ],
334
  ]
335
  gr.Examples(
dreamo/dreamo_pipeline.py CHANGED
@@ -43,16 +43,26 @@ class DreamOPipeline(FluxPipeline):
43
  self.task_embedding = nn.Embedding(2, 3072)
44
  self.idx_embedding = nn.Embedding(10, 3072)
45
 
46
- def load_dreamo_model(self, device, use_turbo=True):
47
  # download models and load file
48
  hf_hub_download(repo_id='ByteDance/DreamO', filename='dreamo.safetensors', local_dir='models')
49
  hf_hub_download(repo_id='ByteDance/DreamO', filename='dreamo_cfg_distill.safetensors', local_dir='models')
50
- hf_hub_download(repo_id='ByteDance/DreamO', filename='dreamo_quality_lora_pos.safetensors', local_dir='models')
51
- hf_hub_download(repo_id='ByteDance/DreamO', filename='dreamo_quality_lora_neg.safetensors', local_dir='models')
 
 
 
 
 
 
 
 
 
 
 
 
52
  dreamo_lora = load_file('models/dreamo.safetensors')
53
  cfg_distill_lora = load_file('models/dreamo_cfg_distill.safetensors')
54
- quality_lora_pos = load_file('models/dreamo_quality_lora_pos.safetensors')
55
- quality_lora_neg = load_file('models/dreamo_quality_lora_neg.safetensors')
56
 
57
  # load embedding
58
  self.t5_embedding.weight.data = dreamo_lora.pop('dreamo_t5_embedding.weight')[-10:]
@@ -83,18 +93,27 @@ class DreamOPipeline(FluxPipeline):
83
  adapter_names.append('turbo')
84
  adapter_weights.append(1)
85
 
86
- # quality loras, one pos, one neg
87
- quality_lora_pos = convert_flux_lora_to_diffusers(quality_lora_pos)
88
- self.load_lora_weights(quality_lora_pos, adapter_name='quality_pos')
89
- adapter_names.append('quality_pos')
90
- adapter_weights.append(0.15)
91
- quality_lora_neg = convert_flux_lora_to_diffusers(quality_lora_neg)
92
- self.load_lora_weights(quality_lora_neg, adapter_name='quality_neg')
93
- adapter_names.append('quality_neg')
94
- adapter_weights.append(-0.8)
 
 
 
 
 
 
 
 
95
 
96
  self.set_adapters(adapter_names, adapter_weights)
97
  self.fuse_lora(adapter_names=adapter_names, lora_scale=1)
 
98
 
99
  self.t5_embedding = self.t5_embedding.to(device)
100
  self.task_embedding = self.task_embedding.to(device)
 
43
  self.task_embedding = nn.Embedding(2, 3072)
44
  self.idx_embedding = nn.Embedding(10, 3072)
45
 
46
+ def load_dreamo_model(self, device, use_turbo=True, version='v1.1'):
47
  # download models and load file
48
  hf_hub_download(repo_id='ByteDance/DreamO', filename='dreamo.safetensors', local_dir='models')
49
  hf_hub_download(repo_id='ByteDance/DreamO', filename='dreamo_cfg_distill.safetensors', local_dir='models')
50
+ if version == 'v1':
51
+ hf_hub_download(repo_id='ByteDance/DreamO', filename='dreamo_quality_lora_pos.safetensors',
52
+ local_dir='models')
53
+ hf_hub_download(repo_id='ByteDance/DreamO', filename='dreamo_quality_lora_neg.safetensors',
54
+ local_dir='models')
55
+ quality_lora_pos = load_file('models/dreamo_quality_lora_pos.safetensors')
56
+ quality_lora_neg = load_file('models/dreamo_quality_lora_neg.safetensors')
57
+ elif version == 'v1.1':
58
+ hf_hub_download(repo_id='ByteDance/DreamO', filename='v1.1/dreamo_sft_lora.safetensors', local_dir='models')
59
+ hf_hub_download(repo_id='ByteDance/DreamO', filename='v1.1/dreamo_dpo_lora.safetensors', local_dir='models')
60
+ sft_lora = load_file('models/v1.1/dreamo_sft_lora.safetensors')
61
+ dpo_lora = load_file('models/v1.1/dreamo_dpo_lora.safetensors')
62
+ else:
63
+ raise ValueError(f'there is no {version}')
64
  dreamo_lora = load_file('models/dreamo.safetensors')
65
  cfg_distill_lora = load_file('models/dreamo_cfg_distill.safetensors')
 
 
66
 
67
  # load embedding
68
  self.t5_embedding.weight.data = dreamo_lora.pop('dreamo_t5_embedding.weight')[-10:]
 
93
  adapter_names.append('turbo')
94
  adapter_weights.append(1)
95
 
96
+ if version == 'v1':
97
+ # quality loras, one pos, one neg
98
+ quality_lora_pos = convert_flux_lora_to_diffusers(quality_lora_pos)
99
+ self.load_lora_weights(quality_lora_pos, adapter_name='quality_pos')
100
+ adapter_names.append('quality_pos')
101
+ adapter_weights.append(0.15)
102
+ quality_lora_neg = convert_flux_lora_to_diffusers(quality_lora_neg)
103
+ self.load_lora_weights(quality_lora_neg, adapter_name='quality_neg')
104
+ adapter_names.append('quality_neg')
105
+ adapter_weights.append(-0.8)
106
+ elif version == 'v1.1':
107
+ self.load_lora_weights(sft_lora, adapter_name='sft_lora')
108
+ adapter_names.append('sft_lora')
109
+ adapter_weights.append(1)
110
+ self.load_lora_weights(dpo_lora, adapter_name='dpo_lora')
111
+ adapter_names.append('dpo_lora')
112
+ adapter_weights.append(1.25)
113
 
114
  self.set_adapters(adapter_names, adapter_weights)
115
  self.fuse_lora(adapter_names=adapter_names, lora_scale=1)
116
+ self.unload_lora_weights()
117
 
118
  self.t5_embedding = self.t5_embedding.to(device)
119
  self.task_embedding = self.task_embedding.to(device)