diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..602834a23fe7bea552229ebbcb1d6679363bc38a 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +*.png filter=lfs diff=lfs merge=lfs -text +*.jpg filter=lfs diff=lfs merge=lfs -text +*.webp filter=lfs diff=lfs merge=lfs -text \ No newline at end of file diff --git a/README copy.md b/README copy.md new file mode 100644 index 0000000000000000000000000000000000000000..027815ccc2e9b171f1f0bca4d818d0eba2278414 --- /dev/null +++ b/README copy.md @@ -0,0 +1,17 @@ +--- +title: USO +emoji: 💻 +colorFrom: indigo +colorTo: purple +sdk: gradio +sdk_version: 5.23.3 +app_file: app.py +pinned: false +license: apache-2.0 +short_description: Freely Combining Any Subjects with Any Styles Across All Scenarios. +models: + - black-forest-labs/FLUX.1-dev + - bytedance-research/UNO +--- + +Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..2256462b7c3f01d992faea87f6a5932f8f014148 --- /dev/null +++ b/app.py @@ -0,0 +1,242 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +import json +import os +from pathlib import Path + +import gradio as gr +import torch +import spaces + +from uso.flux.pipeline import USOPipeline +from transformers import SiglipVisionModel, SiglipImageProcessor + + +with open("assets/uso_text.svg", "r", encoding="utf-8") as svg_file: + text_content = svg_file.read() + +with open("assets/uso_logo.svg", "r", encoding="utf-8") as svg_file: + logo_content = svg_file.read() + +title = f""" +
+ {text_content} + by UXO Team + {logo_content} +
+""".strip() + +badges_text = r""" +
+ +Build +Build + +
+""".strip() + +tips = """ + 📌 **What is USO?** +USO is a unified style-subject optimized customization model and the latest addition to the UXO family ( USO and UNO). +It can freely combine arbitrary subjects with arbitrary styles in any scenarios. + + 💡 **How to use?** +We provide step-by-step instructions in our Github Repo. +Additionally, try the examples provided below the demo to quickly get familiar with USO and spark your creativity! + + ⚡️ The model is trained on 1024x1024 resolution and supports 3 types of usage: +* **Only content img**: support following types: + * Subject/Identity-driven (supports natural prompt, e.g., *A clock on the table.* *The woman near the sea.*, excels in producing **photorealistic portraits**) + * Style edit (layout-preserved): *Transform the image into Ghibli style/Pixel style/Retro comic style/Watercolor painting style...*. + * Style edit (layout-shift): *Ghibli style, the man on the beach.*. +* **Only style img**: Reference input style and generate anything following prompt. Excelling in this and further support multiple style references (in beta). +* **Content img + style img**: Place the content into the desired style. + * Layout-preserved: set prompt to **empty**. + * Layout-shift: using natural prompt.""" + +star = r""" +If USO is helpful, please help to ⭐ our Github Repo. Thanks a lot!""" + +def get_examples(examples_dir: str = "assets/examples") -> list: + examples = Path(examples_dir) + ans = [] + for example in examples.iterdir(): + if not example.is_dir() or len(os.listdir(example)) == 0: + continue + with open(example / "config.json") as f: + example_dict = json.load(f) + + + example_list = [] + + example_list.append(example_dict["usage"]) # case for + example_list.append(example_dict["prompt"]) # prompt + + for key in ["image_ref1", "image_ref2", "image_ref3"]: + if key in example_dict: + example_list.append(str(example / example_dict[key])) + else: + example_list.append(None) + + example_list.append(example_dict["seed"]) + ans.append(example_list) + return ans + + +def create_demo( + model_type: str, + device: str = "cuda" if torch.cuda.is_available() else "cpu", + offload: bool = False, +): + pipeline = USOPipeline( + model_type, device, offload, only_lora=True, lora_rank=128, hf_download=True + ) + print("USOPipeline loaded successfully") + + siglip_processor = SiglipImageProcessor.from_pretrained( + "google/siglip-so400m-patch14-384" + ) + siglip_model = SiglipVisionModel.from_pretrained( + "google/siglip-so400m-patch14-384" + ) + siglip_model.eval() + siglip_model.to(device) + pipeline.model.vision_encoder = siglip_model + pipeline.model.vision_encoder_processor = siglip_processor + print("SigLIP model loaded successfully") + + pipeline.gradio_generate = spaces.GPU(duration=120)(pipeline.gradio_generate) + with gr.Blocks() as demo: + gr.Markdown(title) + gr.Markdown(badges_text) + gr.Markdown(tips) + with gr.Row(): + with gr.Column(): + prompt = gr.Textbox(label="Prompt", value="A beautiful woman.") + with gr.Row(): + image_prompt1 = gr.Image( + label="Content Reference Img", visible=True, interactive=True, type="pil" + ) + image_prompt2 = gr.Image( + label="Style Reference Img", visible=True, interactive=True, type="pil" + ) + image_prompt3 = gr.Image( + label="Extra Style Reference Img (Beta)", visible=True, interactive=True, type="pil" + ) + + with gr.Row(): + with gr.Row(): + width = gr.Slider( + 512, 1536, 1024, step=16, label="Generation Width" + ) + height = gr.Slider( + 512, 1536, 1024, step=16, label="Generation Height" + ) + with gr.Row(): + with gr.Row(): + keep_size = gr.Checkbox( + label="Keep input size", + value=False, + interactive=True + ) + with gr.Column(): + gr.Markdown("Set it to True if you only need style editing or want to keep the layout.") + + with gr.Accordion("Advanced Options", open=True): + with gr.Row(): + num_steps = gr.Slider( + 1, 50, 25, step=1, label="Number of steps" + ) + guidance = gr.Slider( + 1.0, 5.0, 4.0, step=0.1, label="Guidance", interactive=True + ) + content_long_size = gr.Slider( + 0, 1024, 512, step=16, label="Content reference size" + ) + seed = gr.Number(-1, label="Seed (-1 for random)") + + generate_btn = gr.Button("Generate") + gr.Markdown(star) + + with gr.Column(): + output_image = gr.Image(label="Generated Image") + download_btn = gr.File( + label="Download full-resolution", type="filepath", interactive=False + ) + + inputs = [ + prompt, + image_prompt1, + image_prompt2, + image_prompt3, + seed, + width, + height, + guidance, + num_steps, + keep_size, + content_long_size, + ] + generate_btn.click( + fn=pipeline.gradio_generate, + inputs=inputs, + outputs=[output_image, download_btn], + ) + + example_text = gr.Text("", visible=False, label="Case For:") + examples = get_examples("./assets/gradio_examples") + + gr.Examples( + examples=examples, + inputs=[ + example_text, + prompt, + image_prompt1, + image_prompt2, + image_prompt3, + seed, + ], + # cache_examples='lazy', + outputs=[output_image, download_btn], + fn=pipeline.gradio_generate, + ) + + return demo + + +if __name__ == "__main__": + from typing import Literal + + from transformers import HfArgumentParser + + @dataclasses.dataclass + class AppArgs: + name: Literal["flux-dev", "flux-dev-fp8", "flux-schnell", "flux-krea-dev"] = "flux-dev" + device: Literal["cuda", "cpu"] = "cuda" if torch.cuda.is_available() else "cpu" + offload: bool = dataclasses.field( + default=False, + metadata={ + "help": "If True, sequantial offload the models(ae, dit, text encoder) to CPU if not used." + }, + ) + port: int = 7860 + + parser = HfArgumentParser([AppArgs]) + args_tuple = parser.parse_args_into_dataclasses() # type: tuple[AppArgs] + args = args_tuple[0] + + demo = create_demo(args.name, args.device, args.offload) + demo.launch(server_port=args.port) diff --git a/assets/gradio_examples/1subject/config.json b/assets/gradio_examples/1subject/config.json new file mode 100644 index 0000000000000000000000000000000000000000..2f470599cdd9382e12bded66467a33c6c4329197 --- /dev/null +++ b/assets/gradio_examples/1subject/config.json @@ -0,0 +1,6 @@ +{ + "prompt": "Wool felt style, a clock in the jungle.", + "seed": 3407, + "usage": "Subject-driven", + "image_ref1": "./ref.jpg" +} \ No newline at end of file diff --git a/assets/gradio_examples/1subject/ref.jpg b/assets/gradio_examples/1subject/ref.jpg new file mode 100644 index 0000000000000000000000000000000000000000..38615fba19fda54eec4df211606e95ec89cbae4b --- /dev/null +++ b/assets/gradio_examples/1subject/ref.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0e1eb6ca2c944f3bfaed3ace56f5f186ed073a477e0333e0237253d98f0c9267 +size 139451 diff --git a/assets/gradio_examples/2identity/config.json b/assets/gradio_examples/2identity/config.json new file mode 100644 index 0000000000000000000000000000000000000000..be3789e3a29e7e4156440a397fa30a265fd5fc42 --- /dev/null +++ b/assets/gradio_examples/2identity/config.json @@ -0,0 +1,6 @@ +{ + "prompt": "The girl is riding a bike in a street.", + "seed": 3407, + "usage": "Identity-driven", + "image_ref1": "./ref.webp" +} \ No newline at end of file diff --git a/assets/gradio_examples/2identity/ref.webp b/assets/gradio_examples/2identity/ref.webp new file mode 100644 index 0000000000000000000000000000000000000000..be2cb3ababd23173ad3c60a6336f00d685956f66 --- /dev/null +++ b/assets/gradio_examples/2identity/ref.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4e97502bd7eebd6692604f891f836f25c7c30dcac8d15c4d42cc874efc51fcc5 +size 85772 diff --git a/assets/gradio_examples/3identity/config.json b/assets/gradio_examples/3identity/config.json new file mode 100644 index 0000000000000000000000000000000000000000..5816a5f2b6d590c851698d90aa9ba93da9fc87d6 --- /dev/null +++ b/assets/gradio_examples/3identity/config.json @@ -0,0 +1,6 @@ +{ + "prompt": "The man in flower shops carefully match bouquets, conveying beautiful emotions and blessings with flowers.", + "seed": 3407, + "usage": "Identity-driven", + "image_ref1": "./ref.jpg" +} \ No newline at end of file diff --git a/assets/gradio_examples/3identity/ref.jpg b/assets/gradio_examples/3identity/ref.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2b8c787868211c1fc9d2678f917516479f871bdc --- /dev/null +++ b/assets/gradio_examples/3identity/ref.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2730103b6b9ebaf47b44ef9a9d7fbb722de7878a101af09f0b85f8dfadb4c8a4 +size 30572 diff --git a/assets/gradio_examples/4identity/config.json b/assets/gradio_examples/4identity/config.json new file mode 100644 index 0000000000000000000000000000000000000000..94497918462a8c0e09bc80eb3d9549e89ba6e7d8 --- /dev/null +++ b/assets/gradio_examples/4identity/config.json @@ -0,0 +1,6 @@ +{ + "prompt": "Transform the image into Ghibli style.", + "seed": 3407, + "usage": "Identity-driven", + "image_ref1": "./ref.webp" +} \ No newline at end of file diff --git a/assets/gradio_examples/4identity/ref.webp b/assets/gradio_examples/4identity/ref.webp new file mode 100644 index 0000000000000000000000000000000000000000..9f7cae34140b5d373c7ca51d4c2e8f67cf039319 --- /dev/null +++ b/assets/gradio_examples/4identity/ref.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f8ed8aa1c0714c939392e2c033735d6266e53266079bb300cbf05a6824a49f9f +size 38764 diff --git a/assets/gradio_examples/5style/config.json b/assets/gradio_examples/5style/config.json new file mode 100644 index 0000000000000000000000000000000000000000..9303fd6dfd485517ff34607b0058dea9cbec2922 --- /dev/null +++ b/assets/gradio_examples/5style/config.json @@ -0,0 +1,6 @@ +{ + "prompt": "A cat sleeping on a chair.", + "seed": 3407, + "usage": "Style-driven", + "image_ref2": "./ref.webp" +} \ No newline at end of file diff --git a/assets/gradio_examples/5style/ref.webp b/assets/gradio_examples/5style/ref.webp new file mode 100644 index 0000000000000000000000000000000000000000..8c2f06b9cd2c331914842ca76d6ed58a80a50085 --- /dev/null +++ b/assets/gradio_examples/5style/ref.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9ebf56d2d20ae5c49a582ff6bfef64b13022d0c624d9de25ed91047380fdfcfe +size 52340 diff --git a/assets/gradio_examples/6style/config.json b/assets/gradio_examples/6style/config.json new file mode 100644 index 0000000000000000000000000000000000000000..172585f3bae2f7e73d25dea3393f46435123f95e --- /dev/null +++ b/assets/gradio_examples/6style/config.json @@ -0,0 +1,6 @@ +{ + "prompt": "A beautiful woman.", + "seed": 3407, + "usage": "Style-driven", + "image_ref2": "./ref.webp" +} \ No newline at end of file diff --git a/assets/gradio_examples/6style/ref.webp b/assets/gradio_examples/6style/ref.webp new file mode 100644 index 0000000000000000000000000000000000000000..5c1e9ddb0d4a1c7658ddf607019bc76b1a4b6ce0 --- /dev/null +++ b/assets/gradio_examples/6style/ref.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:40c013341c8708b53094e3eaa377b3dfccdc9e77e215ad15d2ac2e875b4c494a +size 58202 diff --git a/assets/gradio_examples/7style_subject/config.json b/assets/gradio_examples/7style_subject/config.json new file mode 100644 index 0000000000000000000000000000000000000000..6dc3da8233e6a6b5f5a9192fb65227352d2c5cdd --- /dev/null +++ b/assets/gradio_examples/7style_subject/config.json @@ -0,0 +1,7 @@ +{ + "prompt": "", + "seed": 321, + "usage": "Style-subject-driven (layout-preserved)", + "image_ref1": "./ref1.webp", + "image_ref2": "./ref2.webp" +} \ No newline at end of file diff --git a/assets/gradio_examples/7style_subject/ref1.webp b/assets/gradio_examples/7style_subject/ref1.webp new file mode 100644 index 0000000000000000000000000000000000000000..9f7cae34140b5d373c7ca51d4c2e8f67cf039319 --- /dev/null +++ b/assets/gradio_examples/7style_subject/ref1.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f8ed8aa1c0714c939392e2c033735d6266e53266079bb300cbf05a6824a49f9f +size 38764 diff --git a/assets/gradio_examples/7style_subject/ref2.webp b/assets/gradio_examples/7style_subject/ref2.webp new file mode 100644 index 0000000000000000000000000000000000000000..c06e99d265a2b42b48875ec086457a5414eea627 --- /dev/null +++ b/assets/gradio_examples/7style_subject/ref2.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:175d6e5b975b4d494950250740c0fe371a7e9b2c93c59a3ae82b82be72ccc0f6 +size 14168 diff --git a/assets/gradio_examples/8style_subject/config.json b/assets/gradio_examples/8style_subject/config.json new file mode 100644 index 0000000000000000000000000000000000000000..ef542fde59e71687ee9e536649b8481a880d57a8 --- /dev/null +++ b/assets/gradio_examples/8style_subject/config.json @@ -0,0 +1,7 @@ +{ + "prompt": "The woman gave an impassioned speech on the podium.", + "seed": 321, + "usage": "Style-subject-driven (layout-shifted)", + "image_ref1": "./ref1.webp", + "image_ref2": "./ref2.webp" +} \ No newline at end of file diff --git a/assets/gradio_examples/8style_subject/ref1.webp b/assets/gradio_examples/8style_subject/ref1.webp new file mode 100644 index 0000000000000000000000000000000000000000..9f7cae34140b5d373c7ca51d4c2e8f67cf039319 --- /dev/null +++ b/assets/gradio_examples/8style_subject/ref1.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f8ed8aa1c0714c939392e2c033735d6266e53266079bb300cbf05a6824a49f9f +size 38764 diff --git a/assets/gradio_examples/8style_subject/ref2.webp b/assets/gradio_examples/8style_subject/ref2.webp new file mode 100644 index 0000000000000000000000000000000000000000..5ab3db1106adb67090bd533e0ddf1094d7c07065 --- /dev/null +++ b/assets/gradio_examples/8style_subject/ref2.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0235262d9bd1070155536352ccf195f9875ead0d3379dee7285c0aaae79f6464 +size 39098 diff --git a/assets/gradio_examples/9mix_style/config.json b/assets/gradio_examples/9mix_style/config.json new file mode 100644 index 0000000000000000000000000000000000000000..8c9aa236dc8c7bac8dd441aa9a86c786824a9d17 --- /dev/null +++ b/assets/gradio_examples/9mix_style/config.json @@ -0,0 +1,7 @@ +{ + "prompt": "A man.", + "seed": 321, + "usage": "Multi-style-driven", + "image_ref2": "./ref1.webp", + "image_ref3": "./ref2.webp" +} \ No newline at end of file diff --git a/assets/gradio_examples/9mix_style/ref1.webp b/assets/gradio_examples/9mix_style/ref1.webp new file mode 100644 index 0000000000000000000000000000000000000000..1c02f7fe712a295f858a666f211d994cecaa7ac1 --- /dev/null +++ b/assets/gradio_examples/9mix_style/ref1.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a1d272a0ecb03126503446b00a2152deab2045f89ac2c01f948e1099589d2862 +size 141886 diff --git a/assets/gradio_examples/9mix_style/ref2.webp b/assets/gradio_examples/9mix_style/ref2.webp new file mode 100644 index 0000000000000000000000000000000000000000..e99715757eb80c277f42a4f5295251c30f1af45f --- /dev/null +++ b/assets/gradio_examples/9mix_style/ref2.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b1ce04559726509672ce859d617a08d8dff8b2fe28f503fecbca7a5f66082882 +size 290260 diff --git a/assets/gradio_examples/identity1.jpg b/assets/gradio_examples/identity1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2b8c787868211c1fc9d2678f917516479f871bdc --- /dev/null +++ b/assets/gradio_examples/identity1.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2730103b6b9ebaf47b44ef9a9d7fbb722de7878a101af09f0b85f8dfadb4c8a4 +size 30572 diff --git a/assets/gradio_examples/identity1_result.png b/assets/gradio_examples/identity1_result.png new file mode 100644 index 0000000000000000000000000000000000000000..e4d895db3f534950cbd804e594f450bf55f6507c --- /dev/null +++ b/assets/gradio_examples/identity1_result.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7684256e44ce1bd4ada1e77a12674432eddd95b07fb388673899139afc56d864 +size 1538828 diff --git a/assets/gradio_examples/identity2.webp b/assets/gradio_examples/identity2.webp new file mode 100644 index 0000000000000000000000000000000000000000..9f7cae34140b5d373c7ca51d4c2e8f67cf039319 --- /dev/null +++ b/assets/gradio_examples/identity2.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f8ed8aa1c0714c939392e2c033735d6266e53266079bb300cbf05a6824a49f9f +size 38764 diff --git a/assets/gradio_examples/identity2_style2_result.webp b/assets/gradio_examples/identity2_style2_result.webp new file mode 100644 index 0000000000000000000000000000000000000000..2d28e2e08181b1862fb30521bec67989ee17eb30 --- /dev/null +++ b/assets/gradio_examples/identity2_style2_result.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8376b6dc02d304616c09ecf09c7dbabb16c7c9142fb4db21f576a15a1ec24062 +size 43892 diff --git a/assets/gradio_examples/style1.webp b/assets/gradio_examples/style1.webp new file mode 100644 index 0000000000000000000000000000000000000000..8c2f06b9cd2c331914842ca76d6ed58a80a50085 --- /dev/null +++ b/assets/gradio_examples/style1.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9ebf56d2d20ae5c49a582ff6bfef64b13022d0c624d9de25ed91047380fdfcfe +size 52340 diff --git a/assets/gradio_examples/style1_result.webp b/assets/gradio_examples/style1_result.webp new file mode 100644 index 0000000000000000000000000000000000000000..0b23eed9991c7ffab948d60af5dcab89bcf03aae --- /dev/null +++ b/assets/gradio_examples/style1_result.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:16a4353dd83b1c48499e222d6f77904e1fda23c1649ea5f6cca6b00b0fca3069 +size 61062 diff --git a/assets/gradio_examples/style2.webp b/assets/gradio_examples/style2.webp new file mode 100644 index 0000000000000000000000000000000000000000..5ab3db1106adb67090bd533e0ddf1094d7c07065 --- /dev/null +++ b/assets/gradio_examples/style2.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0235262d9bd1070155536352ccf195f9875ead0d3379dee7285c0aaae79f6464 +size 39098 diff --git a/assets/gradio_examples/style3.webp b/assets/gradio_examples/style3.webp new file mode 100644 index 0000000000000000000000000000000000000000..1c02f7fe712a295f858a666f211d994cecaa7ac1 --- /dev/null +++ b/assets/gradio_examples/style3.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a1d272a0ecb03126503446b00a2152deab2045f89ac2c01f948e1099589d2862 +size 141886 diff --git a/assets/gradio_examples/style3_style4_result.webp b/assets/gradio_examples/style3_style4_result.webp new file mode 100644 index 0000000000000000000000000000000000000000..2bc1bfc2258d5193a300c560563b3b21eaa434d4 --- /dev/null +++ b/assets/gradio_examples/style3_style4_result.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d09a5e429cc1d059aecd041e061868cd8e5b59f4718bb0f926fd84364f3794b0 +size 172716 diff --git a/assets/gradio_examples/style4.webp b/assets/gradio_examples/style4.webp new file mode 100644 index 0000000000000000000000000000000000000000..e99715757eb80c277f42a4f5295251c30f1af45f --- /dev/null +++ b/assets/gradio_examples/style4.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b1ce04559726509672ce859d617a08d8dff8b2fe28f503fecbca7a5f66082882 +size 290260 diff --git a/assets/gradio_examples/z_mix_style/config.json b/assets/gradio_examples/z_mix_style/config.json new file mode 100644 index 0000000000000000000000000000000000000000..db3579ae23a361a3e50cd4d767ef5f9426952bb6 --- /dev/null +++ b/assets/gradio_examples/z_mix_style/config.json @@ -0,0 +1,7 @@ +{ + "prompt": "Boat on water.", + "seed": 321, + "usage": "Multi-style-driven", + "image_ref2": "./ref1.png", + "image_ref3": "./ref2.png" +} \ No newline at end of file diff --git a/assets/gradio_examples/z_mix_style/ref1.png b/assets/gradio_examples/z_mix_style/ref1.png new file mode 100644 index 0000000000000000000000000000000000000000..e923d0ee6951cf100075015da04a686d50bd3698 --- /dev/null +++ b/assets/gradio_examples/z_mix_style/ref1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5c31ba662c85f4032abf079dfeb9cba08d797b7b63f1d661c5270b373b00d095 +size 26149 diff --git a/assets/gradio_examples/z_mix_style/ref2.png b/assets/gradio_examples/z_mix_style/ref2.png new file mode 100644 index 0000000000000000000000000000000000000000..8f850831089436bfc214349373a808cbf83cf54d --- /dev/null +++ b/assets/gradio_examples/z_mix_style/ref2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c47d23d5ffdbf30b4a8f6c1bc5d07a730825eaac8363c13bdac8e3bb8c330aed +size 14666 diff --git a/assets/gradio_examples/zz_t2i/config.json b/assets/gradio_examples/zz_t2i/config.json new file mode 100644 index 0000000000000000000000000000000000000000..e475d1043a4e13638495ac1db4118d7a5865a00c --- /dev/null +++ b/assets/gradio_examples/zz_t2i/config.json @@ -0,0 +1,5 @@ +{ + "prompt": "A beautiful woman.", + "seed": -1, + "usage": "Text-to-image" +} \ No newline at end of file diff --git a/assets/teaser.webp b/assets/teaser.webp new file mode 100644 index 0000000000000000000000000000000000000000..d6916918b54307276b3fc8a5bc269c62dd7f2989 --- /dev/null +++ b/assets/teaser.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:543c724f6b929303046ae481672567fe4a9620f0af5ca1dfff215dc7a2cbff5f +size 1674736 diff --git a/assets/uso.webp b/assets/uso.webp new file mode 100644 index 0000000000000000000000000000000000000000..0caec5b7d5824b89670f7cbeec120f872d406830 --- /dev/null +++ b/assets/uso.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:772957e867da33550437fa547202d0f995011353ef9a24036d23596dae1a1632 +size 58234 diff --git a/assets/uso_logo.svg b/assets/uso_logo.svg new file mode 100644 index 0000000000000000000000000000000000000000..dc91b85132492bc6c054b5184f24c236656a7569 --- /dev/null +++ b/assets/uso_logo.svg @@ -0,0 +1,880 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/assets/uso_text.svg b/assets/uso_text.svg new file mode 100644 index 0000000000000000000000000000000000000000..86153554bc20737c394cac6d50ee58d9f277572e --- /dev/null +++ b/assets/uso_text.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..61694b57e446afb1def8734413b85d3c4abf836a --- /dev/null +++ b/requirements.txt @@ -0,0 +1,19 @@ +accelerate==1.1.1 +deepspeed==0.14.4 +einops==0.8.0 +transformers==4.43.3 +huggingface-hub +diffusers==0.30.1 +sentencepiece==0.2.0 +gradio==5.22.0 +opencv-python +matplotlib +safetensors==0.4.5 +scipy==1.10.1 +numpy==1.24.4 +onnxruntime-gpu +# httpx==0.23.3 +git+https://github.com/openai/CLIP.git +--extra-index-url https://download.pytorch.org/whl/cu124 +torch==2.4.0 +torchvision==0.19.0 diff --git a/uso/flux/math.py b/uso/flux/math.py new file mode 100644 index 0000000000000000000000000000000000000000..2461437371d22a60eab7df4b5f5cb371dd692fe9 --- /dev/null +++ b/uso/flux/math.py @@ -0,0 +1,45 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. +# Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from einops import rearrange +from torch import Tensor + + +def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: + q, k = apply_rope(q, k, pe) + + x = torch.nn.functional.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "B H L D -> B L (H D)") + + return x + + +def rope(pos: Tensor, dim: int, theta: int) -> Tensor: + assert dim % 2 == 0 + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / (theta**scale) + out = torch.einsum("...n,d->...nd", pos, omega) + out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) + out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) + return out.float() + + +def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) diff --git a/uso/flux/model.py b/uso/flux/model.py new file mode 100644 index 0000000000000000000000000000000000000000..43d4c80227b80a50585df4afcd21d1449ca3f61d --- /dev/null +++ b/uso/flux/model.py @@ -0,0 +1,258 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. +# Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +import torch +from torch import Tensor, nn + +from .modules.layers import ( + DoubleStreamBlock, + EmbedND, + LastLayer, + MLPEmbedder, + SingleStreamBlock, + timestep_embedding, + SigLIPMultiFeatProjModel, +) +import os + + +@dataclass +class FluxParams: + in_channels: int + vec_in_dim: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + depth_single_blocks: int + axes_dim: list[int] + theta: int + qkv_bias: bool + guidance_embed: bool + + +class Flux(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + + _supports_gradient_checkpointing = True + + def __init__(self, params: FluxParams): + super().__init__() + + self.params = params + self.in_channels = params.in_channels + self.out_channels = self.in_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError( + f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" + ) + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError( + f"Got {params.axes_dim} but expected positional dim {pe_dim}" + ) + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND( + dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim + ) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) + self.guidance_in = ( + MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + if params.guidance_embed + else nn.Identity() + ) + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + ) + for _ in range(params.depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock( + self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio + ) + for _ in range(params.depth_single_blocks) + ] + ) + + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) + self.gradient_checkpointing = False + + # feature embedder for siglip multi-feat inputs + self.feature_embedder = SigLIPMultiFeatProjModel( + siglip_token_nums=729, + style_token_nums=64, + siglip_token_dims=1152, + hidden_size=self.hidden_size, + context_layer_norm=True, + ) + print("use semantic encoder siglip multi-feat to encode style image") + + self.vision_encoder = None + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + @property + def attn_processors(self): + # set recursively + processors = {} # type: dict[str, nn.Module] + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def forward( + self, + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + guidance: Tensor | None = None, + ref_img: Tensor | None = None, + ref_img_ids: Tensor | None = None, + siglip_inputs: list[Tensor] | None = None, + ) -> Tensor: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + vec = self.time_in(timestep_embedding(timesteps, 256)) + if self.params.guidance_embed: + if guidance is None: + raise ValueError( + "Didn't get guidance strength for guidance distilled model." + ) + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) + txt = self.txt_in(txt) + if self.feature_embedder is not None and siglip_inputs is not None and len(siglip_inputs) > 0 and self.vision_encoder is not None: + # processing style feat into textural hidden space + siglip_embedding = [self.vision_encoder(**emb, output_hidden_states=True) for emb in siglip_inputs] + # siglip_embedding = [self.vision_encoder(**(emb.to(torch.bfloat16)), output_hidden_states=True) for emb in siglip_inputs] + siglip_embedding = torch.cat([self.feature_embedder(emb) for emb in siglip_embedding], dim=1) + txt = torch.cat((siglip_embedding, txt), dim=1) + siglip_embedding_ids = torch.zeros( + siglip_embedding.shape[0], siglip_embedding.shape[1], 3 + ).to(txt_ids.device) + txt_ids = torch.cat((siglip_embedding_ids, txt_ids), dim=1) + + ids = torch.cat((txt_ids, img_ids), dim=1) + + # concat ref_img/img + img_end = img.shape[1] + if ref_img is not None: + if isinstance(ref_img, tuple) or isinstance(ref_img, list): + img_in = [img] + [self.img_in(ref) for ref in ref_img] + img_ids = [ids] + [ref_ids for ref_ids in ref_img_ids] + img = torch.cat(img_in, dim=1) + ids = torch.cat(img_ids, dim=1) + else: + img = torch.cat((img, self.img_in(ref_img)), dim=1) + ids = torch.cat((ids, ref_img_ids), dim=1) + pe = self.pe_embedder(ids) + + for index_block, block in enumerate(self.double_blocks): + if self.training and self.gradient_checkpointing: + img, txt = torch.utils.checkpoint.checkpoint( + block, + img=img, + txt=txt, + vec=vec, + pe=pe, + use_reentrant=False, + ) + else: + img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + + img = torch.cat((txt, img), 1) + for block in self.single_blocks: + if self.training and self.gradient_checkpointing: + img = torch.utils.checkpoint.checkpoint( + block, img, vec=vec, pe=pe, use_reentrant=False + ) + else: + img = block(img, vec=vec, pe=pe) + img = img[:, txt.shape[1] :, ...] + # index img + img = img[:, :img_end, ...] + + img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + return img diff --git a/uso/flux/modules/__pycache__/autoencoder.cpython-311.pyc b/uso/flux/modules/__pycache__/autoencoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d858acdb8edb03ae0e3ea8cc271feea557f51b1f Binary files /dev/null and b/uso/flux/modules/__pycache__/autoencoder.cpython-311.pyc differ diff --git a/uso/flux/modules/__pycache__/conditioner.cpython-311.pyc b/uso/flux/modules/__pycache__/conditioner.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7334c90551d0028a5c6e0404b4f64715b12a357 Binary files /dev/null and b/uso/flux/modules/__pycache__/conditioner.cpython-311.pyc differ diff --git a/uso/flux/modules/__pycache__/layers.cpython-311.pyc b/uso/flux/modules/__pycache__/layers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89cc4900e468c9e2e16a625e80cfa53e0d1c5f3d Binary files /dev/null and b/uso/flux/modules/__pycache__/layers.cpython-311.pyc differ diff --git a/uso/flux/modules/autoencoder.py b/uso/flux/modules/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..2543bdf4240e1db5b5dc958e148ac0cb12d9e9e3 --- /dev/null +++ b/uso/flux/modules/autoencoder.py @@ -0,0 +1,327 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. +# Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +import torch +from einops import rearrange +from torch import Tensor, nn + + +@dataclass +class AutoEncoderParams: + resolution: int + in_channels: int + ch: int + out_ch: int + ch_mult: list[int] + num_res_blocks: int + z_channels: int + scale_factor: float + shift_factor: float + + +def swish(x: Tensor) -> Tensor: + return x * torch.sigmoid(x) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + + self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) + + def attention(self, h_: Tensor) -> Tensor: + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() + k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() + v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() + h_ = nn.functional.scaled_dot_product_attention(q, k, v) + + return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) + + def forward(self, x: Tensor) -> Tensor: + return x + self.proj_out(self.attention(x)) + + +class ResnetBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h = x + h = self.norm1(h) + h = swish(h) + h = self.conv1(h) + + h = self.norm2(h) + h = swish(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + + return x + h + + +class Downsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + # no asymmetric padding in torch conv, must do it ourselves + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x: Tensor): + pad = (0, 1, 0, 1) + x = nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Upsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: Tensor): + x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + x = self.conv(x) + return x + + +class Encoder(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + ch: int, + ch_mult: list[int], + num_res_blocks: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + # downsampling + self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + block_in = self.ch + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: Tensor) -> Tensor: + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + ch: int, + out_ch: int, + ch_mult: list[int], + num_res_blocks: int, + in_channels: int, + resolution: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.ffactor = 2 ** (self.num_resolutions - 1) + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, z: Tensor) -> Tensor: + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class DiagonalGaussian(nn.Module): + def __init__(self, sample: bool = True, chunk_dim: int = 1): + super().__init__() + self.sample = sample + self.chunk_dim = chunk_dim + + def forward(self, z: Tensor) -> Tensor: + mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) + if self.sample: + std = torch.exp(0.5 * logvar) + return mean + std * torch.randn_like(mean) + else: + return mean + + +class AutoEncoder(nn.Module): + def __init__(self, params: AutoEncoderParams): + super().__init__() + self.encoder = Encoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.decoder = Decoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + out_ch=params.out_ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.reg = DiagonalGaussian() + + self.scale_factor = params.scale_factor + self.shift_factor = params.shift_factor + + def encode(self, x: Tensor) -> Tensor: + z = self.reg(self.encoder(x)) + z = self.scale_factor * (z - self.shift_factor) + return z + + def decode(self, z: Tensor) -> Tensor: + z = z / self.scale_factor + self.shift_factor + return self.decoder(z) + + def forward(self, x: Tensor) -> Tensor: + return self.decode(self.encode(x)) diff --git a/uso/flux/modules/conditioner.py b/uso/flux/modules/conditioner.py new file mode 100644 index 0000000000000000000000000000000000000000..047950827b18f3577c7c43247392c6e0d9295f1f --- /dev/null +++ b/uso/flux/modules/conditioner.py @@ -0,0 +1,53 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. +# Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from torch import Tensor, nn +from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel, + T5Tokenizer) + + +class HFEmbedder(nn.Module): + def __init__(self, version: str, max_length: int, **hf_kwargs): + super().__init__() + self.is_clip = "clip" in version.lower() + self.max_length = max_length + self.output_key = "pooler_output" if self.is_clip else "last_hidden_state" + + if self.is_clip: + self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length) + self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs) + else: + self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length) + self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs) + + self.hf_module = self.hf_module.eval().requires_grad_(False) + + def forward(self, text: list[str]) -> Tensor: + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=False, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + + outputs = self.hf_module( + input_ids=batch_encoding["input_ids"].to(self.hf_module.device), + attention_mask=None, + output_hidden_states=False, + ) + return outputs[self.output_key] diff --git a/uso/flux/modules/layers.py b/uso/flux/modules/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..3bb8139ab78a88863ae822bac111815caf70c7e8 --- /dev/null +++ b/uso/flux/modules/layers.py @@ -0,0 +1,631 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. +# Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass + +import torch +from einops import rearrange, repeat +from torch import Tensor, nn + +from ..math import attention, rope + + +class EmbedND(nn.Module): + def __init__(self, dim: int, theta: int, axes_dim: list[int]): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: Tensor) -> Tensor: + n_axes = ids.shape[-1] + emb = torch.cat( + [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], + dim=-3, + ) + + return emb.unsqueeze(1) + + +def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + t = time_factor * t + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(t.device) + + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + if torch.is_floating_point(t): + embedding = embedding.to(t) + return embedding + + +class MLPEmbedder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int): + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) + self.silu = nn.SiLU() + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) + + def forward(self, x: Tensor) -> Tensor: + return self.out_layer(self.silu(self.in_layer(x))) + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.scale = nn.Parameter(torch.ones(dim)) + + def forward(self, x: Tensor): + x_dtype = x.dtype + x = x.float() + rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) + return ((x * rrms) * self.scale.float()).to(dtype=x_dtype) + + +class QKNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.query_norm = RMSNorm(dim) + self.key_norm = RMSNorm(dim) + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: + q = self.query_norm(q) + k = self.key_norm(k) + return q.to(v), k.to(v) + + +class LoRALinearLayer(nn.Module): + def __init__( + self, + in_features, + out_features, + rank=4, + network_alpha=None, + device=None, + dtype=None, + ): + super().__init__() + + self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype) + self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype) + # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. + # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning + self.network_alpha = network_alpha + self.rank = rank + + nn.init.normal_(self.down.weight, std=1 / rank) + nn.init.zeros_(self.up.weight) + + def forward(self, hidden_states): + orig_dtype = hidden_states.dtype + dtype = self.down.weight.dtype + + down_hidden_states = self.down(hidden_states.to(dtype)) + up_hidden_states = self.up(down_hidden_states) + + if self.network_alpha is not None: + up_hidden_states *= self.network_alpha / self.rank + + return up_hidden_states.to(orig_dtype) + + +class FLuxSelfAttnProcessor: + def __call__(self, attn, x, pe, **attention_kwargs): + qkv = attn.qkv(x) + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = attn.norm(q, k, v) + x = attention(q, k, v, pe=pe) + x = attn.proj(x) + return x + + +class LoraFluxAttnProcessor(nn.Module): + + def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1): + super().__init__() + self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha) + self.proj_lora = LoRALinearLayer(dim, dim, rank, network_alpha) + self.lora_weight = lora_weight + + def __call__(self, attn, x, pe, **attention_kwargs): + qkv = attn.qkv(x) + self.qkv_lora(x) * self.lora_weight + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = attn.norm(q, k, v) + x = attention(q, k, v, pe=pe) + x = attn.proj(x) + self.proj_lora(x) * self.lora_weight + return x + + +class SelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.norm = QKNorm(head_dim) + self.proj = nn.Linear(dim, dim) + + def forward(): + pass + + +@dataclass +class ModulationOut: + shift: Tensor + scale: Tensor + gate: Tensor + + +class Modulation(nn.Module): + def __init__(self, dim: int, double: bool): + super().__init__() + self.is_double = double + self.multiplier = 6 if double else 3 + self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) + + def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]: + out = self.lin(nn.functional.silu(vec))[:, None, :].chunk( + self.multiplier, dim=-1 + ) + + return ( + ModulationOut(*out[:3]), + ModulationOut(*out[3:]) if self.is_double else None, + ) + + +class DoubleStreamBlockLoraProcessor(nn.Module): + def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1): + super().__init__() + self.qkv_lora1 = LoRALinearLayer(dim, dim * 3, rank, network_alpha) + self.proj_lora1 = LoRALinearLayer(dim, dim, rank, network_alpha) + self.qkv_lora2 = LoRALinearLayer(dim, dim * 3, rank, network_alpha) + self.proj_lora2 = LoRALinearLayer(dim, dim, rank, network_alpha) + self.lora_weight = lora_weight + + def forward(self, attn, img, txt, vec, pe, **attention_kwargs): + img_mod1, img_mod2 = attn.img_mod(vec) + txt_mod1, txt_mod2 = attn.txt_mod(vec) + + # prepare image for attention + img_modulated = attn.img_norm1(img) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_qkv = ( + attn.img_attn.qkv(img_modulated) + + self.qkv_lora1(img_modulated) * self.lora_weight + ) + img_q, img_k, img_v = rearrange( + img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads + ) + img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v) + + # prepare txt for attention + txt_modulated = attn.txt_norm1(txt) + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_qkv = ( + attn.txt_attn.qkv(txt_modulated) + + self.qkv_lora2(txt_modulated) * self.lora_weight + ) + txt_q, txt_k, txt_v = rearrange( + txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads + ) + txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v) + + # run actual attention + q = torch.cat((txt_q, img_q), dim=2) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + attn1 = attention(q, k, v, pe=pe) + txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :] + + # calculate the img bloks + img = img + img_mod1.gate * ( + attn.img_attn.proj(img_attn) + self.proj_lora1(img_attn) * self.lora_weight + ) + img = img + img_mod2.gate * attn.img_mlp( + (1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift + ) + + # calculate the txt bloks + txt = txt + txt_mod1.gate * ( + attn.txt_attn.proj(txt_attn) + self.proj_lora2(txt_attn) * self.lora_weight + ) + txt = txt + txt_mod2.gate * attn.txt_mlp( + (1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift + ) + return img, txt + + +class DoubleStreamBlockProcessor: + def __call__(self, attn, img, txt, vec, pe, **attention_kwargs): + img_mod1, img_mod2 = attn.img_mod(vec) + txt_mod1, txt_mod2 = attn.txt_mod(vec) + + # prepare image for attention + img_modulated = attn.img_norm1(img) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_qkv = attn.img_attn.qkv(img_modulated) + img_q, img_k, img_v = rearrange( + img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim + ) + img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v) + + # prepare txt for attention + txt_modulated = attn.txt_norm1(txt) + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_qkv = attn.txt_attn.qkv(txt_modulated) + txt_q, txt_k, txt_v = rearrange( + txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim + ) + txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v) + + # run actual attention + q = torch.cat((txt_q, img_q), dim=2) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + attn1 = attention(q, k, v, pe=pe) + txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :] + + # calculate the img bloks + img = img + img_mod1.gate * attn.img_attn.proj(img_attn) + img = img + img_mod2.gate * attn.img_mlp( + (1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift + ) + + # calculate the txt bloks + txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn) + txt = txt + txt_mod2.gate * attn.txt_mlp( + (1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift + ) + return img, txt + + +class DoubleStreamBlock(nn.Module): + def __init__( + self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False + ): + super().__init__() + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.num_heads = num_heads + self.hidden_size = hidden_size + self.head_dim = hidden_size // num_heads + + self.img_mod = Modulation(hidden_size, double=True) + self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_attn = SelfAttention( + dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias + ) + + self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + self.txt_mod = Modulation(hidden_size, double=True) + self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_attn = SelfAttention( + dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias + ) + + self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + processor = DoubleStreamBlockProcessor() + self.set_processor(processor) + + def set_processor(self, processor) -> None: + self.processor = processor + + def get_processor(self): + return self.processor + + def forward( + self, + img: Tensor, + txt: Tensor, + vec: Tensor, + pe: Tensor, + image_proj: Tensor = None, + ip_scale: float = 1.0, + ) -> tuple[Tensor, Tensor]: + if image_proj is None: + return self.processor(self, img, txt, vec, pe) + else: + return self.processor(self, img, txt, vec, pe, image_proj, ip_scale) + + +class SingleStreamBlockLoraProcessor(nn.Module): + def __init__( + self, dim: int, rank: int = 4, network_alpha=None, lora_weight: float = 1 + ): + super().__init__() + self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha) + self.proj_lora = LoRALinearLayer(15360, dim, rank, network_alpha) + self.lora_weight = lora_weight + + def forward(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: + + mod, _ = attn.modulation(vec) + x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift + qkv, mlp = torch.split( + attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1 + ) + qkv = qkv + self.qkv_lora(x_mod) * self.lora_weight + + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) + q, k = attn.norm(q, k, v) + + # compute attention + attn_1 = attention(q, k, v, pe=pe) + + # compute activation in mlp stream, cat again and run second linear layer + output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2)) + output = ( + output + + self.proj_lora(torch.cat((attn_1, attn.mlp_act(mlp)), 2)) + * self.lora_weight + ) + output = x + mod.gate * output + return output + + +class SingleStreamBlockProcessor: + def __call__( + self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor, **attention_kwargs + ) -> Tensor: + + mod, _ = attn.modulation(vec) + x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift + qkv, mlp = torch.split( + attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1 + ) + + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) + q, k = attn.norm(q, k, v) + + # compute attention + attn_1 = attention(q, k, v, pe=pe) + + # compute activation in mlp stream, cat again and run second linear layer + output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2)) + output = x + mod.gate * output + return output + + +class SingleStreamBlock(nn.Module): + """ + A DiT block with parallel linear layers as described in + https://arxiv.org/abs/2302.05442 and adapted modulation interface. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + qk_scale: float | None = None, + ): + super().__init__() + self.hidden_dim = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.scale = qk_scale or self.head_dim**-0.5 + + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + # qkv and mlp_in + self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) + # proj and mlp_out + self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) + + self.norm = QKNorm(self.head_dim) + + self.hidden_size = hidden_size + self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.mlp_act = nn.GELU(approximate="tanh") + self.modulation = Modulation(hidden_size, double=False) + + processor = SingleStreamBlockProcessor() + self.set_processor(processor) + + def set_processor(self, processor) -> None: + self.processor = processor + + def get_processor(self): + return self.processor + + def forward( + self, + x: Tensor, + vec: Tensor, + pe: Tensor, + image_proj: Tensor | None = None, + ip_scale: float = 1.0, + ) -> Tensor: + if image_proj is None: + return self.processor(self, x, vec, pe) + else: + return self.processor(self, x, vec, pe, image_proj, ip_scale) + + +class LastLayer(nn.Module): + def __init__(self, hidden_size: int, patch_size: int, out_channels: int): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear( + hidden_size, patch_size * patch_size * out_channels, bias=True + ) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True) + ) + + def forward(self, x: Tensor, vec: Tensor) -> Tensor: + shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) + x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + x = self.linear(x) + return x + + +class SigLIPMultiFeatProjModel(torch.nn.Module): + """ + SigLIP Multi-Feature Projection Model for processing style features from different layers + and projecting them into a unified hidden space. + + Args: + siglip_token_nums (int): Number of SigLIP tokens, default 257 + style_token_nums (int): Number of style tokens, default 256 + siglip_token_dims (int): Dimension of SigLIP tokens, default 1536 + hidden_size (int): Hidden layer size, default 3072 + context_layer_norm (bool): Whether to use context layer normalization, default False + """ + + def __init__( + self, + siglip_token_nums: int = 257, + style_token_nums: int = 256, + siglip_token_dims: int = 1536, + hidden_size: int = 3072, + context_layer_norm: bool = False, + ): + super().__init__() + + # High-level feature processing (layer -2) + self.high_embedding_linear = nn.Sequential( + nn.Linear(siglip_token_nums, style_token_nums), + nn.SiLU() + ) + self.high_layer_norm = ( + nn.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity() + ) + self.high_projection = nn.Linear(siglip_token_dims, hidden_size, bias=True) + + # Mid-level feature processing (layer -11) + self.mid_embedding_linear = nn.Sequential( + nn.Linear(siglip_token_nums, style_token_nums), + nn.SiLU() + ) + self.mid_layer_norm = ( + nn.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity() + ) + self.mid_projection = nn.Linear(siglip_token_dims, hidden_size, bias=True) + + # Low-level feature processing (layer -20) + self.low_embedding_linear = nn.Sequential( + nn.Linear(siglip_token_nums, style_token_nums), + nn.SiLU() + ) + self.low_layer_norm = ( + nn.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity() + ) + self.low_projection = nn.Linear(siglip_token_dims, hidden_size, bias=True) + + def forward(self, siglip_outputs): + """ + Forward pass function + + Args: + siglip_outputs: Output from SigLIP model, containing hidden_states + + Returns: + torch.Tensor: Concatenated multi-layer features with shape [bs, 3*style_token_nums, hidden_size] + """ + dtype = next(self.high_embedding_linear.parameters()).dtype + + # Process high-level features (layer -2) + high_embedding = self._process_layer_features( + siglip_outputs.hidden_states[-2], + self.high_embedding_linear, + self.high_layer_norm, + self.high_projection, + dtype + ) + + # Process mid-level features (layer -11) + mid_embedding = self._process_layer_features( + siglip_outputs.hidden_states[-11], + self.mid_embedding_linear, + self.mid_layer_norm, + self.mid_projection, + dtype + ) + + # Process low-level features (layer -20) + low_embedding = self._process_layer_features( + siglip_outputs.hidden_states[-20], + self.low_embedding_linear, + self.low_layer_norm, + self.low_projection, + dtype + ) + + # Concatenate features from all layers + return torch.cat((high_embedding, mid_embedding, low_embedding), dim=1) + + def _process_layer_features( + self, + hidden_states: torch.Tensor, + embedding_linear: nn.Module, + layer_norm: nn.Module, + projection: nn.Module, + dtype: torch.dtype + ) -> torch.Tensor: + """ + Helper function to process features from a single layer + + Args: + hidden_states: Input hidden states [bs, seq_len, dim] + embedding_linear: Embedding linear layer + layer_norm: Layer normalization + projection: Projection layer + dtype: Target data type + + Returns: + torch.Tensor: Processed features [bs, style_token_nums, hidden_size] + """ + # Transform dimensions: [bs, seq_len, dim] -> [bs, dim, seq_len] -> [bs, dim, style_token_nums] -> [bs, style_token_nums, dim] + embedding = embedding_linear( + hidden_states.to(dtype).transpose(1, 2) + ).transpose(1, 2) + + # Apply layer normalization + embedding = layer_norm(embedding) + + # Project to target hidden space + embedding = projection(embedding) + + return embedding diff --git a/uso/flux/pipeline.py b/uso/flux/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..15be80cb624cbc9403e520d1c3fed798d995602c --- /dev/null +++ b/uso/flux/pipeline.py @@ -0,0 +1,390 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. +# Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import math +from typing import Literal, Optional +from torch import Tensor + +import torch +from einops import rearrange +from PIL import ExifTags, Image +import torchvision.transforms.functional as TVF + +from uso.flux.modules.layers import ( + DoubleStreamBlockLoraProcessor, + DoubleStreamBlockProcessor, + SingleStreamBlockLoraProcessor, + SingleStreamBlockProcessor, +) +from uso.flux.sampling import denoise, get_noise, get_schedule, prepare_multi_ip, unpack +from uso.flux.util import ( + get_lora_rank, + load_ae, + load_checkpoint, + load_clip, + load_flow_model, + load_flow_model_only_lora, + load_t5, +) + + +def find_nearest_scale(image_h, image_w, predefined_scales): + """ + 根据图片的高度和宽度,找到最近的预定义尺度。 + + :param image_h: 图片的高度 + :param image_w: 图片的宽度 + :param predefined_scales: 预定义尺度列表 [(h1, w1), (h2, w2), ...] + :return: 最近的预定义尺度 (h, w) + """ + # 计算输入图片的长宽比 + image_ratio = image_h / image_w + + # 初始化变量以存储最小差异和最近的尺度 + min_diff = float("inf") + nearest_scale = None + + # 遍历所有预定义尺度,找到与输入图片长宽比最接近的尺度 + for scale_h, scale_w in predefined_scales: + predefined_ratio = scale_h / scale_w + diff = abs(predefined_ratio - image_ratio) + + if diff < min_diff: + min_diff = diff + nearest_scale = (scale_h, scale_w) + + return nearest_scale + + +def preprocess_ref(raw_image: Image.Image, long_size: int = 512, scale_ratio: int = 1): + # 获取原始图像的宽度和高度 + image_w, image_h = raw_image.size + if image_w == image_h and image_w == 16: + return raw_image + + # 计算长边和短边 + if image_w >= image_h: + new_w = long_size + new_h = int((long_size / image_w) * image_h) + else: + new_h = long_size + new_w = int((long_size / image_h) * image_w) + + # 按新的宽高进行等比例缩放 + raw_image = raw_image.resize((new_w, new_h), resample=Image.LANCZOS) + + # 为了能让canny img进行scale + scale_ratio = int(scale_ratio) + target_w = new_w // (16 * scale_ratio) * (16 * scale_ratio) + target_h = new_h // (16 * scale_ratio) * (16 * scale_ratio) + + # 计算裁剪的起始坐标以实现中心裁剪 + left = (new_w - target_w) // 2 + top = (new_h - target_h) // 2 + right = left + target_w + bottom = top + target_h + + # 进行中心裁剪 + raw_image = raw_image.crop((left, top, right, bottom)) + + # 转换为 RGB 模式 + raw_image = raw_image.convert("RGB") + return raw_image + + +def resize_and_centercrop_image(image, target_height_ref1, target_width_ref1): + target_height_ref1 = int(target_height_ref1 // 64 * 64) + target_width_ref1 = int(target_width_ref1 // 64 * 64) + h, w = image.shape[-2:] + if h < target_height_ref1 or w < target_width_ref1: + # 计算长宽比 + aspect_ratio = w / h + if h < target_height_ref1: + new_h = target_height_ref1 + new_w = new_h * aspect_ratio + if new_w < target_width_ref1: + new_w = target_width_ref1 + new_h = new_w / aspect_ratio + else: + new_w = target_width_ref1 + new_h = new_w / aspect_ratio + if new_h < target_height_ref1: + new_h = target_height_ref1 + new_w = new_h * aspect_ratio + else: + aspect_ratio = w / h + tgt_aspect_ratio = target_width_ref1 / target_height_ref1 + if aspect_ratio > tgt_aspect_ratio: + new_h = target_height_ref1 + new_w = new_h * aspect_ratio + else: + new_w = target_width_ref1 + new_h = new_w / aspect_ratio + # 使用 TVF.resize 进行图像缩放 + image = TVF.resize(image, (math.ceil(new_h), math.ceil(new_w))) + # 计算中心裁剪的参数 + top = (image.shape[-2] - target_height_ref1) // 2 + left = (image.shape[-1] - target_width_ref1) // 2 + # 使用 TVF.crop 进行中心裁剪 + image = TVF.crop(image, top, left, target_height_ref1, target_width_ref1) + return image + + +class USOPipeline: + def __init__( + self, + model_type: str, + device: torch.device, + offload: bool = False, + only_lora: bool = False, + lora_rank: int = 16, + hf_download: bool = True, + ): + self.device = device + self.offload = offload + self.model_type = model_type + + self.clip = load_clip(self.device) + self.t5 = load_t5(self.device, max_length=512) + self.ae = load_ae(model_type, device="cpu" if offload else self.device) + self.use_fp8 = "fp8" in model_type + if only_lora: + self.model = load_flow_model_only_lora( + model_type, + device="cpu" if offload else self.device, + lora_rank=lora_rank, + use_fp8=self.use_fp8, + hf_download=hf_download, + ) + else: + self.model = load_flow_model( + model_type, device="cpu" if offload else self.device + ) + + def load_ckpt(self, ckpt_path): + if ckpt_path is not None: + from safetensors.torch import load_file as load_sft + + print("Loading checkpoint to replace old keys") + # load_sft doesn't support torch.device + if ckpt_path.endswith("safetensors"): + sd = load_sft(ckpt_path, device="cpu") + missing, unexpected = self.model.load_state_dict( + sd, strict=False, assign=True + ) + else: + dit_state = torch.load(ckpt_path, map_location="cpu") + sd = {} + for k in dit_state.keys(): + sd[k.replace("module.", "")] = dit_state[k] + missing, unexpected = self.model.load_state_dict( + sd, strict=False, assign=True + ) + self.model.to(str(self.device)) + print(f"missing keys: {missing}\n\n\n\n\nunexpected keys: {unexpected}") + + def set_lora( + self, + local_path: str = None, + repo_id: str = None, + name: str = None, + lora_weight: int = 0.7, + ): + checkpoint = load_checkpoint(local_path, repo_id, name) + self.update_model_with_lora(checkpoint, lora_weight) + + def set_lora_from_collection( + self, lora_type: str = "realism", lora_weight: int = 0.7 + ): + checkpoint = load_checkpoint( + None, self.hf_lora_collection, self.lora_types_to_names[lora_type] + ) + self.update_model_with_lora(checkpoint, lora_weight) + + def update_model_with_lora(self, checkpoint, lora_weight): + rank = get_lora_rank(checkpoint) + lora_attn_procs = {} + + for name, _ in self.model.attn_processors.items(): + lora_state_dict = {} + for k in checkpoint.keys(): + if name in k: + lora_state_dict[k[len(name) + 1 :]] = checkpoint[k] * lora_weight + + if len(lora_state_dict): + if name.startswith("single_blocks"): + lora_attn_procs[name] = SingleStreamBlockLoraProcessor( + dim=3072, rank=rank + ) + else: + lora_attn_procs[name] = DoubleStreamBlockLoraProcessor( + dim=3072, rank=rank + ) + lora_attn_procs[name].load_state_dict(lora_state_dict) + lora_attn_procs[name].to(self.device) + else: + if name.startswith("single_blocks"): + lora_attn_procs[name] = SingleStreamBlockProcessor() + else: + lora_attn_procs[name] = DoubleStreamBlockProcessor() + + self.model.set_attn_processor(lora_attn_procs) + + def __call__( + self, + prompt: str, + width: int = 512, + height: int = 512, + guidance: float = 4, + num_steps: int = 50, + seed: int = 123456789, + **kwargs, + ): + width = 16 * (width // 16) + height = 16 * (height // 16) + + device_type = self.device if isinstance(self.device, str) else self.device.type + with torch.autocast( + enabled=self.use_fp8, device_type=device_type, dtype=torch.bfloat16 + ): + return self.forward( + prompt, width, height, guidance, num_steps, seed, **kwargs + ) + + @torch.inference_mode() + def gradio_generate( + self, + prompt: str, + image_prompt1: Image.Image, + image_prompt2: Image.Image, + image_prompt3: Image.Image, + seed: int, + width: int = 1024, + height: int = 1024, + guidance: float = 4, + num_steps: int = 25, + keep_size: bool = False, + content_long_size: int = 512, + ): + ref_content_imgs = [image_prompt1] + ref_content_imgs = [img for img in ref_content_imgs if isinstance(img, Image.Image)] + ref_content_imgs = [preprocess_ref(img, content_long_size) for img in ref_content_imgs] + + ref_style_imgs = [image_prompt2, image_prompt3] + ref_style_imgs = [img for img in ref_style_imgs if isinstance(img, Image.Image)] + ref_style_imgs = [self.model.vision_encoder_processor(img, return_tensors="pt").to(self.device) for img in ref_style_imgs] + + seed = seed if seed != -1 else torch.randint(0, 10**8, (1,)).item() + + # whether keep input image size + if keep_size and len(ref_content_imgs)>0: + width, height = ref_content_imgs[0].size + width, height = int(width * (1024 / content_long_size)), int(height * (1024 / content_long_size)) + img = self( + prompt=prompt, + width=width, + height=height, + guidance=guidance, + num_steps=num_steps, + seed=seed, + ref_imgs=ref_content_imgs, + siglip_inputs=ref_style_imgs, + ) + + filename = f"output/gradio/{seed}_{prompt[:20]}.png" + os.makedirs(os.path.dirname(filename), exist_ok=True) + exif_data = Image.Exif() + exif_data[ExifTags.Base.Make] = "USO" + exif_data[ExifTags.Base.Model] = self.model_type + info = f"{prompt=}, {seed=}, {width=}, {height=}, {guidance=}, {num_steps=}" + exif_data[ExifTags.Base.ImageDescription] = info + img.save(filename, format="png", exif=exif_data) + return img, filename + + @torch.inference_mode + def forward( + self, + prompt: str, + width: int, + height: int, + guidance: float, + num_steps: int, + seed: int, + ref_imgs: list[Image.Image] | None = None, + pe: Literal["d", "h", "w", "o"] = "d", + siglip_inputs: list[Tensor] | None = None, + ): + x = get_noise( + 1, height, width, device=self.device, dtype=torch.bfloat16, seed=seed + ) + timesteps = get_schedule( + num_steps, + (width // 8) * (height // 8) // (16 * 16), + shift=True, + ) + if self.offload: + self.ae.encoder = self.ae.encoder.to(self.device) + x_1_refs = [ + self.ae.encode( + (TVF.to_tensor(ref_img) * 2.0 - 1.0) + .unsqueeze(0) + .to(self.device, torch.float32) + ).to(torch.bfloat16) + for ref_img in ref_imgs + ] + + if self.offload: + self.offload_model_to_cpu(self.ae.encoder) + self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device) + inp_cond = prepare_multi_ip( + t5=self.t5, + clip=self.clip, + img=x, + prompt=prompt, + ref_imgs=x_1_refs, + pe=pe, + ) + + if self.offload: + self.offload_model_to_cpu(self.t5, self.clip) + self.model = self.model.to(self.device) + + x = denoise( + self.model, + **inp_cond, + timesteps=timesteps, + guidance=guidance, + siglip_inputs=siglip_inputs, + ) + + if self.offload: + self.offload_model_to_cpu(self.model) + self.ae.decoder.to(x.device) + x = unpack(x.float(), height, width) + x = self.ae.decode(x) + self.offload_model_to_cpu(self.ae.decoder) + + x1 = x.clamp(-1, 1) + x1 = rearrange(x1[-1], "c h w -> h w c") + output_img = Image.fromarray((127.5 * (x1 + 1.0)).cpu().byte().numpy()) + return output_img + + def offload_model_to_cpu(self, *models): + if not self.offload: + return + for model in models: + model.cpu() + torch.cuda.empty_cache() diff --git a/uso/flux/sampling.py b/uso/flux/sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..a4611fa8bef1c446d760a05a61f6d086dc5a8ad0 --- /dev/null +++ b/uso/flux/sampling.py @@ -0,0 +1,274 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. +# Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Literal + +import torch +from einops import rearrange, repeat +from torch import Tensor +from tqdm import tqdm + +from .model import Flux +from .modules.conditioner import HFEmbedder + + +def get_noise( + num_samples: int, + height: int, + width: int, + device: torch.device, + dtype: torch.dtype, + seed: int, +): + return torch.randn( + num_samples, + 16, + # allow for packing + 2 * math.ceil(height / 16), + 2 * math.ceil(width / 16), + device=device, + dtype=dtype, + generator=torch.Generator(device=device).manual_seed(seed), + ) + + +def prepare( + t5: HFEmbedder, + clip: HFEmbedder, + img: Tensor, + prompt: str | list[str], + ref_img: None | Tensor = None, + pe: Literal["d", "h", "w", "o"] = "d", +) -> dict[str, Tensor]: + assert pe in ["d", "h", "w", "o"] + bs, c, h, w = img.shape + if bs == 1 and not isinstance(prompt, str): + bs = len(prompt) + + img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + if img.shape[0] == 1 and bs > 1: + img = repeat(img, "1 ... -> bs ...", bs=bs) + + img_ids = torch.zeros(h // 2, w // 2, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + + if ref_img is not None: + _, _, ref_h, ref_w = ref_img.shape + ref_img = rearrange( + ref_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2 + ) + if ref_img.shape[0] == 1 and bs > 1: + ref_img = repeat(ref_img, "1 ... -> bs ...", bs=bs) + ref_img_ids = torch.zeros(ref_h // 2, ref_w // 2, 3) + # img id分别在宽高偏移各自最大值 + h_offset = h // 2 if pe in {"d", "h"} else 0 + w_offset = w // 2 if pe in {"d", "w"} else 0 + ref_img_ids[..., 1] = ( + ref_img_ids[..., 1] + torch.arange(ref_h // 2)[:, None] + h_offset + ) + ref_img_ids[..., 2] = ( + ref_img_ids[..., 2] + torch.arange(ref_w // 2)[None, :] + w_offset + ) + ref_img_ids = repeat(ref_img_ids, "h w c -> b (h w) c", b=bs) + + if isinstance(prompt, str): + prompt = [prompt] + txt = t5(prompt) + if txt.shape[0] == 1 and bs > 1: + txt = repeat(txt, "1 ... -> bs ...", bs=bs) + txt_ids = torch.zeros(bs, txt.shape[1], 3) + + vec = clip(prompt) + if vec.shape[0] == 1 and bs > 1: + vec = repeat(vec, "1 ... -> bs ...", bs=bs) + + if ref_img is not None: + return { + "img": img, + "img_ids": img_ids.to(img.device), + "ref_img": ref_img, + "ref_img_ids": ref_img_ids.to(img.device), + "txt": txt.to(img.device), + "txt_ids": txt_ids.to(img.device), + "vec": vec.to(img.device), + } + else: + return { + "img": img, + "img_ids": img_ids.to(img.device), + "txt": txt.to(img.device), + "txt_ids": txt_ids.to(img.device), + "vec": vec.to(img.device), + } + + +def prepare_multi_ip( + t5: HFEmbedder, + clip: HFEmbedder, + img: Tensor, + prompt: str | list[str], + ref_imgs: list[Tensor] | None = None, + pe: Literal["d", "h", "w", "o"] = "d", +) -> dict[str, Tensor]: + assert pe in ["d", "h", "w", "o"] + bs, c, h, w = img.shape + if bs == 1 and not isinstance(prompt, str): + bs = len(prompt) + + # tgt img + img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + if img.shape[0] == 1 and bs > 1: + img = repeat(img, "1 ... -> bs ...", bs=bs) + + img_ids = torch.zeros(h // 2, w // 2, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + + ref_img_ids = [] + ref_imgs_list = [] + + pe_shift_w, pe_shift_h = w // 2, h // 2 + for ref_img in ref_imgs: + _, _, ref_h1, ref_w1 = ref_img.shape + ref_img = rearrange( + ref_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2 + ) + if ref_img.shape[0] == 1 and bs > 1: + ref_img = repeat(ref_img, "1 ... -> bs ...", bs=bs) + ref_img_ids1 = torch.zeros(ref_h1 // 2, ref_w1 // 2, 3) + # img id分别在宽高偏移各自最大值 + h_offset = pe_shift_h if pe in {"d", "h"} else 0 + w_offset = pe_shift_w if pe in {"d", "w"} else 0 + ref_img_ids1[..., 1] = ( + ref_img_ids1[..., 1] + torch.arange(ref_h1 // 2)[:, None] + h_offset + ) + ref_img_ids1[..., 2] = ( + ref_img_ids1[..., 2] + torch.arange(ref_w1 // 2)[None, :] + w_offset + ) + ref_img_ids1 = repeat(ref_img_ids1, "h w c -> b (h w) c", b=bs) + ref_img_ids.append(ref_img_ids1) + ref_imgs_list.append(ref_img) + + # 更新pe shift + pe_shift_h += ref_h1 // 2 + pe_shift_w += ref_w1 // 2 + + if isinstance(prompt, str): + prompt = [prompt] + txt = t5(prompt) + if txt.shape[0] == 1 and bs > 1: + txt = repeat(txt, "1 ... -> bs ...", bs=bs) + txt_ids = torch.zeros(bs, txt.shape[1], 3) + + vec = clip(prompt) + if vec.shape[0] == 1 and bs > 1: + vec = repeat(vec, "1 ... -> bs ...", bs=bs) + + return { + "img": img, + "img_ids": img_ids.to(img.device), + "ref_img": tuple(ref_imgs_list), + "ref_img_ids": [ref_img_id.to(img.device) for ref_img_id in ref_img_ids], + "txt": txt.to(img.device), + "txt_ids": txt_ids.to(img.device), + "vec": vec.to(img.device), + } + + +def time_shift(mu: float, sigma: float, t: Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def get_lin_function( + x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15 +): + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + +def get_schedule( + num_steps: int, + image_seq_len: int, + base_shift: float = 0.5, + max_shift: float = 1.15, + shift: bool = True, +) -> list[float]: + # extra step for zero + timesteps = torch.linspace(1, 0, num_steps + 1) + + # shifting the schedule to favor high timesteps for higher signal images + if shift: + # eastimate mu based on linear estimation between two points + mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) + timesteps = time_shift(mu, 1.0, timesteps) + + return timesteps.tolist() + + +def denoise( + model: Flux, + # model input + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + vec: Tensor, + # sampling parameters + timesteps: list[float], + guidance: float = 4.0, + ref_img: Tensor = None, + ref_img_ids: Tensor = None, + siglip_inputs: list[Tensor] | None = None, +): + i = 0 + guidance_vec = torch.full( + (img.shape[0],), guidance, device=img.device, dtype=img.dtype + ) + for t_curr, t_prev in tqdm( + zip(timesteps[:-1], timesteps[1:]), total=len(timesteps) - 1 + ): + # for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]): + t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + pred = model( + img=img, + img_ids=img_ids, + ref_img=ref_img, + ref_img_ids=ref_img_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_vec, + guidance=guidance_vec, + siglip_inputs=siglip_inputs, + ) + img = img + (t_prev - t_curr) * pred + i += 1 + return img + + +def unpack(x: Tensor, height: int, width: int) -> Tensor: + return rearrange( + x, + "b (h w) (c ph pw) -> b c (h ph) (w pw)", + h=math.ceil(height / 16), + w=math.ceil(width / 16), + ph=2, + pw=2, + ) diff --git a/uso/flux/util.py b/uso/flux/util.py new file mode 100644 index 0000000000000000000000000000000000000000..38e40b64885b8d3ab2c466c5cd19368463caf09a --- /dev/null +++ b/uso/flux/util.py @@ -0,0 +1,511 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. +# Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from dataclasses import dataclass + +import torch +import json +import numpy as np +from huggingface_hub import hf_hub_download +from safetensors import safe_open +from safetensors.torch import load_file as load_sft + +from .model import Flux, FluxParams +from .modules.autoencoder import AutoEncoder, AutoEncoderParams +from .modules.conditioner import HFEmbedder + +import re +from uso.flux.modules.layers import ( + DoubleStreamBlockLoraProcessor, + SingleStreamBlockLoraProcessor, +) + + +def load_model(ckpt, device="cpu"): + if ckpt.endswith("safetensors"): + from safetensors import safe_open + + pl_sd = {} + with safe_open(ckpt, framework="pt", device=device) as f: + for k in f.keys(): + pl_sd[k] = f.get_tensor(k) + else: + pl_sd = torch.load(ckpt, map_location=device) + return pl_sd + + +def load_safetensors(path): + tensors = {} + with safe_open(path, framework="pt", device="cpu") as f: + for key in f.keys(): + tensors[key] = f.get_tensor(key) + return tensors + + +def get_lora_rank(checkpoint): + for k in checkpoint.keys(): + if k.endswith(".down.weight"): + return checkpoint[k].shape[0] + + +def load_checkpoint(local_path, repo_id, name): + if local_path is not None: + if ".safetensors" in local_path: + print(f"Loading .safetensors checkpoint from {local_path}") + checkpoint = load_safetensors(local_path) + else: + print(f"Loading checkpoint from {local_path}") + checkpoint = torch.load(local_path, map_location="cpu") + elif repo_id is not None and name is not None: + print(f"Loading checkpoint {name} from repo id {repo_id}") + checkpoint = load_from_repo_id(repo_id, name) + else: + raise ValueError( + "LOADING ERROR: you must specify local_path or repo_id with name in HF to download" + ) + return checkpoint + + +def c_crop(image): + width, height = image.size + new_size = min(width, height) + left = (width - new_size) / 2 + top = (height - new_size) / 2 + right = (width + new_size) / 2 + bottom = (height + new_size) / 2 + return image.crop((left, top, right, bottom)) + + +def pad64(x): + return int(np.ceil(float(x) / 64.0) * 64 - x) + + +def HWC3(x): + assert x.dtype == np.uint8 + if x.ndim == 2: + x = x[:, :, None] + assert x.ndim == 3 + H, W, C = x.shape + assert C == 1 or C == 3 or C == 4 + if C == 3: + return x + if C == 1: + return np.concatenate([x, x, x], axis=2) + if C == 4: + color = x[:, :, 0:3].astype(np.float32) + alpha = x[:, :, 3:4].astype(np.float32) / 255.0 + y = color * alpha + 255.0 * (1.0 - alpha) + y = y.clip(0, 255).astype(np.uint8) + return y + + +@dataclass +class ModelSpec: + params: FluxParams + ae_params: AutoEncoderParams + ckpt_path: str | None + ae_path: str | None + repo_id: str | None + repo_flow: str | None + repo_ae: str | None + repo_id_ae: str | None + + +configs = { + "flux-dev": ModelSpec( + repo_id="black-forest-labs/FLUX.1-dev", + repo_id_ae="black-forest-labs/FLUX.1-dev", + repo_flow="flux1-dev.safetensors", + repo_ae="ae.safetensors", + ckpt_path=os.getenv("FLUX_DEV"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + ), + ae_path=os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), + "flux-dev-fp8": ModelSpec( + repo_id="black-forest-labs/FLUX.1-dev", + repo_id_ae="black-forest-labs/FLUX.1-dev", + repo_flow="flux1-dev.safetensors", + repo_ae="ae.safetensors", + ckpt_path=os.getenv("FLUX_DEV_FP8"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + ), + ae_path=os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), + "flux-krea-dev": ModelSpec( + repo_id="black-forest-labs/FLUX.1-Krea-dev", + repo_id_ae="black-forest-labs/FLUX.1-Krea-dev", + repo_flow="flux1-krea-dev.safetensors", + repo_ae="ae.safetensors", + ckpt_path=os.getenv("FLUX_KREA_DEV"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + ), + ae_path=os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), + "flux-schnell": ModelSpec( + repo_id="black-forest-labs/FLUX.1-schnell", + repo_id_ae="black-forest-labs/FLUX.1-dev", + repo_flow="flux1-schnell.safetensors", + repo_ae="ae.safetensors", + ckpt_path=os.getenv("FLUX_SCHNELL"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=False, + ), + ae_path=os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), +} + + +def print_load_warning(missing: list[str], unexpected: list[str]) -> None: + if len(missing) > 0 and len(unexpected) > 0: + print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) + print("\n" + "-" * 79 + "\n") + print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) + elif len(missing) > 0: + print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) + elif len(unexpected) > 0: + print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) + + +def load_from_repo_id(repo_id, checkpoint_name): + ckpt_path = hf_hub_download(repo_id, checkpoint_name) + sd = load_sft(ckpt_path, device="cpu") + return sd + + +def load_flow_model( + name: str, device: str | torch.device = "cuda", hf_download: bool = True +): + # Loading Flux + print("Init model") + ckpt_path = configs[name].ckpt_path + if ( + ckpt_path is None + and configs[name].repo_id is not None + and configs[name].repo_flow is not None + ): + ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow) + + # with torch.device("meta" if ckpt_path is not None else device): + with torch.device(device): + model = Flux(configs[name].params).to(torch.bfloat16) + + if ckpt_path is not None: + print("Loading main checkpoint") + # load_sft doesn't support torch.device + sd = load_model(ckpt_path, device="cpu") + missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) + print_load_warning(missing, unexpected) + return model.to(str(device)) + + +def load_flow_model_only_lora( + name: str, + device: str | torch.device = "cuda", + hf_download: bool = True, + lora_rank: int = 16, + use_fp8: bool = False, +): + # Loading Flux + ckpt_path = configs[name].ckpt_path + if ( + ckpt_path is None + and configs[name].repo_id is not None + and configs[name].repo_flow is not None + ): + ckpt_path = hf_hub_download( + configs[name].repo_id, configs[name].repo_flow.replace("sft", "safetensors") + ) + + if hf_download: + try: + lora_ckpt_path = hf_hub_download( + "bytedance-research/USO", "uso_flux_v1.0/dit_lora.safetensors" + ) + except Exception as e: + print(f"Failed to download lora checkpoint: {e}") + print("Trying to load lora from local") + lora_ckpt_path = os.environ.get("LORA", None) + try: + proj_ckpt_path = hf_hub_download( + "bytedance-research/USO", "uso_flux_v1.0/projector.safetensors" + ) + except Exception as e: + print(f"Failed to download projection_model checkpoint: {e}") + print("Trying to load projection_model from local") + proj_ckpt_path = os.environ.get("PROJECTION_MODEL", None) + else: + lora_ckpt_path = os.environ.get("LORA", None) + proj_ckpt_path = os.environ.get("PROJECTION_MODEL", None) + with torch.device("meta" if ckpt_path is not None else device): + model = Flux(configs[name].params) + + model = set_lora( + model, lora_rank, device="meta" if lora_ckpt_path is not None else device + ) + + if ckpt_path is not None: + print(f"Loading lora from {lora_ckpt_path}") + lora_sd = ( + load_sft(lora_ckpt_path, device=str(device)) + if lora_ckpt_path.endswith("safetensors") + else torch.load(lora_ckpt_path, map_location="cpu") + ) + proj_sd = ( + load_sft(proj_ckpt_path, device=str(device)) + if proj_ckpt_path.endswith("safetensors") + else torch.load(proj_ckpt_path, map_location="cpu") + ) + lora_sd.update(proj_sd) + + print("Loading main checkpoint") + # load_sft doesn't support torch.device + + if ckpt_path.endswith("safetensors"): + if use_fp8: + print( + "####\n" + "We are in fp8 mode right now, since the fp8 checkpoint of XLabs-AI/flux-dev-fp8 seems broken\n" + "we convert the fp8 checkpoint on flight from bf16 checkpoint\n" + "If your storage is constrained" + "you can save the fp8 checkpoint and replace the bf16 checkpoint by yourself\n" + ) + sd = load_sft(ckpt_path, device="cpu") + sd = { + k: v.to(dtype=torch.float8_e4m3fn, device=device) + for k, v in sd.items() + } + else: + sd = load_sft(ckpt_path, device=str(device)) + + sd.update(lora_sd) + missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) + else: + dit_state = torch.load(ckpt_path, map_location="cpu") + sd = {} + for k in dit_state.keys(): + sd[k.replace("module.", "")] = dit_state[k] + sd.update(lora_sd) + missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) + model.to(str(device)) + print_load_warning(missing, unexpected) + return model + + +def set_lora( + model: Flux, + lora_rank: int, + double_blocks_indices: list[int] | None = None, + single_blocks_indices: list[int] | None = None, + device: str | torch.device = "cpu", +) -> Flux: + double_blocks_indices = ( + list(range(model.params.depth)) + if double_blocks_indices is None + else double_blocks_indices + ) + single_blocks_indices = ( + list(range(model.params.depth_single_blocks)) + if single_blocks_indices is None + else single_blocks_indices + ) + + lora_attn_procs = {} + with torch.device(device): + for name, attn_processor in model.attn_processors.items(): + match = re.search(r"\.(\d+)\.", name) + if match: + layer_index = int(match.group(1)) + + if ( + name.startswith("double_blocks") + and layer_index in double_blocks_indices + ): + lora_attn_procs[name] = DoubleStreamBlockLoraProcessor( + dim=model.params.hidden_size, rank=lora_rank + ) + elif ( + name.startswith("single_blocks") + and layer_index in single_blocks_indices + ): + lora_attn_procs[name] = SingleStreamBlockLoraProcessor( + dim=model.params.hidden_size, rank=lora_rank + ) + else: + lora_attn_procs[name] = attn_processor + model.set_attn_processor(lora_attn_procs) + return model + + +def load_flow_model_quintized( + name: str, device: str | torch.device = "cuda", hf_download: bool = True +): + # Loading Flux + from optimum.quanto import requantize + + print("Init model") + ckpt_path = configs[name].ckpt_path + if ( + ckpt_path is None + and configs[name].repo_id is not None + and configs[name].repo_flow is not None + and hf_download + ): + ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow) + json_path = hf_hub_download(configs[name].repo_id, "flux_dev_quantization_map.json") + + model = Flux(configs[name].params).to(torch.bfloat16) + + print("Loading checkpoint") + # load_sft doesn't support torch.device + sd = load_sft(ckpt_path, device="cpu") + sd = {k: v.to(dtype=torch.float8_e4m3fn, device=device) for k, v in sd.items()} + model.load_state_dict(sd, assign=True) + return model + with open(json_path, "r") as f: + quantization_map = json.load(f) + print("Start a quantization process...") + requantize(model, sd, quantization_map, device=device) + print("Model is quantized!") + return model + + +def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder: + # max length 64, 128, 256 and 512 should work (if your sequence is short enough) + version = os.environ.get("T5", "xlabs-ai/xflux_text_encoders") + return HFEmbedder(version, max_length=max_length, torch_dtype=torch.bfloat16).to( + device + ) + + +def load_clip(device: str | torch.device = "cuda") -> HFEmbedder: + version = os.environ.get("CLIP", "openai/clip-vit-large-patch14") + return HFEmbedder(version, max_length=77, torch_dtype=torch.bfloat16).to(device) + + +def load_ae( + name: str, device: str | torch.device = "cuda", hf_download: bool = True +) -> AutoEncoder: + ckpt_path = configs[name].ae_path + if ( + ckpt_path is None + and configs[name].repo_id is not None + and configs[name].repo_ae is not None + and hf_download + ): + ckpt_path = hf_hub_download(configs[name].repo_id_ae, configs[name].repo_ae) + + # Loading the autoencoder + print("Init AE") + with torch.device("meta" if ckpt_path is not None else device): + ae = AutoEncoder(configs[name].ae_params) + + if ckpt_path is not None: + sd = load_sft(ckpt_path, device=str(device)) + missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True) + print_load_warning(missing, unexpected) + return ae