File size: 4,040 Bytes
b197ccc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399e621
b197ccc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54d5b9f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
import torch
from diffusers import StableDiffusionXLPipeline
import gradio as gr
from huggingface_hub import hf_hub_download, snapshot_download
from nested_attention_pipeline import NestedAdapterInference, add_special_token_to_tokenizer
from utils import align_face


# ----------------------
# Configuration (update paths as needed)
# ----------------------
base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
image_encoder_path = snapshot_download("orpatashnik/NestedAttentionEncoder", allow_patterns=["image_encoder/**"])
image_encoder_path = os.path.join(image_encoder_path, "image_encoder")
personalization_ckpt = hf_hub_download("orpatashnik/NestedAttentionEncoder", "personalization_encoder/pytorch_model.bin")
device = "cuda"

# Special token settings
placeholder_token = "<person>"
initializer_token = "person"

# ----------------------
# Load models
# ----------------------
pipe = StableDiffusionXLPipeline.from_pretrained(
    base_model_path,
    torch_dtype=torch.float16,
)
add_special_token_to_tokenizer(pipe, placeholder_token, initializer_token)
ip_model = NestedAdapterInference(
    pipe,
    image_encoder_path,
    personalization_ckpt,
    1024,
    vq_normalize_factor=2.0,
    device=device
)

# Generation defaults
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
num_inference_steps = 30
guidance_scale = 5.0

# ----------------------
# Inference function with alignment
# ----------------------
def generate_images(img1, img2, img3, prompt, w, num_samples, seed):
    # Collect non-empty reference images
    refs = [img for img in (img1, img2, img3) if img is not None]
    if not refs:
        return []

    # Align directly on PIL
    aligned_refs = [align_face(img) for img in refs]

    # Resize to model resolution
    pil_images = [aligned.resize((512, 512)) for aligned in aligned_refs]
    placeholder_token_ids = ip_model.pipe.tokenizer.convert_tokens_to_ids([placeholder_token])

    # Generate personalized samples
    results = ip_model.generate(
        pil_image=pil_images,
        prompt=prompt,
        negative_prompt=negative_prompt,
        num_samples=num_samples,
        num_inference_steps=num_inference_steps,
        placeholder_token_ids=placeholder_token_ids,
        seed=seed if seed > 0 else None,
        guidance_scale=guidance_scale,
        multiple_images=True,
        special_token_weight=w
    )
    return results

# ----------------------
# Gradio UI
# ----------------------
with gr.Blocks() as demo:
    gr.Markdown("## Personalized Image Generation Demo")
    gr.Markdown(
        "Upload up to 3 reference images. "
        "Faces will be auto-aligned before personalization. Include the placeholder token (e.g., \\<person\\>) in your prompt, "
        "set token weight, and choose how many outputs you want."
    )
    with gr.Row():
        with gr.Column(scale=1):
            # Reference images
            with gr.Row():
                img1 = gr.Image(type="pil", label="Reference Image 1")
                img2 = gr.Image(type="pil", label="Reference Image 2 (optional)")
                img3 = gr.Image(type="pil", label="Reference Image 3 (optional)")
            prompt_input = gr.Textbox(label="Prompt", placeholder="e.g., an abstract pencil drawing of a <person>")
            w_input = gr.Slider(minimum=1.0, maximum=5.0, step=0.5, value=1.0, label="Special Token Weight (w)")
            num_samples_input = gr.Slider(minimum=1, maximum=6, step=1, value=4, label="Number of Images to Generate")
            seed_input = gr.Slider(minimum=-1, maximum=100000, step=1, value=-1, label="Random Seed (use -1 for random and up to 100000)")
            generate_button = gr.Button("Generate Images")
        with gr.Column(scale=1):
            output_gallery = gr.Gallery(label="Generated Images", columns=3)

    generate_button.click(
        fn=generate_images,
        inputs=[img1, img2, img3, prompt_input, w_input, num_samples_input, seed_input],
        outputs=output_gallery
    )

demo.launch()