File size: 4,739 Bytes
c3b82c2
b197ccc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8eeff9c
b197ccc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a477d4
b197ccc
 
 
 
 
 
38274da
b197ccc
 
 
 
 
 
 
399e621
b197ccc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a477d4
b197ccc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a477d4
 
 
 
 
 
 
 
 
 
 
 
b197ccc
 
 
 
 
 
 
 
 
54d5b9f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import spaces
import os
import torch
from diffusers import StableDiffusionXLPipeline
import gradio as gr
from huggingface_hub import hf_hub_download, snapshot_download
from nested_attention_pipeline import NestedAdapterInference, add_special_token_to_tokenizer
from utils import align_face


# ----------------------
# Configuration (update paths as needed)
# ----------------------
base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
image_encoder_path = snapshot_download("orpatashnik/NestedAttentionEncoder", allow_patterns=["image_encoder/**"])
image_encoder_path = os.path.join(image_encoder_path, "image_encoder")
personalization_ckpt = hf_hub_download("orpatashnik/NestedAttentionEncoder", "personalization_encoder/model.safetensors")
device = "cuda"

# Special token settings
placeholder_token = "<person>"
initializer_token = "person"

# ----------------------
# Load models
# ----------------------
pipe = StableDiffusionXLPipeline.from_pretrained(
    base_model_path,
    torch_dtype=torch.float16,
)
add_special_token_to_tokenizer(pipe, placeholder_token, initializer_token)
ip_model = NestedAdapterInference(
    pipe,
    image_encoder_path,
    personalization_ckpt,
    1024,
    vq_normalize_factor=2.0,
    device=device
)

# Generation defaults
negative_prompt = "bad anatomy, monochrome, lowres, worst quality, low quality"
num_inference_steps = 30
guidance_scale = 5.0

# ----------------------
# Inference function with alignment
# ----------------------
@spaces.GPU
def generate_images(img1, img2, img3, prompt, w, num_samples, seed):
    # Collect non-empty reference images
    refs = [img for img in (img1, img2, img3) if img is not None]
    if not refs:
        return []

    # Align directly on PIL
    aligned_refs = [align_face(img) for img in refs]

    # Resize to model resolution
    pil_images = [aligned.resize((512, 512)) for aligned in aligned_refs]
    placeholder_token_ids = ip_model.pipe.tokenizer.convert_tokens_to_ids([placeholder_token])

    # Generate personalized samples
    results = ip_model.generate(
        pil_image=pil_images,
        prompt=prompt,
        negative_prompt=negative_prompt,
        num_samples=num_samples,
        num_inference_steps=num_inference_steps,
        placeholder_token_ids=placeholder_token_ids,
        seed=seed if seed > 0 else None,
        guidance_scale=guidance_scale,
        multiple_images=True,
        special_token_weight=w
    )
    return results

# ----------------------
# Gradio UI
# ----------------------
with gr.Blocks() as demo:
    gr.Markdown("## Nested Attention: Semantic-aware Attention Values for Concept Personalization")
    gr.Markdown(
        "Upload up to 3 reference images. "
        "Faces will be auto-aligned before personalization. Include the placeholder token (e.g., \\<person\\>) in your prompt, "
        "set token weight, and choose how many outputs you want."
    )
    with gr.Row():
        with gr.Column(scale=1):
            # Reference images
            with gr.Row():
                img1 = gr.Image(type="pil", label="Reference Image 1")
                img2 = gr.Image(type="pil", label="Reference Image 2 (optional)")
                img3 = gr.Image(type="pil", label="Reference Image 3 (optional)")
            prompt_input = gr.Textbox(label="Prompt", placeholder="e.g., an abstract pencil drawing of a <person>")
            w_input = gr.Slider(minimum=1.0, maximum=5.0, step=0.5, value=1.0, label="Special Token Weight (w)")
            num_samples_input = gr.Slider(minimum=1, maximum=6, step=1, value=4, label="Number of Images to Generate")
            seed_input = gr.Slider(minimum=-1, maximum=100000, step=1, value=-1, label="Random Seed (use -1 for random and up to 100000)")
            generate_button = gr.Button("Generate Images")

            # Add examples
            gr.Examples(
                examples=[
                    ["example_images/01.jpg", None, None, "a pop figure of a <person>, she stands on a white background", 2.0, 4, 1],
                    ["example_images/01.jpg", None, None, "a watercolor painting of a <person>, closeup", 1.0, 4, 42],
                    ["example_images/01.jpg", None, None, "a high quality photo of a <person> as a firefighter", 3.0, 4, 10],
                ],
                inputs=[img1, img2, img3, prompt_input, w_input, num_samples_input, seed_input],
                label="Example Prompts"
            )

        with gr.Column(scale=1):
            output_gallery = gr.Gallery(label="Generated Images", columns=3)

    generate_button.click(
        fn=generate_images,
        inputs=[img1, img2, img3, prompt_input, w_input, num_samples_input, seed_input],
        outputs=output_gallery
    )

demo.launch()