fenfan commited on
Commit
0f74281
·
1 Parent(s): 44458a9

init commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. README copy.md +17 -0
  3. app.py +242 -0
  4. assets/gradio_examples/1subject/config.json +6 -0
  5. assets/gradio_examples/1subject/ref.jpg +3 -0
  6. assets/gradio_examples/2identity/config.json +6 -0
  7. assets/gradio_examples/2identity/ref.webp +3 -0
  8. assets/gradio_examples/3identity/config.json +6 -0
  9. assets/gradio_examples/3identity/ref.jpg +3 -0
  10. assets/gradio_examples/4identity/config.json +6 -0
  11. assets/gradio_examples/4identity/ref.webp +3 -0
  12. assets/gradio_examples/5style/config.json +6 -0
  13. assets/gradio_examples/5style/ref.webp +3 -0
  14. assets/gradio_examples/6style/config.json +6 -0
  15. assets/gradio_examples/6style/ref.webp +3 -0
  16. assets/gradio_examples/7style_subject/config.json +7 -0
  17. assets/gradio_examples/7style_subject/ref1.webp +3 -0
  18. assets/gradio_examples/7style_subject/ref2.webp +3 -0
  19. assets/gradio_examples/8style_subject/config.json +7 -0
  20. assets/gradio_examples/8style_subject/ref1.webp +3 -0
  21. assets/gradio_examples/8style_subject/ref2.webp +3 -0
  22. assets/gradio_examples/9mix_style/config.json +7 -0
  23. assets/gradio_examples/9mix_style/ref1.webp +3 -0
  24. assets/gradio_examples/9mix_style/ref2.webp +3 -0
  25. assets/gradio_examples/identity1.jpg +3 -0
  26. assets/gradio_examples/identity1_result.png +3 -0
  27. assets/gradio_examples/identity2.webp +3 -0
  28. assets/gradio_examples/identity2_style2_result.webp +3 -0
  29. assets/gradio_examples/style1.webp +3 -0
  30. assets/gradio_examples/style1_result.webp +3 -0
  31. assets/gradio_examples/style2.webp +3 -0
  32. assets/gradio_examples/style3.webp +3 -0
  33. assets/gradio_examples/style3_style4_result.webp +3 -0
  34. assets/gradio_examples/style4.webp +3 -0
  35. assets/gradio_examples/z_mix_style/config.json +7 -0
  36. assets/gradio_examples/z_mix_style/ref1.png +3 -0
  37. assets/gradio_examples/z_mix_style/ref2.png +3 -0
  38. assets/gradio_examples/zz_t2i/config.json +5 -0
  39. assets/teaser.webp +3 -0
  40. assets/uso.webp +3 -0
  41. assets/uso_logo.svg +0 -0
  42. assets/uso_text.svg +0 -0
  43. requirements.txt +19 -0
  44. uso/flux/math.py +45 -0
  45. uso/flux/model.py +258 -0
  46. uso/flux/modules/__pycache__/autoencoder.cpython-311.pyc +0 -0
  47. uso/flux/modules/__pycache__/conditioner.cpython-311.pyc +0 -0
  48. uso/flux/modules/__pycache__/layers.cpython-311.pyc +0 -0
  49. uso/flux/modules/autoencoder.py +327 -0
  50. uso/flux/modules/conditioner.py +53 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
37
+ *.jpg filter=lfs diff=lfs merge=lfs -text
38
+ *.webp filter=lfs diff=lfs merge=lfs -text
README copy.md ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: USO
3
+ emoji: 💻
4
+ colorFrom: indigo
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 5.23.3
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ short_description: Freely Combining Any Subjects with Any Styles Across All Scenarios.
12
+ models:
13
+ - black-forest-labs/FLUX.1-dev
14
+ - bytedance-research/UNO
15
+ ---
16
+
17
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import dataclasses
16
+ import json
17
+ import os
18
+ from pathlib import Path
19
+
20
+ import gradio as gr
21
+ import torch
22
+ import spaces
23
+
24
+ from uso.flux.pipeline import USOPipeline
25
+ from transformers import SiglipVisionModel, SiglipImageProcessor
26
+
27
+
28
+ with open("assets/uso_text.svg", "r", encoding="utf-8") as svg_file:
29
+ text_content = svg_file.read()
30
+
31
+ with open("assets/uso_logo.svg", "r", encoding="utf-8") as svg_file:
32
+ logo_content = svg_file.read()
33
+
34
+ title = f"""
35
+ <div style="display: flex; align-items: center; justify-content: center;">
36
+ <span style="transform: scale(0.7);margin-right: -5px;">{text_content}</span>
37
+ <span style="font-size: 1.8em;margin-left: -10px;font-weight: bold; font-family: Gill Sans;">by UXO Team</span>
38
+ <span style="margin-left: 0px; transform: scale(0.85); display: inline-block;">{logo_content}</span>
39
+ </div>
40
+ """.strip()
41
+
42
+ badges_text = r"""
43
+ <div style="text-align: center; display: flex; justify-content: center; gap: 5px;">
44
+ <a href="https://github.com/bytedance/USO"><img src="https://img.shields.io/static/v1?label=GitHub&message=Code&color=green&logo=github"></a>
45
+ <a href="https://bytedance.github.io/USO/"><img alt="Build" src="https://img.shields.io/badge/Project%20Page-USO-yellow"></a>
46
+ <a href="https://arxiv.org/abs/2504.02160"><img alt="Build" src="https://img.shields.io/badge/arXiv%20paper-USO-b31b1b.svg"></a>
47
+ <a href="https://huggingface.co/bytedance-research/USO"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%97%20Hugging%20Face&message=Model&color=orange"></a>
48
+ </div>
49
+ """.strip()
50
+
51
+ tips = """
52
+ 📌 **What is USO?**
53
+ USO is a unified style-subject optimized customization model and the latest addition to the UXO family (<a href='https://github.com/bytedance/USO' target='_blank'> USO</a> and <a href='https://github.com/bytedance/UNO' target='_blank'> UNO</a>).
54
+ It can freely combine arbitrary subjects with arbitrary styles in any scenarios.
55
+
56
+ 💡 **How to use?**
57
+ We provide step-by-step instructions in our <a href='https://github.com/bytedance/USO' target='_blank'> Github Repo</a>.
58
+ Additionally, try the examples provided below the demo to quickly get familiar with USO and spark your creativity!
59
+
60
+ ⚡️ The model is trained on 1024x1024 resolution and supports 3 types of usage:
61
+ * **Only content img**: support following types:
62
+ * Subject/Identity-driven (supports natural prompt, e.g., *A clock on the table.* *The woman near the sea.*, excels in producing **photorealistic portraits**)
63
+ * Style edit (layout-preserved): *Transform the image into Ghibli style/Pixel style/Retro comic style/Watercolor painting style...*.
64
+ * Style edit (layout-shift): *Ghibli style, the man on the beach.*.
65
+ * **Only style img**: Reference input style and generate anything following prompt. Excelling in this and further support multiple style references (in beta).
66
+ * **Content img + style img**: Place the content into the desired style.
67
+ * Layout-preserved: set prompt to **empty**.
68
+ * Layout-shift: using natural prompt."""
69
+
70
+ star = r"""
71
+ If USO is helpful, please help to ⭐ our <a href='https://github.com/bytedance/USO' target='_blank'> Github Repo</a>. Thanks a lot!"""
72
+
73
+ def get_examples(examples_dir: str = "assets/examples") -> list:
74
+ examples = Path(examples_dir)
75
+ ans = []
76
+ for example in examples.iterdir():
77
+ if not example.is_dir() or len(os.listdir(example)) == 0:
78
+ continue
79
+ with open(example / "config.json") as f:
80
+ example_dict = json.load(f)
81
+
82
+
83
+ example_list = []
84
+
85
+ example_list.append(example_dict["usage"]) # case for
86
+ example_list.append(example_dict["prompt"]) # prompt
87
+
88
+ for key in ["image_ref1", "image_ref2", "image_ref3"]:
89
+ if key in example_dict:
90
+ example_list.append(str(example / example_dict[key]))
91
+ else:
92
+ example_list.append(None)
93
+
94
+ example_list.append(example_dict["seed"])
95
+ ans.append(example_list)
96
+ return ans
97
+
98
+
99
+ def create_demo(
100
+ model_type: str,
101
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
102
+ offload: bool = False,
103
+ ):
104
+ pipeline = USOPipeline(
105
+ model_type, device, offload, only_lora=True, lora_rank=128, hf_download=True
106
+ )
107
+ print("USOPipeline loaded successfully")
108
+
109
+ siglip_processor = SiglipImageProcessor.from_pretrained(
110
+ "google/siglip-so400m-patch14-384"
111
+ )
112
+ siglip_model = SiglipVisionModel.from_pretrained(
113
+ "google/siglip-so400m-patch14-384"
114
+ )
115
+ siglip_model.eval()
116
+ siglip_model.to(device)
117
+ pipeline.model.vision_encoder = siglip_model
118
+ pipeline.model.vision_encoder_processor = siglip_processor
119
+ print("SigLIP model loaded successfully")
120
+
121
+ pipeline.gradio_generate = spaces.GPU(duration=120)(pipeline.gradio_generate)
122
+ with gr.Blocks() as demo:
123
+ gr.Markdown(title)
124
+ gr.Markdown(badges_text)
125
+ gr.Markdown(tips)
126
+ with gr.Row():
127
+ with gr.Column():
128
+ prompt = gr.Textbox(label="Prompt", value="A beautiful woman.")
129
+ with gr.Row():
130
+ image_prompt1 = gr.Image(
131
+ label="Content Reference Img", visible=True, interactive=True, type="pil"
132
+ )
133
+ image_prompt2 = gr.Image(
134
+ label="Style Reference Img", visible=True, interactive=True, type="pil"
135
+ )
136
+ image_prompt3 = gr.Image(
137
+ label="Extra Style Reference Img (Beta)", visible=True, interactive=True, type="pil"
138
+ )
139
+
140
+ with gr.Row():
141
+ with gr.Row():
142
+ width = gr.Slider(
143
+ 512, 1536, 1024, step=16, label="Generation Width"
144
+ )
145
+ height = gr.Slider(
146
+ 512, 1536, 1024, step=16, label="Generation Height"
147
+ )
148
+ with gr.Row():
149
+ with gr.Row():
150
+ keep_size = gr.Checkbox(
151
+ label="Keep input size",
152
+ value=False,
153
+ interactive=True
154
+ )
155
+ with gr.Column():
156
+ gr.Markdown("Set it to True if you only need style editing or want to keep the layout.")
157
+
158
+ with gr.Accordion("Advanced Options", open=True):
159
+ with gr.Row():
160
+ num_steps = gr.Slider(
161
+ 1, 50, 25, step=1, label="Number of steps"
162
+ )
163
+ guidance = gr.Slider(
164
+ 1.0, 5.0, 4.0, step=0.1, label="Guidance", interactive=True
165
+ )
166
+ content_long_size = gr.Slider(
167
+ 0, 1024, 512, step=16, label="Content reference size"
168
+ )
169
+ seed = gr.Number(-1, label="Seed (-1 for random)")
170
+
171
+ generate_btn = gr.Button("Generate")
172
+ gr.Markdown(star)
173
+
174
+ with gr.Column():
175
+ output_image = gr.Image(label="Generated Image")
176
+ download_btn = gr.File(
177
+ label="Download full-resolution", type="filepath", interactive=False
178
+ )
179
+
180
+ inputs = [
181
+ prompt,
182
+ image_prompt1,
183
+ image_prompt2,
184
+ image_prompt3,
185
+ seed,
186
+ width,
187
+ height,
188
+ guidance,
189
+ num_steps,
190
+ keep_size,
191
+ content_long_size,
192
+ ]
193
+ generate_btn.click(
194
+ fn=pipeline.gradio_generate,
195
+ inputs=inputs,
196
+ outputs=[output_image, download_btn],
197
+ )
198
+
199
+ example_text = gr.Text("", visible=False, label="Case For:")
200
+ examples = get_examples("./assets/gradio_examples")
201
+
202
+ gr.Examples(
203
+ examples=examples,
204
+ inputs=[
205
+ example_text,
206
+ prompt,
207
+ image_prompt1,
208
+ image_prompt2,
209
+ image_prompt3,
210
+ seed,
211
+ ],
212
+ # cache_examples='lazy',
213
+ outputs=[output_image, download_btn],
214
+ fn=pipeline.gradio_generate,
215
+ )
216
+
217
+ return demo
218
+
219
+
220
+ if __name__ == "__main__":
221
+ from typing import Literal
222
+
223
+ from transformers import HfArgumentParser
224
+
225
+ @dataclasses.dataclass
226
+ class AppArgs:
227
+ name: Literal["flux-dev", "flux-dev-fp8", "flux-schnell", "flux-krea-dev"] = "flux-dev"
228
+ device: Literal["cuda", "cpu"] = "cuda" if torch.cuda.is_available() else "cpu"
229
+ offload: bool = dataclasses.field(
230
+ default=False,
231
+ metadata={
232
+ "help": "If True, sequantial offload the models(ae, dit, text encoder) to CPU if not used."
233
+ },
234
+ )
235
+ port: int = 7860
236
+
237
+ parser = HfArgumentParser([AppArgs])
238
+ args_tuple = parser.parse_args_into_dataclasses() # type: tuple[AppArgs]
239
+ args = args_tuple[0]
240
+
241
+ demo = create_demo(args.name, args.device, args.offload)
242
+ demo.launch(server_port=args.port)
assets/gradio_examples/1subject/config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "Wool felt style, a clock in the jungle.",
3
+ "seed": 3407,
4
+ "usage": "Subject-driven",
5
+ "image_ref1": "./ref.jpg"
6
+ }
assets/gradio_examples/1subject/ref.jpg ADDED

Git LFS Details

  • SHA256: 0e1eb6ca2c944f3bfaed3ace56f5f186ed073a477e0333e0237253d98f0c9267
  • Pointer size: 131 Bytes
  • Size of remote file: 139 kB
assets/gradio_examples/2identity/config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "The girl is riding a bike in a street.",
3
+ "seed": 3407,
4
+ "usage": "Identity-driven",
5
+ "image_ref1": "./ref.webp"
6
+ }
assets/gradio_examples/2identity/ref.webp ADDED

Git LFS Details

  • SHA256: 4e97502bd7eebd6692604f891f836f25c7c30dcac8d15c4d42cc874efc51fcc5
  • Pointer size: 130 Bytes
  • Size of remote file: 85.8 kB
assets/gradio_examples/3identity/config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "The man in flower shops carefully match bouquets, conveying beautiful emotions and blessings with flowers.",
3
+ "seed": 3407,
4
+ "usage": "Identity-driven",
5
+ "image_ref1": "./ref.jpg"
6
+ }
assets/gradio_examples/3identity/ref.jpg ADDED

Git LFS Details

  • SHA256: 2730103b6b9ebaf47b44ef9a9d7fbb722de7878a101af09f0b85f8dfadb4c8a4
  • Pointer size: 130 Bytes
  • Size of remote file: 30.6 kB
assets/gradio_examples/4identity/config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "Transform the image into Ghibli style.",
3
+ "seed": 3407,
4
+ "usage": "Identity-driven",
5
+ "image_ref1": "./ref.webp"
6
+ }
assets/gradio_examples/4identity/ref.webp ADDED

Git LFS Details

  • SHA256: f8ed8aa1c0714c939392e2c033735d6266e53266079bb300cbf05a6824a49f9f
  • Pointer size: 130 Bytes
  • Size of remote file: 38.8 kB
assets/gradio_examples/5style/config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "A cat sleeping on a chair.",
3
+ "seed": 3407,
4
+ "usage": "Style-driven",
5
+ "image_ref2": "./ref.webp"
6
+ }
assets/gradio_examples/5style/ref.webp ADDED

Git LFS Details

  • SHA256: 9ebf56d2d20ae5c49a582ff6bfef64b13022d0c624d9de25ed91047380fdfcfe
  • Pointer size: 130 Bytes
  • Size of remote file: 52.3 kB
assets/gradio_examples/6style/config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "A beautiful woman.",
3
+ "seed": 3407,
4
+ "usage": "Style-driven",
5
+ "image_ref2": "./ref.webp"
6
+ }
assets/gradio_examples/6style/ref.webp ADDED

Git LFS Details

  • SHA256: 40c013341c8708b53094e3eaa377b3dfccdc9e77e215ad15d2ac2e875b4c494a
  • Pointer size: 130 Bytes
  • Size of remote file: 58.2 kB
assets/gradio_examples/7style_subject/config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "",
3
+ "seed": 321,
4
+ "usage": "Style-subject-driven (layout-preserved)",
5
+ "image_ref1": "./ref1.webp",
6
+ "image_ref2": "./ref2.webp"
7
+ }
assets/gradio_examples/7style_subject/ref1.webp ADDED

Git LFS Details

  • SHA256: f8ed8aa1c0714c939392e2c033735d6266e53266079bb300cbf05a6824a49f9f
  • Pointer size: 130 Bytes
  • Size of remote file: 38.8 kB
assets/gradio_examples/7style_subject/ref2.webp ADDED

Git LFS Details

  • SHA256: 175d6e5b975b4d494950250740c0fe371a7e9b2c93c59a3ae82b82be72ccc0f6
  • Pointer size: 130 Bytes
  • Size of remote file: 14.2 kB
assets/gradio_examples/8style_subject/config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "The woman gave an impassioned speech on the podium.",
3
+ "seed": 321,
4
+ "usage": "Style-subject-driven (layout-shifted)",
5
+ "image_ref1": "./ref1.webp",
6
+ "image_ref2": "./ref2.webp"
7
+ }
assets/gradio_examples/8style_subject/ref1.webp ADDED

Git LFS Details

  • SHA256: f8ed8aa1c0714c939392e2c033735d6266e53266079bb300cbf05a6824a49f9f
  • Pointer size: 130 Bytes
  • Size of remote file: 38.8 kB
assets/gradio_examples/8style_subject/ref2.webp ADDED

Git LFS Details

  • SHA256: 0235262d9bd1070155536352ccf195f9875ead0d3379dee7285c0aaae79f6464
  • Pointer size: 130 Bytes
  • Size of remote file: 39.1 kB
assets/gradio_examples/9mix_style/config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "A man.",
3
+ "seed": 321,
4
+ "usage": "Multi-style-driven",
5
+ "image_ref2": "./ref1.webp",
6
+ "image_ref3": "./ref2.webp"
7
+ }
assets/gradio_examples/9mix_style/ref1.webp ADDED

Git LFS Details

  • SHA256: a1d272a0ecb03126503446b00a2152deab2045f89ac2c01f948e1099589d2862
  • Pointer size: 131 Bytes
  • Size of remote file: 142 kB
assets/gradio_examples/9mix_style/ref2.webp ADDED

Git LFS Details

  • SHA256: b1ce04559726509672ce859d617a08d8dff8b2fe28f503fecbca7a5f66082882
  • Pointer size: 131 Bytes
  • Size of remote file: 290 kB
assets/gradio_examples/identity1.jpg ADDED

Git LFS Details

  • SHA256: 2730103b6b9ebaf47b44ef9a9d7fbb722de7878a101af09f0b85f8dfadb4c8a4
  • Pointer size: 130 Bytes
  • Size of remote file: 30.6 kB
assets/gradio_examples/identity1_result.png ADDED

Git LFS Details

  • SHA256: 7684256e44ce1bd4ada1e77a12674432eddd95b07fb388673899139afc56d864
  • Pointer size: 132 Bytes
  • Size of remote file: 1.54 MB
assets/gradio_examples/identity2.webp ADDED

Git LFS Details

  • SHA256: f8ed8aa1c0714c939392e2c033735d6266e53266079bb300cbf05a6824a49f9f
  • Pointer size: 130 Bytes
  • Size of remote file: 38.8 kB
assets/gradio_examples/identity2_style2_result.webp ADDED

Git LFS Details

  • SHA256: 8376b6dc02d304616c09ecf09c7dbabb16c7c9142fb4db21f576a15a1ec24062
  • Pointer size: 130 Bytes
  • Size of remote file: 43.9 kB
assets/gradio_examples/style1.webp ADDED

Git LFS Details

  • SHA256: 9ebf56d2d20ae5c49a582ff6bfef64b13022d0c624d9de25ed91047380fdfcfe
  • Pointer size: 130 Bytes
  • Size of remote file: 52.3 kB
assets/gradio_examples/style1_result.webp ADDED

Git LFS Details

  • SHA256: 16a4353dd83b1c48499e222d6f77904e1fda23c1649ea5f6cca6b00b0fca3069
  • Pointer size: 130 Bytes
  • Size of remote file: 61.1 kB
assets/gradio_examples/style2.webp ADDED

Git LFS Details

  • SHA256: 0235262d9bd1070155536352ccf195f9875ead0d3379dee7285c0aaae79f6464
  • Pointer size: 130 Bytes
  • Size of remote file: 39.1 kB
assets/gradio_examples/style3.webp ADDED

Git LFS Details

  • SHA256: a1d272a0ecb03126503446b00a2152deab2045f89ac2c01f948e1099589d2862
  • Pointer size: 131 Bytes
  • Size of remote file: 142 kB
assets/gradio_examples/style3_style4_result.webp ADDED

Git LFS Details

  • SHA256: d09a5e429cc1d059aecd041e061868cd8e5b59f4718bb0f926fd84364f3794b0
  • Pointer size: 131 Bytes
  • Size of remote file: 173 kB
assets/gradio_examples/style4.webp ADDED

Git LFS Details

  • SHA256: b1ce04559726509672ce859d617a08d8dff8b2fe28f503fecbca7a5f66082882
  • Pointer size: 131 Bytes
  • Size of remote file: 290 kB
assets/gradio_examples/z_mix_style/config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "Boat on water.",
3
+ "seed": 321,
4
+ "usage": "Multi-style-driven",
5
+ "image_ref2": "./ref1.png",
6
+ "image_ref3": "./ref2.png"
7
+ }
assets/gradio_examples/z_mix_style/ref1.png ADDED

Git LFS Details

  • SHA256: 5c31ba662c85f4032abf079dfeb9cba08d797b7b63f1d661c5270b373b00d095
  • Pointer size: 130 Bytes
  • Size of remote file: 26.1 kB
assets/gradio_examples/z_mix_style/ref2.png ADDED

Git LFS Details

  • SHA256: c47d23d5ffdbf30b4a8f6c1bc5d07a730825eaac8363c13bdac8e3bb8c330aed
  • Pointer size: 130 Bytes
  • Size of remote file: 14.7 kB
assets/gradio_examples/zz_t2i/config.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "prompt": "A beautiful woman.",
3
+ "seed": -1,
4
+ "usage": "Text-to-image"
5
+ }
assets/teaser.webp ADDED

Git LFS Details

  • SHA256: 543c724f6b929303046ae481672567fe4a9620f0af5ca1dfff215dc7a2cbff5f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.67 MB
assets/uso.webp ADDED

Git LFS Details

  • SHA256: 772957e867da33550437fa547202d0f995011353ef9a24036d23596dae1a1632
  • Pointer size: 130 Bytes
  • Size of remote file: 58.2 kB
assets/uso_logo.svg ADDED
assets/uso_text.svg ADDED
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.1.1
2
+ deepspeed==0.14.4
3
+ einops==0.8.0
4
+ transformers==4.43.3
5
+ huggingface-hub
6
+ diffusers==0.30.1
7
+ sentencepiece==0.2.0
8
+ gradio==5.22.0
9
+ opencv-python
10
+ matplotlib
11
+ safetensors==0.4.5
12
+ scipy==1.10.1
13
+ numpy==1.24.4
14
+ onnxruntime-gpu
15
+ # httpx==0.23.3
16
+ git+https://github.com/openai/CLIP.git
17
+ --extra-index-url https://download.pytorch.org/whl/cu124
18
+ torch==2.4.0
19
+ torchvision==0.19.0
uso/flux/math.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2
+ # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ from einops import rearrange
18
+ from torch import Tensor
19
+
20
+
21
+ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
22
+ q, k = apply_rope(q, k, pe)
23
+
24
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
25
+ x = rearrange(x, "B H L D -> B L (H D)")
26
+
27
+ return x
28
+
29
+
30
+ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
31
+ assert dim % 2 == 0
32
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
33
+ omega = 1.0 / (theta**scale)
34
+ out = torch.einsum("...n,d->...nd", pos, omega)
35
+ out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
36
+ out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
37
+ return out.float()
38
+
39
+
40
+ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
41
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
42
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
43
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
44
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
45
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
uso/flux/model.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2
+ # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from dataclasses import dataclass
17
+
18
+ import torch
19
+ from torch import Tensor, nn
20
+
21
+ from .modules.layers import (
22
+ DoubleStreamBlock,
23
+ EmbedND,
24
+ LastLayer,
25
+ MLPEmbedder,
26
+ SingleStreamBlock,
27
+ timestep_embedding,
28
+ SigLIPMultiFeatProjModel,
29
+ )
30
+ import os
31
+
32
+
33
+ @dataclass
34
+ class FluxParams:
35
+ in_channels: int
36
+ vec_in_dim: int
37
+ context_in_dim: int
38
+ hidden_size: int
39
+ mlp_ratio: float
40
+ num_heads: int
41
+ depth: int
42
+ depth_single_blocks: int
43
+ axes_dim: list[int]
44
+ theta: int
45
+ qkv_bias: bool
46
+ guidance_embed: bool
47
+
48
+
49
+ class Flux(nn.Module):
50
+ """
51
+ Transformer model for flow matching on sequences.
52
+ """
53
+
54
+ _supports_gradient_checkpointing = True
55
+
56
+ def __init__(self, params: FluxParams):
57
+ super().__init__()
58
+
59
+ self.params = params
60
+ self.in_channels = params.in_channels
61
+ self.out_channels = self.in_channels
62
+ if params.hidden_size % params.num_heads != 0:
63
+ raise ValueError(
64
+ f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
65
+ )
66
+ pe_dim = params.hidden_size // params.num_heads
67
+ if sum(params.axes_dim) != pe_dim:
68
+ raise ValueError(
69
+ f"Got {params.axes_dim} but expected positional dim {pe_dim}"
70
+ )
71
+ self.hidden_size = params.hidden_size
72
+ self.num_heads = params.num_heads
73
+ self.pe_embedder = EmbedND(
74
+ dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim
75
+ )
76
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
77
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
78
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
79
+ self.guidance_in = (
80
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
81
+ if params.guidance_embed
82
+ else nn.Identity()
83
+ )
84
+ self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
85
+
86
+ self.double_blocks = nn.ModuleList(
87
+ [
88
+ DoubleStreamBlock(
89
+ self.hidden_size,
90
+ self.num_heads,
91
+ mlp_ratio=params.mlp_ratio,
92
+ qkv_bias=params.qkv_bias,
93
+ )
94
+ for _ in range(params.depth)
95
+ ]
96
+ )
97
+
98
+ self.single_blocks = nn.ModuleList(
99
+ [
100
+ SingleStreamBlock(
101
+ self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio
102
+ )
103
+ for _ in range(params.depth_single_blocks)
104
+ ]
105
+ )
106
+
107
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
108
+ self.gradient_checkpointing = False
109
+
110
+ # feature embedder for siglip multi-feat inputs
111
+ self.feature_embedder = SigLIPMultiFeatProjModel(
112
+ siglip_token_nums=729,
113
+ style_token_nums=64,
114
+ siglip_token_dims=1152,
115
+ hidden_size=self.hidden_size,
116
+ context_layer_norm=True,
117
+ )
118
+ print("use semantic encoder siglip multi-feat to encode style image")
119
+
120
+ self.vision_encoder = None
121
+
122
+ def _set_gradient_checkpointing(self, module, value=False):
123
+ if hasattr(module, "gradient_checkpointing"):
124
+ module.gradient_checkpointing = value
125
+
126
+ @property
127
+ def attn_processors(self):
128
+ # set recursively
129
+ processors = {} # type: dict[str, nn.Module]
130
+
131
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors):
132
+ if hasattr(module, "set_processor"):
133
+ processors[f"{name}.processor"] = module.processor
134
+
135
+ for sub_name, child in module.named_children():
136
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
137
+
138
+ return processors
139
+
140
+ for name, module in self.named_children():
141
+ fn_recursive_add_processors(name, module, processors)
142
+
143
+ return processors
144
+
145
+ def set_attn_processor(self, processor):
146
+ r"""
147
+ Sets the attention processor to use to compute attention.
148
+
149
+ Parameters:
150
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
151
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
152
+ for **all** `Attention` layers.
153
+
154
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
155
+ processor. This is strongly recommended when setting trainable attention processors.
156
+
157
+ """
158
+ count = len(self.attn_processors.keys())
159
+
160
+ if isinstance(processor, dict) and len(processor) != count:
161
+ raise ValueError(
162
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
163
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
164
+ )
165
+
166
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
167
+ if hasattr(module, "set_processor"):
168
+ if not isinstance(processor, dict):
169
+ module.set_processor(processor)
170
+ else:
171
+ module.set_processor(processor.pop(f"{name}.processor"))
172
+
173
+ for sub_name, child in module.named_children():
174
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
175
+
176
+ for name, module in self.named_children():
177
+ fn_recursive_attn_processor(name, module, processor)
178
+
179
+ def forward(
180
+ self,
181
+ img: Tensor,
182
+ img_ids: Tensor,
183
+ txt: Tensor,
184
+ txt_ids: Tensor,
185
+ timesteps: Tensor,
186
+ y: Tensor,
187
+ guidance: Tensor | None = None,
188
+ ref_img: Tensor | None = None,
189
+ ref_img_ids: Tensor | None = None,
190
+ siglip_inputs: list[Tensor] | None = None,
191
+ ) -> Tensor:
192
+ if img.ndim != 3 or txt.ndim != 3:
193
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
194
+
195
+ # running on sequences img
196
+ img = self.img_in(img)
197
+ vec = self.time_in(timestep_embedding(timesteps, 256))
198
+ if self.params.guidance_embed:
199
+ if guidance is None:
200
+ raise ValueError(
201
+ "Didn't get guidance strength for guidance distilled model."
202
+ )
203
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
204
+ vec = vec + self.vector_in(y)
205
+ txt = self.txt_in(txt)
206
+ if self.feature_embedder is not None and siglip_inputs is not None and len(siglip_inputs) > 0 and self.vision_encoder is not None:
207
+ # processing style feat into textural hidden space
208
+ siglip_embedding = [self.vision_encoder(**emb, output_hidden_states=True) for emb in siglip_inputs]
209
+ # siglip_embedding = [self.vision_encoder(**(emb.to(torch.bfloat16)), output_hidden_states=True) for emb in siglip_inputs]
210
+ siglip_embedding = torch.cat([self.feature_embedder(emb) for emb in siglip_embedding], dim=1)
211
+ txt = torch.cat((siglip_embedding, txt), dim=1)
212
+ siglip_embedding_ids = torch.zeros(
213
+ siglip_embedding.shape[0], siglip_embedding.shape[1], 3
214
+ ).to(txt_ids.device)
215
+ txt_ids = torch.cat((siglip_embedding_ids, txt_ids), dim=1)
216
+
217
+ ids = torch.cat((txt_ids, img_ids), dim=1)
218
+
219
+ # concat ref_img/img
220
+ img_end = img.shape[1]
221
+ if ref_img is not None:
222
+ if isinstance(ref_img, tuple) or isinstance(ref_img, list):
223
+ img_in = [img] + [self.img_in(ref) for ref in ref_img]
224
+ img_ids = [ids] + [ref_ids for ref_ids in ref_img_ids]
225
+ img = torch.cat(img_in, dim=1)
226
+ ids = torch.cat(img_ids, dim=1)
227
+ else:
228
+ img = torch.cat((img, self.img_in(ref_img)), dim=1)
229
+ ids = torch.cat((ids, ref_img_ids), dim=1)
230
+ pe = self.pe_embedder(ids)
231
+
232
+ for index_block, block in enumerate(self.double_blocks):
233
+ if self.training and self.gradient_checkpointing:
234
+ img, txt = torch.utils.checkpoint.checkpoint(
235
+ block,
236
+ img=img,
237
+ txt=txt,
238
+ vec=vec,
239
+ pe=pe,
240
+ use_reentrant=False,
241
+ )
242
+ else:
243
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
244
+
245
+ img = torch.cat((txt, img), 1)
246
+ for block in self.single_blocks:
247
+ if self.training and self.gradient_checkpointing:
248
+ img = torch.utils.checkpoint.checkpoint(
249
+ block, img, vec=vec, pe=pe, use_reentrant=False
250
+ )
251
+ else:
252
+ img = block(img, vec=vec, pe=pe)
253
+ img = img[:, txt.shape[1] :, ...]
254
+ # index img
255
+ img = img[:, :img_end, ...]
256
+
257
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
258
+ return img
uso/flux/modules/__pycache__/autoencoder.cpython-311.pyc ADDED
Binary file (18.9 kB). View file
 
uso/flux/modules/__pycache__/conditioner.cpython-311.pyc ADDED
Binary file (2.6 kB). View file
 
uso/flux/modules/__pycache__/layers.cpython-311.pyc ADDED
Binary file (37.3 kB). View file
 
uso/flux/modules/autoencoder.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2
+ # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from dataclasses import dataclass
17
+
18
+ import torch
19
+ from einops import rearrange
20
+ from torch import Tensor, nn
21
+
22
+
23
+ @dataclass
24
+ class AutoEncoderParams:
25
+ resolution: int
26
+ in_channels: int
27
+ ch: int
28
+ out_ch: int
29
+ ch_mult: list[int]
30
+ num_res_blocks: int
31
+ z_channels: int
32
+ scale_factor: float
33
+ shift_factor: float
34
+
35
+
36
+ def swish(x: Tensor) -> Tensor:
37
+ return x * torch.sigmoid(x)
38
+
39
+
40
+ class AttnBlock(nn.Module):
41
+ def __init__(self, in_channels: int):
42
+ super().__init__()
43
+ self.in_channels = in_channels
44
+
45
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
46
+
47
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
48
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
49
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
50
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
51
+
52
+ def attention(self, h_: Tensor) -> Tensor:
53
+ h_ = self.norm(h_)
54
+ q = self.q(h_)
55
+ k = self.k(h_)
56
+ v = self.v(h_)
57
+
58
+ b, c, h, w = q.shape
59
+ q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
60
+ k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
61
+ v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
62
+ h_ = nn.functional.scaled_dot_product_attention(q, k, v)
63
+
64
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
65
+
66
+ def forward(self, x: Tensor) -> Tensor:
67
+ return x + self.proj_out(self.attention(x))
68
+
69
+
70
+ class ResnetBlock(nn.Module):
71
+ def __init__(self, in_channels: int, out_channels: int):
72
+ super().__init__()
73
+ self.in_channels = in_channels
74
+ out_channels = in_channels if out_channels is None else out_channels
75
+ self.out_channels = out_channels
76
+
77
+ self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
78
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
79
+ self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
80
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
81
+ if self.in_channels != self.out_channels:
82
+ self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
83
+
84
+ def forward(self, x):
85
+ h = x
86
+ h = self.norm1(h)
87
+ h = swish(h)
88
+ h = self.conv1(h)
89
+
90
+ h = self.norm2(h)
91
+ h = swish(h)
92
+ h = self.conv2(h)
93
+
94
+ if self.in_channels != self.out_channels:
95
+ x = self.nin_shortcut(x)
96
+
97
+ return x + h
98
+
99
+
100
+ class Downsample(nn.Module):
101
+ def __init__(self, in_channels: int):
102
+ super().__init__()
103
+ # no asymmetric padding in torch conv, must do it ourselves
104
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
105
+
106
+ def forward(self, x: Tensor):
107
+ pad = (0, 1, 0, 1)
108
+ x = nn.functional.pad(x, pad, mode="constant", value=0)
109
+ x = self.conv(x)
110
+ return x
111
+
112
+
113
+ class Upsample(nn.Module):
114
+ def __init__(self, in_channels: int):
115
+ super().__init__()
116
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
117
+
118
+ def forward(self, x: Tensor):
119
+ x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
120
+ x = self.conv(x)
121
+ return x
122
+
123
+
124
+ class Encoder(nn.Module):
125
+ def __init__(
126
+ self,
127
+ resolution: int,
128
+ in_channels: int,
129
+ ch: int,
130
+ ch_mult: list[int],
131
+ num_res_blocks: int,
132
+ z_channels: int,
133
+ ):
134
+ super().__init__()
135
+ self.ch = ch
136
+ self.num_resolutions = len(ch_mult)
137
+ self.num_res_blocks = num_res_blocks
138
+ self.resolution = resolution
139
+ self.in_channels = in_channels
140
+ # downsampling
141
+ self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
142
+
143
+ curr_res = resolution
144
+ in_ch_mult = (1,) + tuple(ch_mult)
145
+ self.in_ch_mult = in_ch_mult
146
+ self.down = nn.ModuleList()
147
+ block_in = self.ch
148
+ for i_level in range(self.num_resolutions):
149
+ block = nn.ModuleList()
150
+ attn = nn.ModuleList()
151
+ block_in = ch * in_ch_mult[i_level]
152
+ block_out = ch * ch_mult[i_level]
153
+ for _ in range(self.num_res_blocks):
154
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
155
+ block_in = block_out
156
+ down = nn.Module()
157
+ down.block = block
158
+ down.attn = attn
159
+ if i_level != self.num_resolutions - 1:
160
+ down.downsample = Downsample(block_in)
161
+ curr_res = curr_res // 2
162
+ self.down.append(down)
163
+
164
+ # middle
165
+ self.mid = nn.Module()
166
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
167
+ self.mid.attn_1 = AttnBlock(block_in)
168
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
169
+
170
+ # end
171
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
172
+ self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
173
+
174
+ def forward(self, x: Tensor) -> Tensor:
175
+ # downsampling
176
+ hs = [self.conv_in(x)]
177
+ for i_level in range(self.num_resolutions):
178
+ for i_block in range(self.num_res_blocks):
179
+ h = self.down[i_level].block[i_block](hs[-1])
180
+ if len(self.down[i_level].attn) > 0:
181
+ h = self.down[i_level].attn[i_block](h)
182
+ hs.append(h)
183
+ if i_level != self.num_resolutions - 1:
184
+ hs.append(self.down[i_level].downsample(hs[-1]))
185
+
186
+ # middle
187
+ h = hs[-1]
188
+ h = self.mid.block_1(h)
189
+ h = self.mid.attn_1(h)
190
+ h = self.mid.block_2(h)
191
+ # end
192
+ h = self.norm_out(h)
193
+ h = swish(h)
194
+ h = self.conv_out(h)
195
+ return h
196
+
197
+
198
+ class Decoder(nn.Module):
199
+ def __init__(
200
+ self,
201
+ ch: int,
202
+ out_ch: int,
203
+ ch_mult: list[int],
204
+ num_res_blocks: int,
205
+ in_channels: int,
206
+ resolution: int,
207
+ z_channels: int,
208
+ ):
209
+ super().__init__()
210
+ self.ch = ch
211
+ self.num_resolutions = len(ch_mult)
212
+ self.num_res_blocks = num_res_blocks
213
+ self.resolution = resolution
214
+ self.in_channels = in_channels
215
+ self.ffactor = 2 ** (self.num_resolutions - 1)
216
+
217
+ # compute in_ch_mult, block_in and curr_res at lowest res
218
+ block_in = ch * ch_mult[self.num_resolutions - 1]
219
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
220
+ self.z_shape = (1, z_channels, curr_res, curr_res)
221
+
222
+ # z to block_in
223
+ self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
224
+
225
+ # middle
226
+ self.mid = nn.Module()
227
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
228
+ self.mid.attn_1 = AttnBlock(block_in)
229
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
230
+
231
+ # upsampling
232
+ self.up = nn.ModuleList()
233
+ for i_level in reversed(range(self.num_resolutions)):
234
+ block = nn.ModuleList()
235
+ attn = nn.ModuleList()
236
+ block_out = ch * ch_mult[i_level]
237
+ for _ in range(self.num_res_blocks + 1):
238
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
239
+ block_in = block_out
240
+ up = nn.Module()
241
+ up.block = block
242
+ up.attn = attn
243
+ if i_level != 0:
244
+ up.upsample = Upsample(block_in)
245
+ curr_res = curr_res * 2
246
+ self.up.insert(0, up) # prepend to get consistent order
247
+
248
+ # end
249
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
250
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
251
+
252
+ def forward(self, z: Tensor) -> Tensor:
253
+ # z to block_in
254
+ h = self.conv_in(z)
255
+
256
+ # middle
257
+ h = self.mid.block_1(h)
258
+ h = self.mid.attn_1(h)
259
+ h = self.mid.block_2(h)
260
+
261
+ # upsampling
262
+ for i_level in reversed(range(self.num_resolutions)):
263
+ for i_block in range(self.num_res_blocks + 1):
264
+ h = self.up[i_level].block[i_block](h)
265
+ if len(self.up[i_level].attn) > 0:
266
+ h = self.up[i_level].attn[i_block](h)
267
+ if i_level != 0:
268
+ h = self.up[i_level].upsample(h)
269
+
270
+ # end
271
+ h = self.norm_out(h)
272
+ h = swish(h)
273
+ h = self.conv_out(h)
274
+ return h
275
+
276
+
277
+ class DiagonalGaussian(nn.Module):
278
+ def __init__(self, sample: bool = True, chunk_dim: int = 1):
279
+ super().__init__()
280
+ self.sample = sample
281
+ self.chunk_dim = chunk_dim
282
+
283
+ def forward(self, z: Tensor) -> Tensor:
284
+ mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
285
+ if self.sample:
286
+ std = torch.exp(0.5 * logvar)
287
+ return mean + std * torch.randn_like(mean)
288
+ else:
289
+ return mean
290
+
291
+
292
+ class AutoEncoder(nn.Module):
293
+ def __init__(self, params: AutoEncoderParams):
294
+ super().__init__()
295
+ self.encoder = Encoder(
296
+ resolution=params.resolution,
297
+ in_channels=params.in_channels,
298
+ ch=params.ch,
299
+ ch_mult=params.ch_mult,
300
+ num_res_blocks=params.num_res_blocks,
301
+ z_channels=params.z_channels,
302
+ )
303
+ self.decoder = Decoder(
304
+ resolution=params.resolution,
305
+ in_channels=params.in_channels,
306
+ ch=params.ch,
307
+ out_ch=params.out_ch,
308
+ ch_mult=params.ch_mult,
309
+ num_res_blocks=params.num_res_blocks,
310
+ z_channels=params.z_channels,
311
+ )
312
+ self.reg = DiagonalGaussian()
313
+
314
+ self.scale_factor = params.scale_factor
315
+ self.shift_factor = params.shift_factor
316
+
317
+ def encode(self, x: Tensor) -> Tensor:
318
+ z = self.reg(self.encoder(x))
319
+ z = self.scale_factor * (z - self.shift_factor)
320
+ return z
321
+
322
+ def decode(self, z: Tensor) -> Tensor:
323
+ z = z / self.scale_factor + self.shift_factor
324
+ return self.decoder(z)
325
+
326
+ def forward(self, x: Tensor) -> Tensor:
327
+ return self.decode(self.encode(x))
uso/flux/modules/conditioner.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2
+ # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from torch import Tensor, nn
17
+ from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel,
18
+ T5Tokenizer)
19
+
20
+
21
+ class HFEmbedder(nn.Module):
22
+ def __init__(self, version: str, max_length: int, **hf_kwargs):
23
+ super().__init__()
24
+ self.is_clip = "clip" in version.lower()
25
+ self.max_length = max_length
26
+ self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
27
+
28
+ if self.is_clip:
29
+ self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
30
+ self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
31
+ else:
32
+ self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
33
+ self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
34
+
35
+ self.hf_module = self.hf_module.eval().requires_grad_(False)
36
+
37
+ def forward(self, text: list[str]) -> Tensor:
38
+ batch_encoding = self.tokenizer(
39
+ text,
40
+ truncation=True,
41
+ max_length=self.max_length,
42
+ return_length=False,
43
+ return_overflowing_tokens=False,
44
+ padding="max_length",
45
+ return_tensors="pt",
46
+ )
47
+
48
+ outputs = self.hf_module(
49
+ input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
50
+ attention_mask=None,
51
+ output_hidden_states=False,
52
+ )
53
+ return outputs[self.output_key]