diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..602834a23fe7bea552229ebbcb1d6679363bc38a 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+*.png filter=lfs diff=lfs merge=lfs -text
+*.jpg filter=lfs diff=lfs merge=lfs -text
+*.webp filter=lfs diff=lfs merge=lfs -text
\ No newline at end of file
diff --git a/README copy.md b/README copy.md
new file mode 100644
index 0000000000000000000000000000000000000000..027815ccc2e9b171f1f0bca4d818d0eba2278414
--- /dev/null
+++ b/README copy.md
@@ -0,0 +1,17 @@
+---
+title: USO
+emoji: 💻
+colorFrom: indigo
+colorTo: purple
+sdk: gradio
+sdk_version: 5.23.3
+app_file: app.py
+pinned: false
+license: apache-2.0
+short_description: Freely Combining Any Subjects with Any Styles Across All Scenarios.
+models:
+ - black-forest-labs/FLUX.1-dev
+ - bytedance-research/UNO
+---
+
+Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..2256462b7c3f01d992faea87f6a5932f8f014148
--- /dev/null
+++ b/app.py
@@ -0,0 +1,242 @@
+# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import dataclasses
+import json
+import os
+from pathlib import Path
+
+import gradio as gr
+import torch
+import spaces
+
+from uso.flux.pipeline import USOPipeline
+from transformers import SiglipVisionModel, SiglipImageProcessor
+
+
+with open("assets/uso_text.svg", "r", encoding="utf-8") as svg_file:
+ text_content = svg_file.read()
+
+with open("assets/uso_logo.svg", "r", encoding="utf-8") as svg_file:
+ logo_content = svg_file.read()
+
+title = f"""
+
+ {text_content}
+ by UXO Team
+ {logo_content}
+
+""".strip()
+
+badges_text = r"""
+
+
+
+
+
+
+""".strip()
+
+tips = """
+ 📌 **What is USO?**
+USO is a unified style-subject optimized customization model and the latest addition to the UXO family ( USO and UNO).
+It can freely combine arbitrary subjects with arbitrary styles in any scenarios.
+
+ 💡 **How to use?**
+We provide step-by-step instructions in our Github Repo.
+Additionally, try the examples provided below the demo to quickly get familiar with USO and spark your creativity!
+
+ ⚡️ The model is trained on 1024x1024 resolution and supports 3 types of usage:
+* **Only content img**: support following types:
+ * Subject/Identity-driven (supports natural prompt, e.g., *A clock on the table.* *The woman near the sea.*, excels in producing **photorealistic portraits**)
+ * Style edit (layout-preserved): *Transform the image into Ghibli style/Pixel style/Retro comic style/Watercolor painting style...*.
+ * Style edit (layout-shift): *Ghibli style, the man on the beach.*.
+* **Only style img**: Reference input style and generate anything following prompt. Excelling in this and further support multiple style references (in beta).
+* **Content img + style img**: Place the content into the desired style.
+ * Layout-preserved: set prompt to **empty**.
+ * Layout-shift: using natural prompt."""
+
+star = r"""
+If USO is helpful, please help to ⭐ our Github Repo. Thanks a lot!"""
+
+def get_examples(examples_dir: str = "assets/examples") -> list:
+ examples = Path(examples_dir)
+ ans = []
+ for example in examples.iterdir():
+ if not example.is_dir() or len(os.listdir(example)) == 0:
+ continue
+ with open(example / "config.json") as f:
+ example_dict = json.load(f)
+
+
+ example_list = []
+
+ example_list.append(example_dict["usage"]) # case for
+ example_list.append(example_dict["prompt"]) # prompt
+
+ for key in ["image_ref1", "image_ref2", "image_ref3"]:
+ if key in example_dict:
+ example_list.append(str(example / example_dict[key]))
+ else:
+ example_list.append(None)
+
+ example_list.append(example_dict["seed"])
+ ans.append(example_list)
+ return ans
+
+
+def create_demo(
+ model_type: str,
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
+ offload: bool = False,
+):
+ pipeline = USOPipeline(
+ model_type, device, offload, only_lora=True, lora_rank=128, hf_download=True
+ )
+ print("USOPipeline loaded successfully")
+
+ siglip_processor = SiglipImageProcessor.from_pretrained(
+ "google/siglip-so400m-patch14-384"
+ )
+ siglip_model = SiglipVisionModel.from_pretrained(
+ "google/siglip-so400m-patch14-384"
+ )
+ siglip_model.eval()
+ siglip_model.to(device)
+ pipeline.model.vision_encoder = siglip_model
+ pipeline.model.vision_encoder_processor = siglip_processor
+ print("SigLIP model loaded successfully")
+
+ pipeline.gradio_generate = spaces.GPU(duration=120)(pipeline.gradio_generate)
+ with gr.Blocks() as demo:
+ gr.Markdown(title)
+ gr.Markdown(badges_text)
+ gr.Markdown(tips)
+ with gr.Row():
+ with gr.Column():
+ prompt = gr.Textbox(label="Prompt", value="A beautiful woman.")
+ with gr.Row():
+ image_prompt1 = gr.Image(
+ label="Content Reference Img", visible=True, interactive=True, type="pil"
+ )
+ image_prompt2 = gr.Image(
+ label="Style Reference Img", visible=True, interactive=True, type="pil"
+ )
+ image_prompt3 = gr.Image(
+ label="Extra Style Reference Img (Beta)", visible=True, interactive=True, type="pil"
+ )
+
+ with gr.Row():
+ with gr.Row():
+ width = gr.Slider(
+ 512, 1536, 1024, step=16, label="Generation Width"
+ )
+ height = gr.Slider(
+ 512, 1536, 1024, step=16, label="Generation Height"
+ )
+ with gr.Row():
+ with gr.Row():
+ keep_size = gr.Checkbox(
+ label="Keep input size",
+ value=False,
+ interactive=True
+ )
+ with gr.Column():
+ gr.Markdown("Set it to True if you only need style editing or want to keep the layout.")
+
+ with gr.Accordion("Advanced Options", open=True):
+ with gr.Row():
+ num_steps = gr.Slider(
+ 1, 50, 25, step=1, label="Number of steps"
+ )
+ guidance = gr.Slider(
+ 1.0, 5.0, 4.0, step=0.1, label="Guidance", interactive=True
+ )
+ content_long_size = gr.Slider(
+ 0, 1024, 512, step=16, label="Content reference size"
+ )
+ seed = gr.Number(-1, label="Seed (-1 for random)")
+
+ generate_btn = gr.Button("Generate")
+ gr.Markdown(star)
+
+ with gr.Column():
+ output_image = gr.Image(label="Generated Image")
+ download_btn = gr.File(
+ label="Download full-resolution", type="filepath", interactive=False
+ )
+
+ inputs = [
+ prompt,
+ image_prompt1,
+ image_prompt2,
+ image_prompt3,
+ seed,
+ width,
+ height,
+ guidance,
+ num_steps,
+ keep_size,
+ content_long_size,
+ ]
+ generate_btn.click(
+ fn=pipeline.gradio_generate,
+ inputs=inputs,
+ outputs=[output_image, download_btn],
+ )
+
+ example_text = gr.Text("", visible=False, label="Case For:")
+ examples = get_examples("./assets/gradio_examples")
+
+ gr.Examples(
+ examples=examples,
+ inputs=[
+ example_text,
+ prompt,
+ image_prompt1,
+ image_prompt2,
+ image_prompt3,
+ seed,
+ ],
+ # cache_examples='lazy',
+ outputs=[output_image, download_btn],
+ fn=pipeline.gradio_generate,
+ )
+
+ return demo
+
+
+if __name__ == "__main__":
+ from typing import Literal
+
+ from transformers import HfArgumentParser
+
+ @dataclasses.dataclass
+ class AppArgs:
+ name: Literal["flux-dev", "flux-dev-fp8", "flux-schnell", "flux-krea-dev"] = "flux-dev"
+ device: Literal["cuda", "cpu"] = "cuda" if torch.cuda.is_available() else "cpu"
+ offload: bool = dataclasses.field(
+ default=False,
+ metadata={
+ "help": "If True, sequantial offload the models(ae, dit, text encoder) to CPU if not used."
+ },
+ )
+ port: int = 7860
+
+ parser = HfArgumentParser([AppArgs])
+ args_tuple = parser.parse_args_into_dataclasses() # type: tuple[AppArgs]
+ args = args_tuple[0]
+
+ demo = create_demo(args.name, args.device, args.offload)
+ demo.launch(server_port=args.port)
diff --git a/assets/gradio_examples/1subject/config.json b/assets/gradio_examples/1subject/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..2f470599cdd9382e12bded66467a33c6c4329197
--- /dev/null
+++ b/assets/gradio_examples/1subject/config.json
@@ -0,0 +1,6 @@
+{
+ "prompt": "Wool felt style, a clock in the jungle.",
+ "seed": 3407,
+ "usage": "Subject-driven",
+ "image_ref1": "./ref.jpg"
+}
\ No newline at end of file
diff --git a/assets/gradio_examples/1subject/ref.jpg b/assets/gradio_examples/1subject/ref.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..38615fba19fda54eec4df211606e95ec89cbae4b
--- /dev/null
+++ b/assets/gradio_examples/1subject/ref.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0e1eb6ca2c944f3bfaed3ace56f5f186ed073a477e0333e0237253d98f0c9267
+size 139451
diff --git a/assets/gradio_examples/2identity/config.json b/assets/gradio_examples/2identity/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..be3789e3a29e7e4156440a397fa30a265fd5fc42
--- /dev/null
+++ b/assets/gradio_examples/2identity/config.json
@@ -0,0 +1,6 @@
+{
+ "prompt": "The girl is riding a bike in a street.",
+ "seed": 3407,
+ "usage": "Identity-driven",
+ "image_ref1": "./ref.webp"
+}
\ No newline at end of file
diff --git a/assets/gradio_examples/2identity/ref.webp b/assets/gradio_examples/2identity/ref.webp
new file mode 100644
index 0000000000000000000000000000000000000000..be2cb3ababd23173ad3c60a6336f00d685956f66
--- /dev/null
+++ b/assets/gradio_examples/2identity/ref.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4e97502bd7eebd6692604f891f836f25c7c30dcac8d15c4d42cc874efc51fcc5
+size 85772
diff --git a/assets/gradio_examples/3identity/config.json b/assets/gradio_examples/3identity/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..5816a5f2b6d590c851698d90aa9ba93da9fc87d6
--- /dev/null
+++ b/assets/gradio_examples/3identity/config.json
@@ -0,0 +1,6 @@
+{
+ "prompt": "The man in flower shops carefully match bouquets, conveying beautiful emotions and blessings with flowers.",
+ "seed": 3407,
+ "usage": "Identity-driven",
+ "image_ref1": "./ref.jpg"
+}
\ No newline at end of file
diff --git a/assets/gradio_examples/3identity/ref.jpg b/assets/gradio_examples/3identity/ref.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..2b8c787868211c1fc9d2678f917516479f871bdc
--- /dev/null
+++ b/assets/gradio_examples/3identity/ref.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2730103b6b9ebaf47b44ef9a9d7fbb722de7878a101af09f0b85f8dfadb4c8a4
+size 30572
diff --git a/assets/gradio_examples/4identity/config.json b/assets/gradio_examples/4identity/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..94497918462a8c0e09bc80eb3d9549e89ba6e7d8
--- /dev/null
+++ b/assets/gradio_examples/4identity/config.json
@@ -0,0 +1,6 @@
+{
+ "prompt": "Transform the image into Ghibli style.",
+ "seed": 3407,
+ "usage": "Identity-driven",
+ "image_ref1": "./ref.webp"
+}
\ No newline at end of file
diff --git a/assets/gradio_examples/4identity/ref.webp b/assets/gradio_examples/4identity/ref.webp
new file mode 100644
index 0000000000000000000000000000000000000000..9f7cae34140b5d373c7ca51d4c2e8f67cf039319
--- /dev/null
+++ b/assets/gradio_examples/4identity/ref.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f8ed8aa1c0714c939392e2c033735d6266e53266079bb300cbf05a6824a49f9f
+size 38764
diff --git a/assets/gradio_examples/5style/config.json b/assets/gradio_examples/5style/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..9303fd6dfd485517ff34607b0058dea9cbec2922
--- /dev/null
+++ b/assets/gradio_examples/5style/config.json
@@ -0,0 +1,6 @@
+{
+ "prompt": "A cat sleeping on a chair.",
+ "seed": 3407,
+ "usage": "Style-driven",
+ "image_ref2": "./ref.webp"
+}
\ No newline at end of file
diff --git a/assets/gradio_examples/5style/ref.webp b/assets/gradio_examples/5style/ref.webp
new file mode 100644
index 0000000000000000000000000000000000000000..8c2f06b9cd2c331914842ca76d6ed58a80a50085
--- /dev/null
+++ b/assets/gradio_examples/5style/ref.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9ebf56d2d20ae5c49a582ff6bfef64b13022d0c624d9de25ed91047380fdfcfe
+size 52340
diff --git a/assets/gradio_examples/6style/config.json b/assets/gradio_examples/6style/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..172585f3bae2f7e73d25dea3393f46435123f95e
--- /dev/null
+++ b/assets/gradio_examples/6style/config.json
@@ -0,0 +1,6 @@
+{
+ "prompt": "A beautiful woman.",
+ "seed": 3407,
+ "usage": "Style-driven",
+ "image_ref2": "./ref.webp"
+}
\ No newline at end of file
diff --git a/assets/gradio_examples/6style/ref.webp b/assets/gradio_examples/6style/ref.webp
new file mode 100644
index 0000000000000000000000000000000000000000..5c1e9ddb0d4a1c7658ddf607019bc76b1a4b6ce0
--- /dev/null
+++ b/assets/gradio_examples/6style/ref.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:40c013341c8708b53094e3eaa377b3dfccdc9e77e215ad15d2ac2e875b4c494a
+size 58202
diff --git a/assets/gradio_examples/7style_subject/config.json b/assets/gradio_examples/7style_subject/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..6dc3da8233e6a6b5f5a9192fb65227352d2c5cdd
--- /dev/null
+++ b/assets/gradio_examples/7style_subject/config.json
@@ -0,0 +1,7 @@
+{
+ "prompt": "",
+ "seed": 321,
+ "usage": "Style-subject-driven (layout-preserved)",
+ "image_ref1": "./ref1.webp",
+ "image_ref2": "./ref2.webp"
+}
\ No newline at end of file
diff --git a/assets/gradio_examples/7style_subject/ref1.webp b/assets/gradio_examples/7style_subject/ref1.webp
new file mode 100644
index 0000000000000000000000000000000000000000..9f7cae34140b5d373c7ca51d4c2e8f67cf039319
--- /dev/null
+++ b/assets/gradio_examples/7style_subject/ref1.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f8ed8aa1c0714c939392e2c033735d6266e53266079bb300cbf05a6824a49f9f
+size 38764
diff --git a/assets/gradio_examples/7style_subject/ref2.webp b/assets/gradio_examples/7style_subject/ref2.webp
new file mode 100644
index 0000000000000000000000000000000000000000..c06e99d265a2b42b48875ec086457a5414eea627
--- /dev/null
+++ b/assets/gradio_examples/7style_subject/ref2.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:175d6e5b975b4d494950250740c0fe371a7e9b2c93c59a3ae82b82be72ccc0f6
+size 14168
diff --git a/assets/gradio_examples/8style_subject/config.json b/assets/gradio_examples/8style_subject/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..ef542fde59e71687ee9e536649b8481a880d57a8
--- /dev/null
+++ b/assets/gradio_examples/8style_subject/config.json
@@ -0,0 +1,7 @@
+{
+ "prompt": "The woman gave an impassioned speech on the podium.",
+ "seed": 321,
+ "usage": "Style-subject-driven (layout-shifted)",
+ "image_ref1": "./ref1.webp",
+ "image_ref2": "./ref2.webp"
+}
\ No newline at end of file
diff --git a/assets/gradio_examples/8style_subject/ref1.webp b/assets/gradio_examples/8style_subject/ref1.webp
new file mode 100644
index 0000000000000000000000000000000000000000..9f7cae34140b5d373c7ca51d4c2e8f67cf039319
--- /dev/null
+++ b/assets/gradio_examples/8style_subject/ref1.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f8ed8aa1c0714c939392e2c033735d6266e53266079bb300cbf05a6824a49f9f
+size 38764
diff --git a/assets/gradio_examples/8style_subject/ref2.webp b/assets/gradio_examples/8style_subject/ref2.webp
new file mode 100644
index 0000000000000000000000000000000000000000..5ab3db1106adb67090bd533e0ddf1094d7c07065
--- /dev/null
+++ b/assets/gradio_examples/8style_subject/ref2.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0235262d9bd1070155536352ccf195f9875ead0d3379dee7285c0aaae79f6464
+size 39098
diff --git a/assets/gradio_examples/9mix_style/config.json b/assets/gradio_examples/9mix_style/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..8c9aa236dc8c7bac8dd441aa9a86c786824a9d17
--- /dev/null
+++ b/assets/gradio_examples/9mix_style/config.json
@@ -0,0 +1,7 @@
+{
+ "prompt": "A man.",
+ "seed": 321,
+ "usage": "Multi-style-driven",
+ "image_ref2": "./ref1.webp",
+ "image_ref3": "./ref2.webp"
+}
\ No newline at end of file
diff --git a/assets/gradio_examples/9mix_style/ref1.webp b/assets/gradio_examples/9mix_style/ref1.webp
new file mode 100644
index 0000000000000000000000000000000000000000..1c02f7fe712a295f858a666f211d994cecaa7ac1
--- /dev/null
+++ b/assets/gradio_examples/9mix_style/ref1.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a1d272a0ecb03126503446b00a2152deab2045f89ac2c01f948e1099589d2862
+size 141886
diff --git a/assets/gradio_examples/9mix_style/ref2.webp b/assets/gradio_examples/9mix_style/ref2.webp
new file mode 100644
index 0000000000000000000000000000000000000000..e99715757eb80c277f42a4f5295251c30f1af45f
--- /dev/null
+++ b/assets/gradio_examples/9mix_style/ref2.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b1ce04559726509672ce859d617a08d8dff8b2fe28f503fecbca7a5f66082882
+size 290260
diff --git a/assets/gradio_examples/identity1.jpg b/assets/gradio_examples/identity1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..2b8c787868211c1fc9d2678f917516479f871bdc
--- /dev/null
+++ b/assets/gradio_examples/identity1.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2730103b6b9ebaf47b44ef9a9d7fbb722de7878a101af09f0b85f8dfadb4c8a4
+size 30572
diff --git a/assets/gradio_examples/identity1_result.png b/assets/gradio_examples/identity1_result.png
new file mode 100644
index 0000000000000000000000000000000000000000..e4d895db3f534950cbd804e594f450bf55f6507c
--- /dev/null
+++ b/assets/gradio_examples/identity1_result.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7684256e44ce1bd4ada1e77a12674432eddd95b07fb388673899139afc56d864
+size 1538828
diff --git a/assets/gradio_examples/identity2.webp b/assets/gradio_examples/identity2.webp
new file mode 100644
index 0000000000000000000000000000000000000000..9f7cae34140b5d373c7ca51d4c2e8f67cf039319
--- /dev/null
+++ b/assets/gradio_examples/identity2.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f8ed8aa1c0714c939392e2c033735d6266e53266079bb300cbf05a6824a49f9f
+size 38764
diff --git a/assets/gradio_examples/identity2_style2_result.webp b/assets/gradio_examples/identity2_style2_result.webp
new file mode 100644
index 0000000000000000000000000000000000000000..2d28e2e08181b1862fb30521bec67989ee17eb30
--- /dev/null
+++ b/assets/gradio_examples/identity2_style2_result.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8376b6dc02d304616c09ecf09c7dbabb16c7c9142fb4db21f576a15a1ec24062
+size 43892
diff --git a/assets/gradio_examples/style1.webp b/assets/gradio_examples/style1.webp
new file mode 100644
index 0000000000000000000000000000000000000000..8c2f06b9cd2c331914842ca76d6ed58a80a50085
--- /dev/null
+++ b/assets/gradio_examples/style1.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9ebf56d2d20ae5c49a582ff6bfef64b13022d0c624d9de25ed91047380fdfcfe
+size 52340
diff --git a/assets/gradio_examples/style1_result.webp b/assets/gradio_examples/style1_result.webp
new file mode 100644
index 0000000000000000000000000000000000000000..0b23eed9991c7ffab948d60af5dcab89bcf03aae
--- /dev/null
+++ b/assets/gradio_examples/style1_result.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:16a4353dd83b1c48499e222d6f77904e1fda23c1649ea5f6cca6b00b0fca3069
+size 61062
diff --git a/assets/gradio_examples/style2.webp b/assets/gradio_examples/style2.webp
new file mode 100644
index 0000000000000000000000000000000000000000..5ab3db1106adb67090bd533e0ddf1094d7c07065
--- /dev/null
+++ b/assets/gradio_examples/style2.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0235262d9bd1070155536352ccf195f9875ead0d3379dee7285c0aaae79f6464
+size 39098
diff --git a/assets/gradio_examples/style3.webp b/assets/gradio_examples/style3.webp
new file mode 100644
index 0000000000000000000000000000000000000000..1c02f7fe712a295f858a666f211d994cecaa7ac1
--- /dev/null
+++ b/assets/gradio_examples/style3.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a1d272a0ecb03126503446b00a2152deab2045f89ac2c01f948e1099589d2862
+size 141886
diff --git a/assets/gradio_examples/style3_style4_result.webp b/assets/gradio_examples/style3_style4_result.webp
new file mode 100644
index 0000000000000000000000000000000000000000..2bc1bfc2258d5193a300c560563b3b21eaa434d4
--- /dev/null
+++ b/assets/gradio_examples/style3_style4_result.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d09a5e429cc1d059aecd041e061868cd8e5b59f4718bb0f926fd84364f3794b0
+size 172716
diff --git a/assets/gradio_examples/style4.webp b/assets/gradio_examples/style4.webp
new file mode 100644
index 0000000000000000000000000000000000000000..e99715757eb80c277f42a4f5295251c30f1af45f
--- /dev/null
+++ b/assets/gradio_examples/style4.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b1ce04559726509672ce859d617a08d8dff8b2fe28f503fecbca7a5f66082882
+size 290260
diff --git a/assets/gradio_examples/z_mix_style/config.json b/assets/gradio_examples/z_mix_style/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..db3579ae23a361a3e50cd4d767ef5f9426952bb6
--- /dev/null
+++ b/assets/gradio_examples/z_mix_style/config.json
@@ -0,0 +1,7 @@
+{
+ "prompt": "Boat on water.",
+ "seed": 321,
+ "usage": "Multi-style-driven",
+ "image_ref2": "./ref1.png",
+ "image_ref3": "./ref2.png"
+}
\ No newline at end of file
diff --git a/assets/gradio_examples/z_mix_style/ref1.png b/assets/gradio_examples/z_mix_style/ref1.png
new file mode 100644
index 0000000000000000000000000000000000000000..e923d0ee6951cf100075015da04a686d50bd3698
--- /dev/null
+++ b/assets/gradio_examples/z_mix_style/ref1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5c31ba662c85f4032abf079dfeb9cba08d797b7b63f1d661c5270b373b00d095
+size 26149
diff --git a/assets/gradio_examples/z_mix_style/ref2.png b/assets/gradio_examples/z_mix_style/ref2.png
new file mode 100644
index 0000000000000000000000000000000000000000..8f850831089436bfc214349373a808cbf83cf54d
--- /dev/null
+++ b/assets/gradio_examples/z_mix_style/ref2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c47d23d5ffdbf30b4a8f6c1bc5d07a730825eaac8363c13bdac8e3bb8c330aed
+size 14666
diff --git a/assets/gradio_examples/zz_t2i/config.json b/assets/gradio_examples/zz_t2i/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..e475d1043a4e13638495ac1db4118d7a5865a00c
--- /dev/null
+++ b/assets/gradio_examples/zz_t2i/config.json
@@ -0,0 +1,5 @@
+{
+ "prompt": "A beautiful woman.",
+ "seed": -1,
+ "usage": "Text-to-image"
+}
\ No newline at end of file
diff --git a/assets/teaser.webp b/assets/teaser.webp
new file mode 100644
index 0000000000000000000000000000000000000000..d6916918b54307276b3fc8a5bc269c62dd7f2989
--- /dev/null
+++ b/assets/teaser.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:543c724f6b929303046ae481672567fe4a9620f0af5ca1dfff215dc7a2cbff5f
+size 1674736
diff --git a/assets/uso.webp b/assets/uso.webp
new file mode 100644
index 0000000000000000000000000000000000000000..0caec5b7d5824b89670f7cbeec120f872d406830
--- /dev/null
+++ b/assets/uso.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:772957e867da33550437fa547202d0f995011353ef9a24036d23596dae1a1632
+size 58234
diff --git a/assets/uso_logo.svg b/assets/uso_logo.svg
new file mode 100644
index 0000000000000000000000000000000000000000..dc91b85132492bc6c054b5184f24c236656a7569
--- /dev/null
+++ b/assets/uso_logo.svg
@@ -0,0 +1,880 @@
+
+
diff --git a/assets/uso_text.svg b/assets/uso_text.svg
new file mode 100644
index 0000000000000000000000000000000000000000..86153554bc20737c394cac6d50ee58d9f277572e
--- /dev/null
+++ b/assets/uso_text.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..61694b57e446afb1def8734413b85d3c4abf836a
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,19 @@
+accelerate==1.1.1
+deepspeed==0.14.4
+einops==0.8.0
+transformers==4.43.3
+huggingface-hub
+diffusers==0.30.1
+sentencepiece==0.2.0
+gradio==5.22.0
+opencv-python
+matplotlib
+safetensors==0.4.5
+scipy==1.10.1
+numpy==1.24.4
+onnxruntime-gpu
+# httpx==0.23.3
+git+https://github.com/openai/CLIP.git
+--extra-index-url https://download.pytorch.org/whl/cu124
+torch==2.4.0
+torchvision==0.19.0
diff --git a/uso/flux/math.py b/uso/flux/math.py
new file mode 100644
index 0000000000000000000000000000000000000000..2461437371d22a60eab7df4b5f5cb371dd692fe9
--- /dev/null
+++ b/uso/flux/math.py
@@ -0,0 +1,45 @@
+# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
+# Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+from einops import rearrange
+from torch import Tensor
+
+
+def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
+ q, k = apply_rope(q, k, pe)
+
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
+ x = rearrange(x, "B H L D -> B L (H D)")
+
+ return x
+
+
+def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
+ assert dim % 2 == 0
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
+ omega = 1.0 / (theta**scale)
+ out = torch.einsum("...n,d->...nd", pos, omega)
+ out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
+ out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
+ return out.float()
+
+
+def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
diff --git a/uso/flux/model.py b/uso/flux/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..43d4c80227b80a50585df4afcd21d1449ca3f61d
--- /dev/null
+++ b/uso/flux/model.py
@@ -0,0 +1,258 @@
+# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
+# Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+
+import torch
+from torch import Tensor, nn
+
+from .modules.layers import (
+ DoubleStreamBlock,
+ EmbedND,
+ LastLayer,
+ MLPEmbedder,
+ SingleStreamBlock,
+ timestep_embedding,
+ SigLIPMultiFeatProjModel,
+)
+import os
+
+
+@dataclass
+class FluxParams:
+ in_channels: int
+ vec_in_dim: int
+ context_in_dim: int
+ hidden_size: int
+ mlp_ratio: float
+ num_heads: int
+ depth: int
+ depth_single_blocks: int
+ axes_dim: list[int]
+ theta: int
+ qkv_bias: bool
+ guidance_embed: bool
+
+
+class Flux(nn.Module):
+ """
+ Transformer model for flow matching on sequences.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ def __init__(self, params: FluxParams):
+ super().__init__()
+
+ self.params = params
+ self.in_channels = params.in_channels
+ self.out_channels = self.in_channels
+ if params.hidden_size % params.num_heads != 0:
+ raise ValueError(
+ f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
+ )
+ pe_dim = params.hidden_size // params.num_heads
+ if sum(params.axes_dim) != pe_dim:
+ raise ValueError(
+ f"Got {params.axes_dim} but expected positional dim {pe_dim}"
+ )
+ self.hidden_size = params.hidden_size
+ self.num_heads = params.num_heads
+ self.pe_embedder = EmbedND(
+ dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim
+ )
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
+ self.guidance_in = (
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
+ if params.guidance_embed
+ else nn.Identity()
+ )
+ self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
+
+ self.double_blocks = nn.ModuleList(
+ [
+ DoubleStreamBlock(
+ self.hidden_size,
+ self.num_heads,
+ mlp_ratio=params.mlp_ratio,
+ qkv_bias=params.qkv_bias,
+ )
+ for _ in range(params.depth)
+ ]
+ )
+
+ self.single_blocks = nn.ModuleList(
+ [
+ SingleStreamBlock(
+ self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio
+ )
+ for _ in range(params.depth_single_blocks)
+ ]
+ )
+
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
+ self.gradient_checkpointing = False
+
+ # feature embedder for siglip multi-feat inputs
+ self.feature_embedder = SigLIPMultiFeatProjModel(
+ siglip_token_nums=729,
+ style_token_nums=64,
+ siglip_token_dims=1152,
+ hidden_size=self.hidden_size,
+ context_layer_norm=True,
+ )
+ print("use semantic encoder siglip multi-feat to encode style image")
+
+ self.vision_encoder = None
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if hasattr(module, "gradient_checkpointing"):
+ module.gradient_checkpointing = value
+
+ @property
+ def attn_processors(self):
+ # set recursively
+ processors = {} # type: dict[str, nn.Module]
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors):
+ if hasattr(module, "set_processor"):
+ processors[f"{name}.processor"] = module.processor
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ def set_attn_processor(self, processor):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def forward(
+ self,
+ img: Tensor,
+ img_ids: Tensor,
+ txt: Tensor,
+ txt_ids: Tensor,
+ timesteps: Tensor,
+ y: Tensor,
+ guidance: Tensor | None = None,
+ ref_img: Tensor | None = None,
+ ref_img_ids: Tensor | None = None,
+ siglip_inputs: list[Tensor] | None = None,
+ ) -> Tensor:
+ if img.ndim != 3 or txt.ndim != 3:
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
+
+ # running on sequences img
+ img = self.img_in(img)
+ vec = self.time_in(timestep_embedding(timesteps, 256))
+ if self.params.guidance_embed:
+ if guidance is None:
+ raise ValueError(
+ "Didn't get guidance strength for guidance distilled model."
+ )
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
+ vec = vec + self.vector_in(y)
+ txt = self.txt_in(txt)
+ if self.feature_embedder is not None and siglip_inputs is not None and len(siglip_inputs) > 0 and self.vision_encoder is not None:
+ # processing style feat into textural hidden space
+ siglip_embedding = [self.vision_encoder(**emb, output_hidden_states=True) for emb in siglip_inputs]
+ # siglip_embedding = [self.vision_encoder(**(emb.to(torch.bfloat16)), output_hidden_states=True) for emb in siglip_inputs]
+ siglip_embedding = torch.cat([self.feature_embedder(emb) for emb in siglip_embedding], dim=1)
+ txt = torch.cat((siglip_embedding, txt), dim=1)
+ siglip_embedding_ids = torch.zeros(
+ siglip_embedding.shape[0], siglip_embedding.shape[1], 3
+ ).to(txt_ids.device)
+ txt_ids = torch.cat((siglip_embedding_ids, txt_ids), dim=1)
+
+ ids = torch.cat((txt_ids, img_ids), dim=1)
+
+ # concat ref_img/img
+ img_end = img.shape[1]
+ if ref_img is not None:
+ if isinstance(ref_img, tuple) or isinstance(ref_img, list):
+ img_in = [img] + [self.img_in(ref) for ref in ref_img]
+ img_ids = [ids] + [ref_ids for ref_ids in ref_img_ids]
+ img = torch.cat(img_in, dim=1)
+ ids = torch.cat(img_ids, dim=1)
+ else:
+ img = torch.cat((img, self.img_in(ref_img)), dim=1)
+ ids = torch.cat((ids, ref_img_ids), dim=1)
+ pe = self.pe_embedder(ids)
+
+ for index_block, block in enumerate(self.double_blocks):
+ if self.training and self.gradient_checkpointing:
+ img, txt = torch.utils.checkpoint.checkpoint(
+ block,
+ img=img,
+ txt=txt,
+ vec=vec,
+ pe=pe,
+ use_reentrant=False,
+ )
+ else:
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
+
+ img = torch.cat((txt, img), 1)
+ for block in self.single_blocks:
+ if self.training and self.gradient_checkpointing:
+ img = torch.utils.checkpoint.checkpoint(
+ block, img, vec=vec, pe=pe, use_reentrant=False
+ )
+ else:
+ img = block(img, vec=vec, pe=pe)
+ img = img[:, txt.shape[1] :, ...]
+ # index img
+ img = img[:, :img_end, ...]
+
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
+ return img
diff --git a/uso/flux/modules/__pycache__/autoencoder.cpython-311.pyc b/uso/flux/modules/__pycache__/autoencoder.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d858acdb8edb03ae0e3ea8cc271feea557f51b1f
Binary files /dev/null and b/uso/flux/modules/__pycache__/autoencoder.cpython-311.pyc differ
diff --git a/uso/flux/modules/__pycache__/conditioner.cpython-311.pyc b/uso/flux/modules/__pycache__/conditioner.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a7334c90551d0028a5c6e0404b4f64715b12a357
Binary files /dev/null and b/uso/flux/modules/__pycache__/conditioner.cpython-311.pyc differ
diff --git a/uso/flux/modules/__pycache__/layers.cpython-311.pyc b/uso/flux/modules/__pycache__/layers.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..89cc4900e468c9e2e16a625e80cfa53e0d1c5f3d
Binary files /dev/null and b/uso/flux/modules/__pycache__/layers.cpython-311.pyc differ
diff --git a/uso/flux/modules/autoencoder.py b/uso/flux/modules/autoencoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..2543bdf4240e1db5b5dc958e148ac0cb12d9e9e3
--- /dev/null
+++ b/uso/flux/modules/autoencoder.py
@@ -0,0 +1,327 @@
+# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
+# Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+
+import torch
+from einops import rearrange
+from torch import Tensor, nn
+
+
+@dataclass
+class AutoEncoderParams:
+ resolution: int
+ in_channels: int
+ ch: int
+ out_ch: int
+ ch_mult: list[int]
+ num_res_blocks: int
+ z_channels: int
+ scale_factor: float
+ shift_factor: float
+
+
+def swish(x: Tensor) -> Tensor:
+ return x * torch.sigmoid(x)
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels: int):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
+
+ def attention(self, h_: Tensor) -> Tensor:
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ b, c, h, w = q.shape
+ q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
+ k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
+ v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
+ h_ = nn.functional.scaled_dot_product_attention(q, k, v)
+
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
+
+ def forward(self, x: Tensor) -> Tensor:
+ return x + self.proj_out(self.attention(x))
+
+
+class ResnetBlock(nn.Module):
+ def __init__(self, in_channels: int, out_channels: int):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+
+ self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ if self.in_channels != self.out_channels:
+ self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, x):
+ h = x
+ h = self.norm1(h)
+ h = swish(h)
+ h = self.conv1(h)
+
+ h = self.norm2(h)
+ h = swish(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ x = self.nin_shortcut(x)
+
+ return x + h
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels: int):
+ super().__init__()
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
+
+ def forward(self, x: Tensor):
+ pad = (0, 1, 0, 1)
+ x = nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ return x
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels: int):
+ super().__init__()
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
+
+ def forward(self, x: Tensor):
+ x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ x = self.conv(x)
+ return x
+
+
+class Encoder(nn.Module):
+ def __init__(
+ self,
+ resolution: int,
+ in_channels: int,
+ ch: int,
+ ch_mult: list[int],
+ num_res_blocks: int,
+ z_channels: int,
+ ):
+ super().__init__()
+ self.ch = ch
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ # downsampling
+ self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,) + tuple(ch_mult)
+ self.in_ch_mult = in_ch_mult
+ self.down = nn.ModuleList()
+ block_in = self.ch
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch * in_ch_mult[i_level]
+ block_out = ch * ch_mult[i_level]
+ for _ in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
+ block_in = block_out
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions - 1:
+ down.downsample = Downsample(block_in)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
+ self.mid.attn_1 = AttnBlock(block_in)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
+
+ # end
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
+ self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
+
+ def forward(self, x: Tensor) -> Tensor:
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1])
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions - 1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h)
+ # end
+ h = self.norm_out(h)
+ h = swish(h)
+ h = self.conv_out(h)
+ return h
+
+
+class Decoder(nn.Module):
+ def __init__(
+ self,
+ ch: int,
+ out_ch: int,
+ ch_mult: list[int],
+ num_res_blocks: int,
+ in_channels: int,
+ resolution: int,
+ z_channels: int,
+ ):
+ super().__init__()
+ self.ch = ch
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.ffactor = 2 ** (self.num_resolutions - 1)
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ block_in = ch * ch_mult[self.num_resolutions - 1]
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.z_shape = (1, z_channels, curr_res, curr_res)
+
+ # z to block_in
+ self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
+ self.mid.attn_1 = AttnBlock(block_in)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch * ch_mult[i_level]
+ for _ in range(self.num_res_blocks + 1):
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
+ block_in = block_out
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
+
+ def forward(self, z: Tensor) -> Tensor:
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](h)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ h = self.norm_out(h)
+ h = swish(h)
+ h = self.conv_out(h)
+ return h
+
+
+class DiagonalGaussian(nn.Module):
+ def __init__(self, sample: bool = True, chunk_dim: int = 1):
+ super().__init__()
+ self.sample = sample
+ self.chunk_dim = chunk_dim
+
+ def forward(self, z: Tensor) -> Tensor:
+ mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
+ if self.sample:
+ std = torch.exp(0.5 * logvar)
+ return mean + std * torch.randn_like(mean)
+ else:
+ return mean
+
+
+class AutoEncoder(nn.Module):
+ def __init__(self, params: AutoEncoderParams):
+ super().__init__()
+ self.encoder = Encoder(
+ resolution=params.resolution,
+ in_channels=params.in_channels,
+ ch=params.ch,
+ ch_mult=params.ch_mult,
+ num_res_blocks=params.num_res_blocks,
+ z_channels=params.z_channels,
+ )
+ self.decoder = Decoder(
+ resolution=params.resolution,
+ in_channels=params.in_channels,
+ ch=params.ch,
+ out_ch=params.out_ch,
+ ch_mult=params.ch_mult,
+ num_res_blocks=params.num_res_blocks,
+ z_channels=params.z_channels,
+ )
+ self.reg = DiagonalGaussian()
+
+ self.scale_factor = params.scale_factor
+ self.shift_factor = params.shift_factor
+
+ def encode(self, x: Tensor) -> Tensor:
+ z = self.reg(self.encoder(x))
+ z = self.scale_factor * (z - self.shift_factor)
+ return z
+
+ def decode(self, z: Tensor) -> Tensor:
+ z = z / self.scale_factor + self.shift_factor
+ return self.decoder(z)
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.decode(self.encode(x))
diff --git a/uso/flux/modules/conditioner.py b/uso/flux/modules/conditioner.py
new file mode 100644
index 0000000000000000000000000000000000000000..047950827b18f3577c7c43247392c6e0d9295f1f
--- /dev/null
+++ b/uso/flux/modules/conditioner.py
@@ -0,0 +1,53 @@
+# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
+# Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from torch import Tensor, nn
+from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel,
+ T5Tokenizer)
+
+
+class HFEmbedder(nn.Module):
+ def __init__(self, version: str, max_length: int, **hf_kwargs):
+ super().__init__()
+ self.is_clip = "clip" in version.lower()
+ self.max_length = max_length
+ self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
+
+ if self.is_clip:
+ self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
+ self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
+ else:
+ self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
+ self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
+
+ self.hf_module = self.hf_module.eval().requires_grad_(False)
+
+ def forward(self, text: list[str]) -> Tensor:
+ batch_encoding = self.tokenizer(
+ text,
+ truncation=True,
+ max_length=self.max_length,
+ return_length=False,
+ return_overflowing_tokens=False,
+ padding="max_length",
+ return_tensors="pt",
+ )
+
+ outputs = self.hf_module(
+ input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
+ attention_mask=None,
+ output_hidden_states=False,
+ )
+ return outputs[self.output_key]
diff --git a/uso/flux/modules/layers.py b/uso/flux/modules/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bb8139ab78a88863ae822bac111815caf70c7e8
--- /dev/null
+++ b/uso/flux/modules/layers.py
@@ -0,0 +1,631 @@
+# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
+# Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from dataclasses import dataclass
+
+import torch
+from einops import rearrange, repeat
+from torch import Tensor, nn
+
+from ..math import attention, rope
+
+
+class EmbedND(nn.Module):
+ def __init__(self, dim: int, theta: int, axes_dim: list[int]):
+ super().__init__()
+ self.dim = dim
+ self.theta = theta
+ self.axes_dim = axes_dim
+
+ def forward(self, ids: Tensor) -> Tensor:
+ n_axes = ids.shape[-1]
+ emb = torch.cat(
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
+ dim=-3,
+ )
+
+ return emb.unsqueeze(1)
+
+
+def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
+ """
+ Create sinusoidal timestep embeddings.
+ :param t: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an (N, D) Tensor of positional embeddings.
+ """
+ t = time_factor * t
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period)
+ * torch.arange(start=0, end=half, dtype=torch.float32)
+ / half
+ ).to(t.device)
+
+ args = t[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ if torch.is_floating_point(t):
+ embedding = embedding.to(t)
+ return embedding
+
+
+class MLPEmbedder(nn.Module):
+ def __init__(self, in_dim: int, hidden_dim: int):
+ super().__init__()
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
+ self.silu = nn.SiLU()
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.out_layer(self.silu(self.in_layer(x)))
+
+
+class RMSNorm(torch.nn.Module):
+ def __init__(self, dim: int):
+ super().__init__()
+ self.scale = nn.Parameter(torch.ones(dim))
+
+ def forward(self, x: Tensor):
+ x_dtype = x.dtype
+ x = x.float()
+ rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
+ return ((x * rrms) * self.scale.float()).to(dtype=x_dtype)
+
+
+class QKNorm(torch.nn.Module):
+ def __init__(self, dim: int):
+ super().__init__()
+ self.query_norm = RMSNorm(dim)
+ self.key_norm = RMSNorm(dim)
+
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
+ q = self.query_norm(q)
+ k = self.key_norm(k)
+ return q.to(v), k.to(v)
+
+
+class LoRALinearLayer(nn.Module):
+ def __init__(
+ self,
+ in_features,
+ out_features,
+ rank=4,
+ network_alpha=None,
+ device=None,
+ dtype=None,
+ ):
+ super().__init__()
+
+ self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
+ self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
+ # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
+ # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
+ self.network_alpha = network_alpha
+ self.rank = rank
+
+ nn.init.normal_(self.down.weight, std=1 / rank)
+ nn.init.zeros_(self.up.weight)
+
+ def forward(self, hidden_states):
+ orig_dtype = hidden_states.dtype
+ dtype = self.down.weight.dtype
+
+ down_hidden_states = self.down(hidden_states.to(dtype))
+ up_hidden_states = self.up(down_hidden_states)
+
+ if self.network_alpha is not None:
+ up_hidden_states *= self.network_alpha / self.rank
+
+ return up_hidden_states.to(orig_dtype)
+
+
+class FLuxSelfAttnProcessor:
+ def __call__(self, attn, x, pe, **attention_kwargs):
+ qkv = attn.qkv(x)
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
+ q, k = attn.norm(q, k, v)
+ x = attention(q, k, v, pe=pe)
+ x = attn.proj(x)
+ return x
+
+
+class LoraFluxAttnProcessor(nn.Module):
+
+ def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1):
+ super().__init__()
+ self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
+ self.proj_lora = LoRALinearLayer(dim, dim, rank, network_alpha)
+ self.lora_weight = lora_weight
+
+ def __call__(self, attn, x, pe, **attention_kwargs):
+ qkv = attn.qkv(x) + self.qkv_lora(x) * self.lora_weight
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
+ q, k = attn.norm(q, k, v)
+ x = attention(q, k, v, pe=pe)
+ x = attn.proj(x) + self.proj_lora(x) * self.lora_weight
+ return x
+
+
+class SelfAttention(nn.Module):
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.norm = QKNorm(head_dim)
+ self.proj = nn.Linear(dim, dim)
+
+ def forward():
+ pass
+
+
+@dataclass
+class ModulationOut:
+ shift: Tensor
+ scale: Tensor
+ gate: Tensor
+
+
+class Modulation(nn.Module):
+ def __init__(self, dim: int, double: bool):
+ super().__init__()
+ self.is_double = double
+ self.multiplier = 6 if double else 3
+ self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
+
+ def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
+ out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(
+ self.multiplier, dim=-1
+ )
+
+ return (
+ ModulationOut(*out[:3]),
+ ModulationOut(*out[3:]) if self.is_double else None,
+ )
+
+
+class DoubleStreamBlockLoraProcessor(nn.Module):
+ def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1):
+ super().__init__()
+ self.qkv_lora1 = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
+ self.proj_lora1 = LoRALinearLayer(dim, dim, rank, network_alpha)
+ self.qkv_lora2 = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
+ self.proj_lora2 = LoRALinearLayer(dim, dim, rank, network_alpha)
+ self.lora_weight = lora_weight
+
+ def forward(self, attn, img, txt, vec, pe, **attention_kwargs):
+ img_mod1, img_mod2 = attn.img_mod(vec)
+ txt_mod1, txt_mod2 = attn.txt_mod(vec)
+
+ # prepare image for attention
+ img_modulated = attn.img_norm1(img)
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
+ img_qkv = (
+ attn.img_attn.qkv(img_modulated)
+ + self.qkv_lora1(img_modulated) * self.lora_weight
+ )
+ img_q, img_k, img_v = rearrange(
+ img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads
+ )
+ img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
+
+ # prepare txt for attention
+ txt_modulated = attn.txt_norm1(txt)
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
+ txt_qkv = (
+ attn.txt_attn.qkv(txt_modulated)
+ + self.qkv_lora2(txt_modulated) * self.lora_weight
+ )
+ txt_q, txt_k, txt_v = rearrange(
+ txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads
+ )
+ txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
+
+ # run actual attention
+ q = torch.cat((txt_q, img_q), dim=2)
+ k = torch.cat((txt_k, img_k), dim=2)
+ v = torch.cat((txt_v, img_v), dim=2)
+
+ attn1 = attention(q, k, v, pe=pe)
+ txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
+
+ # calculate the img bloks
+ img = img + img_mod1.gate * (
+ attn.img_attn.proj(img_attn) + self.proj_lora1(img_attn) * self.lora_weight
+ )
+ img = img + img_mod2.gate * attn.img_mlp(
+ (1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift
+ )
+
+ # calculate the txt bloks
+ txt = txt + txt_mod1.gate * (
+ attn.txt_attn.proj(txt_attn) + self.proj_lora2(txt_attn) * self.lora_weight
+ )
+ txt = txt + txt_mod2.gate * attn.txt_mlp(
+ (1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift
+ )
+ return img, txt
+
+
+class DoubleStreamBlockProcessor:
+ def __call__(self, attn, img, txt, vec, pe, **attention_kwargs):
+ img_mod1, img_mod2 = attn.img_mod(vec)
+ txt_mod1, txt_mod2 = attn.txt_mod(vec)
+
+ # prepare image for attention
+ img_modulated = attn.img_norm1(img)
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
+ img_qkv = attn.img_attn.qkv(img_modulated)
+ img_q, img_k, img_v = rearrange(
+ img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim
+ )
+ img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
+
+ # prepare txt for attention
+ txt_modulated = attn.txt_norm1(txt)
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
+ txt_qkv = attn.txt_attn.qkv(txt_modulated)
+ txt_q, txt_k, txt_v = rearrange(
+ txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim
+ )
+ txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
+
+ # run actual attention
+ q = torch.cat((txt_q, img_q), dim=2)
+ k = torch.cat((txt_k, img_k), dim=2)
+ v = torch.cat((txt_v, img_v), dim=2)
+
+ attn1 = attention(q, k, v, pe=pe)
+ txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
+
+ # calculate the img bloks
+ img = img + img_mod1.gate * attn.img_attn.proj(img_attn)
+ img = img + img_mod2.gate * attn.img_mlp(
+ (1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift
+ )
+
+ # calculate the txt bloks
+ txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn)
+ txt = txt + txt_mod2.gate * attn.txt_mlp(
+ (1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift
+ )
+ return img, txt
+
+
+class DoubleStreamBlock(nn.Module):
+ def __init__(
+ self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False
+ ):
+ super().__init__()
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ self.num_heads = num_heads
+ self.hidden_size = hidden_size
+ self.head_dim = hidden_size // num_heads
+
+ self.img_mod = Modulation(hidden_size, double=True)
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.img_attn = SelfAttention(
+ dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
+ )
+
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.img_mlp = nn.Sequential(
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
+ nn.GELU(approximate="tanh"),
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
+ )
+
+ self.txt_mod = Modulation(hidden_size, double=True)
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.txt_attn = SelfAttention(
+ dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
+ )
+
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.txt_mlp = nn.Sequential(
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
+ nn.GELU(approximate="tanh"),
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
+ )
+ processor = DoubleStreamBlockProcessor()
+ self.set_processor(processor)
+
+ def set_processor(self, processor) -> None:
+ self.processor = processor
+
+ def get_processor(self):
+ return self.processor
+
+ def forward(
+ self,
+ img: Tensor,
+ txt: Tensor,
+ vec: Tensor,
+ pe: Tensor,
+ image_proj: Tensor = None,
+ ip_scale: float = 1.0,
+ ) -> tuple[Tensor, Tensor]:
+ if image_proj is None:
+ return self.processor(self, img, txt, vec, pe)
+ else:
+ return self.processor(self, img, txt, vec, pe, image_proj, ip_scale)
+
+
+class SingleStreamBlockLoraProcessor(nn.Module):
+ def __init__(
+ self, dim: int, rank: int = 4, network_alpha=None, lora_weight: float = 1
+ ):
+ super().__init__()
+ self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
+ self.proj_lora = LoRALinearLayer(15360, dim, rank, network_alpha)
+ self.lora_weight = lora_weight
+
+ def forward(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
+
+ mod, _ = attn.modulation(vec)
+ x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift
+ qkv, mlp = torch.split(
+ attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1
+ )
+ qkv = qkv + self.qkv_lora(x_mod) * self.lora_weight
+
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
+ q, k = attn.norm(q, k, v)
+
+ # compute attention
+ attn_1 = attention(q, k, v, pe=pe)
+
+ # compute activation in mlp stream, cat again and run second linear layer
+ output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2))
+ output = (
+ output
+ + self.proj_lora(torch.cat((attn_1, attn.mlp_act(mlp)), 2))
+ * self.lora_weight
+ )
+ output = x + mod.gate * output
+ return output
+
+
+class SingleStreamBlockProcessor:
+ def __call__(
+ self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor, **attention_kwargs
+ ) -> Tensor:
+
+ mod, _ = attn.modulation(vec)
+ x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift
+ qkv, mlp = torch.split(
+ attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1
+ )
+
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
+ q, k = attn.norm(q, k, v)
+
+ # compute attention
+ attn_1 = attention(q, k, v, pe=pe)
+
+ # compute activation in mlp stream, cat again and run second linear layer
+ output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2))
+ output = x + mod.gate * output
+ return output
+
+
+class SingleStreamBlock(nn.Module):
+ """
+ A DiT block with parallel linear layers as described in
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
+ """
+
+ def __init__(
+ self,
+ hidden_size: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qk_scale: float | None = None,
+ ):
+ super().__init__()
+ self.hidden_dim = hidden_size
+ self.num_heads = num_heads
+ self.head_dim = hidden_size // num_heads
+ self.scale = qk_scale or self.head_dim**-0.5
+
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ # qkv and mlp_in
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
+ # proj and mlp_out
+ self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
+
+ self.norm = QKNorm(self.head_dim)
+
+ self.hidden_size = hidden_size
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+
+ self.mlp_act = nn.GELU(approximate="tanh")
+ self.modulation = Modulation(hidden_size, double=False)
+
+ processor = SingleStreamBlockProcessor()
+ self.set_processor(processor)
+
+ def set_processor(self, processor) -> None:
+ self.processor = processor
+
+ def get_processor(self):
+ return self.processor
+
+ def forward(
+ self,
+ x: Tensor,
+ vec: Tensor,
+ pe: Tensor,
+ image_proj: Tensor | None = None,
+ ip_scale: float = 1.0,
+ ) -> Tensor:
+ if image_proj is None:
+ return self.processor(self, x, vec, pe)
+ else:
+ return self.processor(self, x, vec, pe, image_proj, ip_scale)
+
+
+class LastLayer(nn.Module):
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
+ super().__init__()
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.linear = nn.Linear(
+ hidden_size, patch_size * patch_size * out_channels, bias=True
+ )
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
+ )
+
+ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
+ x = self.linear(x)
+ return x
+
+
+class SigLIPMultiFeatProjModel(torch.nn.Module):
+ """
+ SigLIP Multi-Feature Projection Model for processing style features from different layers
+ and projecting them into a unified hidden space.
+
+ Args:
+ siglip_token_nums (int): Number of SigLIP tokens, default 257
+ style_token_nums (int): Number of style tokens, default 256
+ siglip_token_dims (int): Dimension of SigLIP tokens, default 1536
+ hidden_size (int): Hidden layer size, default 3072
+ context_layer_norm (bool): Whether to use context layer normalization, default False
+ """
+
+ def __init__(
+ self,
+ siglip_token_nums: int = 257,
+ style_token_nums: int = 256,
+ siglip_token_dims: int = 1536,
+ hidden_size: int = 3072,
+ context_layer_norm: bool = False,
+ ):
+ super().__init__()
+
+ # High-level feature processing (layer -2)
+ self.high_embedding_linear = nn.Sequential(
+ nn.Linear(siglip_token_nums, style_token_nums),
+ nn.SiLU()
+ )
+ self.high_layer_norm = (
+ nn.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity()
+ )
+ self.high_projection = nn.Linear(siglip_token_dims, hidden_size, bias=True)
+
+ # Mid-level feature processing (layer -11)
+ self.mid_embedding_linear = nn.Sequential(
+ nn.Linear(siglip_token_nums, style_token_nums),
+ nn.SiLU()
+ )
+ self.mid_layer_norm = (
+ nn.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity()
+ )
+ self.mid_projection = nn.Linear(siglip_token_dims, hidden_size, bias=True)
+
+ # Low-level feature processing (layer -20)
+ self.low_embedding_linear = nn.Sequential(
+ nn.Linear(siglip_token_nums, style_token_nums),
+ nn.SiLU()
+ )
+ self.low_layer_norm = (
+ nn.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity()
+ )
+ self.low_projection = nn.Linear(siglip_token_dims, hidden_size, bias=True)
+
+ def forward(self, siglip_outputs):
+ """
+ Forward pass function
+
+ Args:
+ siglip_outputs: Output from SigLIP model, containing hidden_states
+
+ Returns:
+ torch.Tensor: Concatenated multi-layer features with shape [bs, 3*style_token_nums, hidden_size]
+ """
+ dtype = next(self.high_embedding_linear.parameters()).dtype
+
+ # Process high-level features (layer -2)
+ high_embedding = self._process_layer_features(
+ siglip_outputs.hidden_states[-2],
+ self.high_embedding_linear,
+ self.high_layer_norm,
+ self.high_projection,
+ dtype
+ )
+
+ # Process mid-level features (layer -11)
+ mid_embedding = self._process_layer_features(
+ siglip_outputs.hidden_states[-11],
+ self.mid_embedding_linear,
+ self.mid_layer_norm,
+ self.mid_projection,
+ dtype
+ )
+
+ # Process low-level features (layer -20)
+ low_embedding = self._process_layer_features(
+ siglip_outputs.hidden_states[-20],
+ self.low_embedding_linear,
+ self.low_layer_norm,
+ self.low_projection,
+ dtype
+ )
+
+ # Concatenate features from all layers
+ return torch.cat((high_embedding, mid_embedding, low_embedding), dim=1)
+
+ def _process_layer_features(
+ self,
+ hidden_states: torch.Tensor,
+ embedding_linear: nn.Module,
+ layer_norm: nn.Module,
+ projection: nn.Module,
+ dtype: torch.dtype
+ ) -> torch.Tensor:
+ """
+ Helper function to process features from a single layer
+
+ Args:
+ hidden_states: Input hidden states [bs, seq_len, dim]
+ embedding_linear: Embedding linear layer
+ layer_norm: Layer normalization
+ projection: Projection layer
+ dtype: Target data type
+
+ Returns:
+ torch.Tensor: Processed features [bs, style_token_nums, hidden_size]
+ """
+ # Transform dimensions: [bs, seq_len, dim] -> [bs, dim, seq_len] -> [bs, dim, style_token_nums] -> [bs, style_token_nums, dim]
+ embedding = embedding_linear(
+ hidden_states.to(dtype).transpose(1, 2)
+ ).transpose(1, 2)
+
+ # Apply layer normalization
+ embedding = layer_norm(embedding)
+
+ # Project to target hidden space
+ embedding = projection(embedding)
+
+ return embedding
diff --git a/uso/flux/pipeline.py b/uso/flux/pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..15be80cb624cbc9403e520d1c3fed798d995602c
--- /dev/null
+++ b/uso/flux/pipeline.py
@@ -0,0 +1,390 @@
+# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
+# Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import math
+from typing import Literal, Optional
+from torch import Tensor
+
+import torch
+from einops import rearrange
+from PIL import ExifTags, Image
+import torchvision.transforms.functional as TVF
+
+from uso.flux.modules.layers import (
+ DoubleStreamBlockLoraProcessor,
+ DoubleStreamBlockProcessor,
+ SingleStreamBlockLoraProcessor,
+ SingleStreamBlockProcessor,
+)
+from uso.flux.sampling import denoise, get_noise, get_schedule, prepare_multi_ip, unpack
+from uso.flux.util import (
+ get_lora_rank,
+ load_ae,
+ load_checkpoint,
+ load_clip,
+ load_flow_model,
+ load_flow_model_only_lora,
+ load_t5,
+)
+
+
+def find_nearest_scale(image_h, image_w, predefined_scales):
+ """
+ 根据图片的高度和宽度,找到最近的预定义尺度。
+
+ :param image_h: 图片的高度
+ :param image_w: 图片的宽度
+ :param predefined_scales: 预定义尺度列表 [(h1, w1), (h2, w2), ...]
+ :return: 最近的预定义尺度 (h, w)
+ """
+ # 计算输入图片的长宽比
+ image_ratio = image_h / image_w
+
+ # 初始化变量以存储最小差异和最近的尺度
+ min_diff = float("inf")
+ nearest_scale = None
+
+ # 遍历所有预定义尺度,找到与输入图片长宽比最接近的尺度
+ for scale_h, scale_w in predefined_scales:
+ predefined_ratio = scale_h / scale_w
+ diff = abs(predefined_ratio - image_ratio)
+
+ if diff < min_diff:
+ min_diff = diff
+ nearest_scale = (scale_h, scale_w)
+
+ return nearest_scale
+
+
+def preprocess_ref(raw_image: Image.Image, long_size: int = 512, scale_ratio: int = 1):
+ # 获取原始图像的宽度和高度
+ image_w, image_h = raw_image.size
+ if image_w == image_h and image_w == 16:
+ return raw_image
+
+ # 计算长边和短边
+ if image_w >= image_h:
+ new_w = long_size
+ new_h = int((long_size / image_w) * image_h)
+ else:
+ new_h = long_size
+ new_w = int((long_size / image_h) * image_w)
+
+ # 按新的宽高进行等比例缩放
+ raw_image = raw_image.resize((new_w, new_h), resample=Image.LANCZOS)
+
+ # 为了能让canny img进行scale
+ scale_ratio = int(scale_ratio)
+ target_w = new_w // (16 * scale_ratio) * (16 * scale_ratio)
+ target_h = new_h // (16 * scale_ratio) * (16 * scale_ratio)
+
+ # 计算裁剪的起始坐标以实现中心裁剪
+ left = (new_w - target_w) // 2
+ top = (new_h - target_h) // 2
+ right = left + target_w
+ bottom = top + target_h
+
+ # 进行中心裁剪
+ raw_image = raw_image.crop((left, top, right, bottom))
+
+ # 转换为 RGB 模式
+ raw_image = raw_image.convert("RGB")
+ return raw_image
+
+
+def resize_and_centercrop_image(image, target_height_ref1, target_width_ref1):
+ target_height_ref1 = int(target_height_ref1 // 64 * 64)
+ target_width_ref1 = int(target_width_ref1 // 64 * 64)
+ h, w = image.shape[-2:]
+ if h < target_height_ref1 or w < target_width_ref1:
+ # 计算长宽比
+ aspect_ratio = w / h
+ if h < target_height_ref1:
+ new_h = target_height_ref1
+ new_w = new_h * aspect_ratio
+ if new_w < target_width_ref1:
+ new_w = target_width_ref1
+ new_h = new_w / aspect_ratio
+ else:
+ new_w = target_width_ref1
+ new_h = new_w / aspect_ratio
+ if new_h < target_height_ref1:
+ new_h = target_height_ref1
+ new_w = new_h * aspect_ratio
+ else:
+ aspect_ratio = w / h
+ tgt_aspect_ratio = target_width_ref1 / target_height_ref1
+ if aspect_ratio > tgt_aspect_ratio:
+ new_h = target_height_ref1
+ new_w = new_h * aspect_ratio
+ else:
+ new_w = target_width_ref1
+ new_h = new_w / aspect_ratio
+ # 使用 TVF.resize 进行图像缩放
+ image = TVF.resize(image, (math.ceil(new_h), math.ceil(new_w)))
+ # 计算中心裁剪的参数
+ top = (image.shape[-2] - target_height_ref1) // 2
+ left = (image.shape[-1] - target_width_ref1) // 2
+ # 使用 TVF.crop 进行中心裁剪
+ image = TVF.crop(image, top, left, target_height_ref1, target_width_ref1)
+ return image
+
+
+class USOPipeline:
+ def __init__(
+ self,
+ model_type: str,
+ device: torch.device,
+ offload: bool = False,
+ only_lora: bool = False,
+ lora_rank: int = 16,
+ hf_download: bool = True,
+ ):
+ self.device = device
+ self.offload = offload
+ self.model_type = model_type
+
+ self.clip = load_clip(self.device)
+ self.t5 = load_t5(self.device, max_length=512)
+ self.ae = load_ae(model_type, device="cpu" if offload else self.device)
+ self.use_fp8 = "fp8" in model_type
+ if only_lora:
+ self.model = load_flow_model_only_lora(
+ model_type,
+ device="cpu" if offload else self.device,
+ lora_rank=lora_rank,
+ use_fp8=self.use_fp8,
+ hf_download=hf_download,
+ )
+ else:
+ self.model = load_flow_model(
+ model_type, device="cpu" if offload else self.device
+ )
+
+ def load_ckpt(self, ckpt_path):
+ if ckpt_path is not None:
+ from safetensors.torch import load_file as load_sft
+
+ print("Loading checkpoint to replace old keys")
+ # load_sft doesn't support torch.device
+ if ckpt_path.endswith("safetensors"):
+ sd = load_sft(ckpt_path, device="cpu")
+ missing, unexpected = self.model.load_state_dict(
+ sd, strict=False, assign=True
+ )
+ else:
+ dit_state = torch.load(ckpt_path, map_location="cpu")
+ sd = {}
+ for k in dit_state.keys():
+ sd[k.replace("module.", "")] = dit_state[k]
+ missing, unexpected = self.model.load_state_dict(
+ sd, strict=False, assign=True
+ )
+ self.model.to(str(self.device))
+ print(f"missing keys: {missing}\n\n\n\n\nunexpected keys: {unexpected}")
+
+ def set_lora(
+ self,
+ local_path: str = None,
+ repo_id: str = None,
+ name: str = None,
+ lora_weight: int = 0.7,
+ ):
+ checkpoint = load_checkpoint(local_path, repo_id, name)
+ self.update_model_with_lora(checkpoint, lora_weight)
+
+ def set_lora_from_collection(
+ self, lora_type: str = "realism", lora_weight: int = 0.7
+ ):
+ checkpoint = load_checkpoint(
+ None, self.hf_lora_collection, self.lora_types_to_names[lora_type]
+ )
+ self.update_model_with_lora(checkpoint, lora_weight)
+
+ def update_model_with_lora(self, checkpoint, lora_weight):
+ rank = get_lora_rank(checkpoint)
+ lora_attn_procs = {}
+
+ for name, _ in self.model.attn_processors.items():
+ lora_state_dict = {}
+ for k in checkpoint.keys():
+ if name in k:
+ lora_state_dict[k[len(name) + 1 :]] = checkpoint[k] * lora_weight
+
+ if len(lora_state_dict):
+ if name.startswith("single_blocks"):
+ lora_attn_procs[name] = SingleStreamBlockLoraProcessor(
+ dim=3072, rank=rank
+ )
+ else:
+ lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(
+ dim=3072, rank=rank
+ )
+ lora_attn_procs[name].load_state_dict(lora_state_dict)
+ lora_attn_procs[name].to(self.device)
+ else:
+ if name.startswith("single_blocks"):
+ lora_attn_procs[name] = SingleStreamBlockProcessor()
+ else:
+ lora_attn_procs[name] = DoubleStreamBlockProcessor()
+
+ self.model.set_attn_processor(lora_attn_procs)
+
+ def __call__(
+ self,
+ prompt: str,
+ width: int = 512,
+ height: int = 512,
+ guidance: float = 4,
+ num_steps: int = 50,
+ seed: int = 123456789,
+ **kwargs,
+ ):
+ width = 16 * (width // 16)
+ height = 16 * (height // 16)
+
+ device_type = self.device if isinstance(self.device, str) else self.device.type
+ with torch.autocast(
+ enabled=self.use_fp8, device_type=device_type, dtype=torch.bfloat16
+ ):
+ return self.forward(
+ prompt, width, height, guidance, num_steps, seed, **kwargs
+ )
+
+ @torch.inference_mode()
+ def gradio_generate(
+ self,
+ prompt: str,
+ image_prompt1: Image.Image,
+ image_prompt2: Image.Image,
+ image_prompt3: Image.Image,
+ seed: int,
+ width: int = 1024,
+ height: int = 1024,
+ guidance: float = 4,
+ num_steps: int = 25,
+ keep_size: bool = False,
+ content_long_size: int = 512,
+ ):
+ ref_content_imgs = [image_prompt1]
+ ref_content_imgs = [img for img in ref_content_imgs if isinstance(img, Image.Image)]
+ ref_content_imgs = [preprocess_ref(img, content_long_size) for img in ref_content_imgs]
+
+ ref_style_imgs = [image_prompt2, image_prompt3]
+ ref_style_imgs = [img for img in ref_style_imgs if isinstance(img, Image.Image)]
+ ref_style_imgs = [self.model.vision_encoder_processor(img, return_tensors="pt").to(self.device) for img in ref_style_imgs]
+
+ seed = seed if seed != -1 else torch.randint(0, 10**8, (1,)).item()
+
+ # whether keep input image size
+ if keep_size and len(ref_content_imgs)>0:
+ width, height = ref_content_imgs[0].size
+ width, height = int(width * (1024 / content_long_size)), int(height * (1024 / content_long_size))
+ img = self(
+ prompt=prompt,
+ width=width,
+ height=height,
+ guidance=guidance,
+ num_steps=num_steps,
+ seed=seed,
+ ref_imgs=ref_content_imgs,
+ siglip_inputs=ref_style_imgs,
+ )
+
+ filename = f"output/gradio/{seed}_{prompt[:20]}.png"
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
+ exif_data = Image.Exif()
+ exif_data[ExifTags.Base.Make] = "USO"
+ exif_data[ExifTags.Base.Model] = self.model_type
+ info = f"{prompt=}, {seed=}, {width=}, {height=}, {guidance=}, {num_steps=}"
+ exif_data[ExifTags.Base.ImageDescription] = info
+ img.save(filename, format="png", exif=exif_data)
+ return img, filename
+
+ @torch.inference_mode
+ def forward(
+ self,
+ prompt: str,
+ width: int,
+ height: int,
+ guidance: float,
+ num_steps: int,
+ seed: int,
+ ref_imgs: list[Image.Image] | None = None,
+ pe: Literal["d", "h", "w", "o"] = "d",
+ siglip_inputs: list[Tensor] | None = None,
+ ):
+ x = get_noise(
+ 1, height, width, device=self.device, dtype=torch.bfloat16, seed=seed
+ )
+ timesteps = get_schedule(
+ num_steps,
+ (width // 8) * (height // 8) // (16 * 16),
+ shift=True,
+ )
+ if self.offload:
+ self.ae.encoder = self.ae.encoder.to(self.device)
+ x_1_refs = [
+ self.ae.encode(
+ (TVF.to_tensor(ref_img) * 2.0 - 1.0)
+ .unsqueeze(0)
+ .to(self.device, torch.float32)
+ ).to(torch.bfloat16)
+ for ref_img in ref_imgs
+ ]
+
+ if self.offload:
+ self.offload_model_to_cpu(self.ae.encoder)
+ self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device)
+ inp_cond = prepare_multi_ip(
+ t5=self.t5,
+ clip=self.clip,
+ img=x,
+ prompt=prompt,
+ ref_imgs=x_1_refs,
+ pe=pe,
+ )
+
+ if self.offload:
+ self.offload_model_to_cpu(self.t5, self.clip)
+ self.model = self.model.to(self.device)
+
+ x = denoise(
+ self.model,
+ **inp_cond,
+ timesteps=timesteps,
+ guidance=guidance,
+ siglip_inputs=siglip_inputs,
+ )
+
+ if self.offload:
+ self.offload_model_to_cpu(self.model)
+ self.ae.decoder.to(x.device)
+ x = unpack(x.float(), height, width)
+ x = self.ae.decode(x)
+ self.offload_model_to_cpu(self.ae.decoder)
+
+ x1 = x.clamp(-1, 1)
+ x1 = rearrange(x1[-1], "c h w -> h w c")
+ output_img = Image.fromarray((127.5 * (x1 + 1.0)).cpu().byte().numpy())
+ return output_img
+
+ def offload_model_to_cpu(self, *models):
+ if not self.offload:
+ return
+ for model in models:
+ model.cpu()
+ torch.cuda.empty_cache()
diff --git a/uso/flux/sampling.py b/uso/flux/sampling.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4611fa8bef1c446d760a05a61f6d086dc5a8ad0
--- /dev/null
+++ b/uso/flux/sampling.py
@@ -0,0 +1,274 @@
+# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
+# Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Literal
+
+import torch
+from einops import rearrange, repeat
+from torch import Tensor
+from tqdm import tqdm
+
+from .model import Flux
+from .modules.conditioner import HFEmbedder
+
+
+def get_noise(
+ num_samples: int,
+ height: int,
+ width: int,
+ device: torch.device,
+ dtype: torch.dtype,
+ seed: int,
+):
+ return torch.randn(
+ num_samples,
+ 16,
+ # allow for packing
+ 2 * math.ceil(height / 16),
+ 2 * math.ceil(width / 16),
+ device=device,
+ dtype=dtype,
+ generator=torch.Generator(device=device).manual_seed(seed),
+ )
+
+
+def prepare(
+ t5: HFEmbedder,
+ clip: HFEmbedder,
+ img: Tensor,
+ prompt: str | list[str],
+ ref_img: None | Tensor = None,
+ pe: Literal["d", "h", "w", "o"] = "d",
+) -> dict[str, Tensor]:
+ assert pe in ["d", "h", "w", "o"]
+ bs, c, h, w = img.shape
+ if bs == 1 and not isinstance(prompt, str):
+ bs = len(prompt)
+
+ img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
+ if img.shape[0] == 1 and bs > 1:
+ img = repeat(img, "1 ... -> bs ...", bs=bs)
+
+ img_ids = torch.zeros(h // 2, w // 2, 3)
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
+
+ if ref_img is not None:
+ _, _, ref_h, ref_w = ref_img.shape
+ ref_img = rearrange(
+ ref_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2
+ )
+ if ref_img.shape[0] == 1 and bs > 1:
+ ref_img = repeat(ref_img, "1 ... -> bs ...", bs=bs)
+ ref_img_ids = torch.zeros(ref_h // 2, ref_w // 2, 3)
+ # img id分别在宽高偏移各自最大值
+ h_offset = h // 2 if pe in {"d", "h"} else 0
+ w_offset = w // 2 if pe in {"d", "w"} else 0
+ ref_img_ids[..., 1] = (
+ ref_img_ids[..., 1] + torch.arange(ref_h // 2)[:, None] + h_offset
+ )
+ ref_img_ids[..., 2] = (
+ ref_img_ids[..., 2] + torch.arange(ref_w // 2)[None, :] + w_offset
+ )
+ ref_img_ids = repeat(ref_img_ids, "h w c -> b (h w) c", b=bs)
+
+ if isinstance(prompt, str):
+ prompt = [prompt]
+ txt = t5(prompt)
+ if txt.shape[0] == 1 and bs > 1:
+ txt = repeat(txt, "1 ... -> bs ...", bs=bs)
+ txt_ids = torch.zeros(bs, txt.shape[1], 3)
+
+ vec = clip(prompt)
+ if vec.shape[0] == 1 and bs > 1:
+ vec = repeat(vec, "1 ... -> bs ...", bs=bs)
+
+ if ref_img is not None:
+ return {
+ "img": img,
+ "img_ids": img_ids.to(img.device),
+ "ref_img": ref_img,
+ "ref_img_ids": ref_img_ids.to(img.device),
+ "txt": txt.to(img.device),
+ "txt_ids": txt_ids.to(img.device),
+ "vec": vec.to(img.device),
+ }
+ else:
+ return {
+ "img": img,
+ "img_ids": img_ids.to(img.device),
+ "txt": txt.to(img.device),
+ "txt_ids": txt_ids.to(img.device),
+ "vec": vec.to(img.device),
+ }
+
+
+def prepare_multi_ip(
+ t5: HFEmbedder,
+ clip: HFEmbedder,
+ img: Tensor,
+ prompt: str | list[str],
+ ref_imgs: list[Tensor] | None = None,
+ pe: Literal["d", "h", "w", "o"] = "d",
+) -> dict[str, Tensor]:
+ assert pe in ["d", "h", "w", "o"]
+ bs, c, h, w = img.shape
+ if bs == 1 and not isinstance(prompt, str):
+ bs = len(prompt)
+
+ # tgt img
+ img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
+ if img.shape[0] == 1 and bs > 1:
+ img = repeat(img, "1 ... -> bs ...", bs=bs)
+
+ img_ids = torch.zeros(h // 2, w // 2, 3)
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
+
+ ref_img_ids = []
+ ref_imgs_list = []
+
+ pe_shift_w, pe_shift_h = w // 2, h // 2
+ for ref_img in ref_imgs:
+ _, _, ref_h1, ref_w1 = ref_img.shape
+ ref_img = rearrange(
+ ref_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2
+ )
+ if ref_img.shape[0] == 1 and bs > 1:
+ ref_img = repeat(ref_img, "1 ... -> bs ...", bs=bs)
+ ref_img_ids1 = torch.zeros(ref_h1 // 2, ref_w1 // 2, 3)
+ # img id分别在宽高偏移各自最大值
+ h_offset = pe_shift_h if pe in {"d", "h"} else 0
+ w_offset = pe_shift_w if pe in {"d", "w"} else 0
+ ref_img_ids1[..., 1] = (
+ ref_img_ids1[..., 1] + torch.arange(ref_h1 // 2)[:, None] + h_offset
+ )
+ ref_img_ids1[..., 2] = (
+ ref_img_ids1[..., 2] + torch.arange(ref_w1 // 2)[None, :] + w_offset
+ )
+ ref_img_ids1 = repeat(ref_img_ids1, "h w c -> b (h w) c", b=bs)
+ ref_img_ids.append(ref_img_ids1)
+ ref_imgs_list.append(ref_img)
+
+ # 更新pe shift
+ pe_shift_h += ref_h1 // 2
+ pe_shift_w += ref_w1 // 2
+
+ if isinstance(prompt, str):
+ prompt = [prompt]
+ txt = t5(prompt)
+ if txt.shape[0] == 1 and bs > 1:
+ txt = repeat(txt, "1 ... -> bs ...", bs=bs)
+ txt_ids = torch.zeros(bs, txt.shape[1], 3)
+
+ vec = clip(prompt)
+ if vec.shape[0] == 1 and bs > 1:
+ vec = repeat(vec, "1 ... -> bs ...", bs=bs)
+
+ return {
+ "img": img,
+ "img_ids": img_ids.to(img.device),
+ "ref_img": tuple(ref_imgs_list),
+ "ref_img_ids": [ref_img_id.to(img.device) for ref_img_id in ref_img_ids],
+ "txt": txt.to(img.device),
+ "txt_ids": txt_ids.to(img.device),
+ "vec": vec.to(img.device),
+ }
+
+
+def time_shift(mu: float, sigma: float, t: Tensor):
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
+
+
+def get_lin_function(
+ x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
+):
+ m = (y2 - y1) / (x2 - x1)
+ b = y1 - m * x1
+ return lambda x: m * x + b
+
+
+def get_schedule(
+ num_steps: int,
+ image_seq_len: int,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+ shift: bool = True,
+) -> list[float]:
+ # extra step for zero
+ timesteps = torch.linspace(1, 0, num_steps + 1)
+
+ # shifting the schedule to favor high timesteps for higher signal images
+ if shift:
+ # eastimate mu based on linear estimation between two points
+ mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
+ timesteps = time_shift(mu, 1.0, timesteps)
+
+ return timesteps.tolist()
+
+
+def denoise(
+ model: Flux,
+ # model input
+ img: Tensor,
+ img_ids: Tensor,
+ txt: Tensor,
+ txt_ids: Tensor,
+ vec: Tensor,
+ # sampling parameters
+ timesteps: list[float],
+ guidance: float = 4.0,
+ ref_img: Tensor = None,
+ ref_img_ids: Tensor = None,
+ siglip_inputs: list[Tensor] | None = None,
+):
+ i = 0
+ guidance_vec = torch.full(
+ (img.shape[0],), guidance, device=img.device, dtype=img.dtype
+ )
+ for t_curr, t_prev in tqdm(
+ zip(timesteps[:-1], timesteps[1:]), total=len(timesteps) - 1
+ ):
+ # for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
+ t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
+ pred = model(
+ img=img,
+ img_ids=img_ids,
+ ref_img=ref_img,
+ ref_img_ids=ref_img_ids,
+ txt=txt,
+ txt_ids=txt_ids,
+ y=vec,
+ timesteps=t_vec,
+ guidance=guidance_vec,
+ siglip_inputs=siglip_inputs,
+ )
+ img = img + (t_prev - t_curr) * pred
+ i += 1
+ return img
+
+
+def unpack(x: Tensor, height: int, width: int) -> Tensor:
+ return rearrange(
+ x,
+ "b (h w) (c ph pw) -> b c (h ph) (w pw)",
+ h=math.ceil(height / 16),
+ w=math.ceil(width / 16),
+ ph=2,
+ pw=2,
+ )
diff --git a/uso/flux/util.py b/uso/flux/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..38e40b64885b8d3ab2c466c5cd19368463caf09a
--- /dev/null
+++ b/uso/flux/util.py
@@ -0,0 +1,511 @@
+# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
+# Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from dataclasses import dataclass
+
+import torch
+import json
+import numpy as np
+from huggingface_hub import hf_hub_download
+from safetensors import safe_open
+from safetensors.torch import load_file as load_sft
+
+from .model import Flux, FluxParams
+from .modules.autoencoder import AutoEncoder, AutoEncoderParams
+from .modules.conditioner import HFEmbedder
+
+import re
+from uso.flux.modules.layers import (
+ DoubleStreamBlockLoraProcessor,
+ SingleStreamBlockLoraProcessor,
+)
+
+
+def load_model(ckpt, device="cpu"):
+ if ckpt.endswith("safetensors"):
+ from safetensors import safe_open
+
+ pl_sd = {}
+ with safe_open(ckpt, framework="pt", device=device) as f:
+ for k in f.keys():
+ pl_sd[k] = f.get_tensor(k)
+ else:
+ pl_sd = torch.load(ckpt, map_location=device)
+ return pl_sd
+
+
+def load_safetensors(path):
+ tensors = {}
+ with safe_open(path, framework="pt", device="cpu") as f:
+ for key in f.keys():
+ tensors[key] = f.get_tensor(key)
+ return tensors
+
+
+def get_lora_rank(checkpoint):
+ for k in checkpoint.keys():
+ if k.endswith(".down.weight"):
+ return checkpoint[k].shape[0]
+
+
+def load_checkpoint(local_path, repo_id, name):
+ if local_path is not None:
+ if ".safetensors" in local_path:
+ print(f"Loading .safetensors checkpoint from {local_path}")
+ checkpoint = load_safetensors(local_path)
+ else:
+ print(f"Loading checkpoint from {local_path}")
+ checkpoint = torch.load(local_path, map_location="cpu")
+ elif repo_id is not None and name is not None:
+ print(f"Loading checkpoint {name} from repo id {repo_id}")
+ checkpoint = load_from_repo_id(repo_id, name)
+ else:
+ raise ValueError(
+ "LOADING ERROR: you must specify local_path or repo_id with name in HF to download"
+ )
+ return checkpoint
+
+
+def c_crop(image):
+ width, height = image.size
+ new_size = min(width, height)
+ left = (width - new_size) / 2
+ top = (height - new_size) / 2
+ right = (width + new_size) / 2
+ bottom = (height + new_size) / 2
+ return image.crop((left, top, right, bottom))
+
+
+def pad64(x):
+ return int(np.ceil(float(x) / 64.0) * 64 - x)
+
+
+def HWC3(x):
+ assert x.dtype == np.uint8
+ if x.ndim == 2:
+ x = x[:, :, None]
+ assert x.ndim == 3
+ H, W, C = x.shape
+ assert C == 1 or C == 3 or C == 4
+ if C == 3:
+ return x
+ if C == 1:
+ return np.concatenate([x, x, x], axis=2)
+ if C == 4:
+ color = x[:, :, 0:3].astype(np.float32)
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
+ y = color * alpha + 255.0 * (1.0 - alpha)
+ y = y.clip(0, 255).astype(np.uint8)
+ return y
+
+
+@dataclass
+class ModelSpec:
+ params: FluxParams
+ ae_params: AutoEncoderParams
+ ckpt_path: str | None
+ ae_path: str | None
+ repo_id: str | None
+ repo_flow: str | None
+ repo_ae: str | None
+ repo_id_ae: str | None
+
+
+configs = {
+ "flux-dev": ModelSpec(
+ repo_id="black-forest-labs/FLUX.1-dev",
+ repo_id_ae="black-forest-labs/FLUX.1-dev",
+ repo_flow="flux1-dev.safetensors",
+ repo_ae="ae.safetensors",
+ ckpt_path=os.getenv("FLUX_DEV"),
+ params=FluxParams(
+ in_channels=64,
+ vec_in_dim=768,
+ context_in_dim=4096,
+ hidden_size=3072,
+ mlp_ratio=4.0,
+ num_heads=24,
+ depth=19,
+ depth_single_blocks=38,
+ axes_dim=[16, 56, 56],
+ theta=10_000,
+ qkv_bias=True,
+ guidance_embed=True,
+ ),
+ ae_path=os.getenv("AE"),
+ ae_params=AutoEncoderParams(
+ resolution=256,
+ in_channels=3,
+ ch=128,
+ out_ch=3,
+ ch_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ z_channels=16,
+ scale_factor=0.3611,
+ shift_factor=0.1159,
+ ),
+ ),
+ "flux-dev-fp8": ModelSpec(
+ repo_id="black-forest-labs/FLUX.1-dev",
+ repo_id_ae="black-forest-labs/FLUX.1-dev",
+ repo_flow="flux1-dev.safetensors",
+ repo_ae="ae.safetensors",
+ ckpt_path=os.getenv("FLUX_DEV_FP8"),
+ params=FluxParams(
+ in_channels=64,
+ vec_in_dim=768,
+ context_in_dim=4096,
+ hidden_size=3072,
+ mlp_ratio=4.0,
+ num_heads=24,
+ depth=19,
+ depth_single_blocks=38,
+ axes_dim=[16, 56, 56],
+ theta=10_000,
+ qkv_bias=True,
+ guidance_embed=True,
+ ),
+ ae_path=os.getenv("AE"),
+ ae_params=AutoEncoderParams(
+ resolution=256,
+ in_channels=3,
+ ch=128,
+ out_ch=3,
+ ch_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ z_channels=16,
+ scale_factor=0.3611,
+ shift_factor=0.1159,
+ ),
+ ),
+ "flux-krea-dev": ModelSpec(
+ repo_id="black-forest-labs/FLUX.1-Krea-dev",
+ repo_id_ae="black-forest-labs/FLUX.1-Krea-dev",
+ repo_flow="flux1-krea-dev.safetensors",
+ repo_ae="ae.safetensors",
+ ckpt_path=os.getenv("FLUX_KREA_DEV"),
+ params=FluxParams(
+ in_channels=64,
+ vec_in_dim=768,
+ context_in_dim=4096,
+ hidden_size=3072,
+ mlp_ratio=4.0,
+ num_heads=24,
+ depth=19,
+ depth_single_blocks=38,
+ axes_dim=[16, 56, 56],
+ theta=10_000,
+ qkv_bias=True,
+ guidance_embed=True,
+ ),
+ ae_path=os.getenv("AE"),
+ ae_params=AutoEncoderParams(
+ resolution=256,
+ in_channels=3,
+ ch=128,
+ out_ch=3,
+ ch_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ z_channels=16,
+ scale_factor=0.3611,
+ shift_factor=0.1159,
+ ),
+ ),
+ "flux-schnell": ModelSpec(
+ repo_id="black-forest-labs/FLUX.1-schnell",
+ repo_id_ae="black-forest-labs/FLUX.1-dev",
+ repo_flow="flux1-schnell.safetensors",
+ repo_ae="ae.safetensors",
+ ckpt_path=os.getenv("FLUX_SCHNELL"),
+ params=FluxParams(
+ in_channels=64,
+ vec_in_dim=768,
+ context_in_dim=4096,
+ hidden_size=3072,
+ mlp_ratio=4.0,
+ num_heads=24,
+ depth=19,
+ depth_single_blocks=38,
+ axes_dim=[16, 56, 56],
+ theta=10_000,
+ qkv_bias=True,
+ guidance_embed=False,
+ ),
+ ae_path=os.getenv("AE"),
+ ae_params=AutoEncoderParams(
+ resolution=256,
+ in_channels=3,
+ ch=128,
+ out_ch=3,
+ ch_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ z_channels=16,
+ scale_factor=0.3611,
+ shift_factor=0.1159,
+ ),
+ ),
+}
+
+
+def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
+ if len(missing) > 0 and len(unexpected) > 0:
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
+ print("\n" + "-" * 79 + "\n")
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
+ elif len(missing) > 0:
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
+ elif len(unexpected) > 0:
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
+
+
+def load_from_repo_id(repo_id, checkpoint_name):
+ ckpt_path = hf_hub_download(repo_id, checkpoint_name)
+ sd = load_sft(ckpt_path, device="cpu")
+ return sd
+
+
+def load_flow_model(
+ name: str, device: str | torch.device = "cuda", hf_download: bool = True
+):
+ # Loading Flux
+ print("Init model")
+ ckpt_path = configs[name].ckpt_path
+ if (
+ ckpt_path is None
+ and configs[name].repo_id is not None
+ and configs[name].repo_flow is not None
+ ):
+ ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
+
+ # with torch.device("meta" if ckpt_path is not None else device):
+ with torch.device(device):
+ model = Flux(configs[name].params).to(torch.bfloat16)
+
+ if ckpt_path is not None:
+ print("Loading main checkpoint")
+ # load_sft doesn't support torch.device
+ sd = load_model(ckpt_path, device="cpu")
+ missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
+ print_load_warning(missing, unexpected)
+ return model.to(str(device))
+
+
+def load_flow_model_only_lora(
+ name: str,
+ device: str | torch.device = "cuda",
+ hf_download: bool = True,
+ lora_rank: int = 16,
+ use_fp8: bool = False,
+):
+ # Loading Flux
+ ckpt_path = configs[name].ckpt_path
+ if (
+ ckpt_path is None
+ and configs[name].repo_id is not None
+ and configs[name].repo_flow is not None
+ ):
+ ckpt_path = hf_hub_download(
+ configs[name].repo_id, configs[name].repo_flow.replace("sft", "safetensors")
+ )
+
+ if hf_download:
+ try:
+ lora_ckpt_path = hf_hub_download(
+ "bytedance-research/USO", "uso_flux_v1.0/dit_lora.safetensors"
+ )
+ except Exception as e:
+ print(f"Failed to download lora checkpoint: {e}")
+ print("Trying to load lora from local")
+ lora_ckpt_path = os.environ.get("LORA", None)
+ try:
+ proj_ckpt_path = hf_hub_download(
+ "bytedance-research/USO", "uso_flux_v1.0/projector.safetensors"
+ )
+ except Exception as e:
+ print(f"Failed to download projection_model checkpoint: {e}")
+ print("Trying to load projection_model from local")
+ proj_ckpt_path = os.environ.get("PROJECTION_MODEL", None)
+ else:
+ lora_ckpt_path = os.environ.get("LORA", None)
+ proj_ckpt_path = os.environ.get("PROJECTION_MODEL", None)
+ with torch.device("meta" if ckpt_path is not None else device):
+ model = Flux(configs[name].params)
+
+ model = set_lora(
+ model, lora_rank, device="meta" if lora_ckpt_path is not None else device
+ )
+
+ if ckpt_path is not None:
+ print(f"Loading lora from {lora_ckpt_path}")
+ lora_sd = (
+ load_sft(lora_ckpt_path, device=str(device))
+ if lora_ckpt_path.endswith("safetensors")
+ else torch.load(lora_ckpt_path, map_location="cpu")
+ )
+ proj_sd = (
+ load_sft(proj_ckpt_path, device=str(device))
+ if proj_ckpt_path.endswith("safetensors")
+ else torch.load(proj_ckpt_path, map_location="cpu")
+ )
+ lora_sd.update(proj_sd)
+
+ print("Loading main checkpoint")
+ # load_sft doesn't support torch.device
+
+ if ckpt_path.endswith("safetensors"):
+ if use_fp8:
+ print(
+ "####\n"
+ "We are in fp8 mode right now, since the fp8 checkpoint of XLabs-AI/flux-dev-fp8 seems broken\n"
+ "we convert the fp8 checkpoint on flight from bf16 checkpoint\n"
+ "If your storage is constrained"
+ "you can save the fp8 checkpoint and replace the bf16 checkpoint by yourself\n"
+ )
+ sd = load_sft(ckpt_path, device="cpu")
+ sd = {
+ k: v.to(dtype=torch.float8_e4m3fn, device=device)
+ for k, v in sd.items()
+ }
+ else:
+ sd = load_sft(ckpt_path, device=str(device))
+
+ sd.update(lora_sd)
+ missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
+ else:
+ dit_state = torch.load(ckpt_path, map_location="cpu")
+ sd = {}
+ for k in dit_state.keys():
+ sd[k.replace("module.", "")] = dit_state[k]
+ sd.update(lora_sd)
+ missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
+ model.to(str(device))
+ print_load_warning(missing, unexpected)
+ return model
+
+
+def set_lora(
+ model: Flux,
+ lora_rank: int,
+ double_blocks_indices: list[int] | None = None,
+ single_blocks_indices: list[int] | None = None,
+ device: str | torch.device = "cpu",
+) -> Flux:
+ double_blocks_indices = (
+ list(range(model.params.depth))
+ if double_blocks_indices is None
+ else double_blocks_indices
+ )
+ single_blocks_indices = (
+ list(range(model.params.depth_single_blocks))
+ if single_blocks_indices is None
+ else single_blocks_indices
+ )
+
+ lora_attn_procs = {}
+ with torch.device(device):
+ for name, attn_processor in model.attn_processors.items():
+ match = re.search(r"\.(\d+)\.", name)
+ if match:
+ layer_index = int(match.group(1))
+
+ if (
+ name.startswith("double_blocks")
+ and layer_index in double_blocks_indices
+ ):
+ lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(
+ dim=model.params.hidden_size, rank=lora_rank
+ )
+ elif (
+ name.startswith("single_blocks")
+ and layer_index in single_blocks_indices
+ ):
+ lora_attn_procs[name] = SingleStreamBlockLoraProcessor(
+ dim=model.params.hidden_size, rank=lora_rank
+ )
+ else:
+ lora_attn_procs[name] = attn_processor
+ model.set_attn_processor(lora_attn_procs)
+ return model
+
+
+def load_flow_model_quintized(
+ name: str, device: str | torch.device = "cuda", hf_download: bool = True
+):
+ # Loading Flux
+ from optimum.quanto import requantize
+
+ print("Init model")
+ ckpt_path = configs[name].ckpt_path
+ if (
+ ckpt_path is None
+ and configs[name].repo_id is not None
+ and configs[name].repo_flow is not None
+ and hf_download
+ ):
+ ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
+ json_path = hf_hub_download(configs[name].repo_id, "flux_dev_quantization_map.json")
+
+ model = Flux(configs[name].params).to(torch.bfloat16)
+
+ print("Loading checkpoint")
+ # load_sft doesn't support torch.device
+ sd = load_sft(ckpt_path, device="cpu")
+ sd = {k: v.to(dtype=torch.float8_e4m3fn, device=device) for k, v in sd.items()}
+ model.load_state_dict(sd, assign=True)
+ return model
+ with open(json_path, "r") as f:
+ quantization_map = json.load(f)
+ print("Start a quantization process...")
+ requantize(model, sd, quantization_map, device=device)
+ print("Model is quantized!")
+ return model
+
+
+def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
+ # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
+ version = os.environ.get("T5", "xlabs-ai/xflux_text_encoders")
+ return HFEmbedder(version, max_length=max_length, torch_dtype=torch.bfloat16).to(
+ device
+ )
+
+
+def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
+ version = os.environ.get("CLIP", "openai/clip-vit-large-patch14")
+ return HFEmbedder(version, max_length=77, torch_dtype=torch.bfloat16).to(device)
+
+
+def load_ae(
+ name: str, device: str | torch.device = "cuda", hf_download: bool = True
+) -> AutoEncoder:
+ ckpt_path = configs[name].ae_path
+ if (
+ ckpt_path is None
+ and configs[name].repo_id is not None
+ and configs[name].repo_ae is not None
+ and hf_download
+ ):
+ ckpt_path = hf_hub_download(configs[name].repo_id_ae, configs[name].repo_ae)
+
+ # Loading the autoencoder
+ print("Init AE")
+ with torch.device("meta" if ckpt_path is not None else device):
+ ae = AutoEncoder(configs[name].ae_params)
+
+ if ckpt_path is not None:
+ sd = load_sft(ckpt_path, device=str(device))
+ missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
+ print_load_warning(missing, unexpected)
+ return ae