File size: 4,510 Bytes
b197ccc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import torch
from diffusers import StableDiffusionXLPipeline
import gradio as gr
from huggingface_hub import hf_hub_download, snapshot_download
from nested_attention_pipeline import NestedAdapterInference, add_special_token_to_tokenizer
from utils import align_face
import dlib


# ----------------------
# Configuration (update paths as needed)
# ----------------------
SHAPE_PREDICTOR_PATH = hf_hub_download("orpatashnik/NestedAttentionEncoder", "shape_predictor_68_face_landmarks.dat")
FACE_DETECTOR_PATH = hf_hub_download("orpatashnik/NestedAttentionEncoder", "mmod_human_face_detector.dat")
base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
image_encoder_path = snapshot_download("orpatashnik/NestedAttentionEncoder", allow_patterns=["image_encoder/**"])
image_encoder_path = os.path.join(image_encoder_path, "image_encoder")
personalization_ckpt = hf_hub_download("orpatashnik/NestedAttentionEncoder", "personalization_encoder/pytorch_model.bin")
device = "cuda"

# Special token settings
placeholder_token = "<person>"
initializer_token = "person"

# ----------------------
# Load models
# ----------------------
pipe = StableDiffusionXLPipeline.from_pretrained(
    base_model_path,
    torch_dtype=torch.float16,
)
add_special_token_to_tokenizer(pipe, placeholder_token, initializer_token)
ip_model = NestedAdapterInference(
    pipe,
    image_encoder_path,
    personalization_ckpt,
    1024,
    vq_normalize_factor=2.0,
    device=device
)

# Initialize face alignment predictor
predictor = dlib.shape_predictor(SHAPE_PREDICTOR_PATH)
detector = dlib.cnn_face_detection_model_v1(FACE_DETECTOR_PATH) 

# Generation defaults
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
num_inference_steps = 30
guidance_scale = 5.0

# ----------------------
# Inference function with alignment
# ----------------------
def generate_images(img1, img2, img3, prompt, w, num_samples, seed):
    # Collect non-empty reference images
    refs = [img for img in (img1, img2, img3) if img is not None]
    if not refs:
        return []

    # Align directly on PIL
    aligned_refs = [align_face(img, predictor, detector) for img in refs]

    # Resize to model resolution
    pil_images = [aligned.resize((512, 512)) for aligned in aligned_refs]
    placeholder_token_ids = ip_model.pipe.tokenizer.convert_tokens_to_ids([placeholder_token])

    # Generate personalized samples
    results = ip_model.generate(
        pil_image=pil_images,
        prompt=prompt,
        negative_prompt=negative_prompt,
        num_samples=num_samples,
        num_inference_steps=num_inference_steps,
        placeholder_token_ids=placeholder_token_ids,
        seed=seed if seed > 0 else None,
        guidance_scale=guidance_scale,
        multiple_images=True,
        special_token_weight=w
    )
    return results

# ----------------------
# Gradio UI
# ----------------------
with gr.Blocks() as demo:
    gr.Markdown("## Personalized Image Generation Demo")
    gr.Markdown(
        "Upload up to 3 reference images. "
        "Faces will be auto-aligned before personalization. Include the placeholder token (e.g., \\<person\\>) in your prompt, "
        "set token weight, and choose how many outputs you want."
    )
    with gr.Row():
        with gr.Column(scale=1):
            # Reference images
            with gr.Row():
                img1 = gr.Image(type="pil", label="Reference Image 1")
                img2 = gr.Image(type="pil", label="Reference Image 2 (optional)")
                img3 = gr.Image(type="pil", label="Reference Image 3 (optional)")
            prompt_input = gr.Textbox(label="Prompt", placeholder="e.g., an abstract pencil drawing of a <person>")
            w_input = gr.Slider(minimum=1.0, maximum=5.0, step=0.5, value=1.0, label="Special Token Weight (w)")
            num_samples_input = gr.Slider(minimum=1, maximum=6, step=1, value=4, label="Number of Images to Generate")
            seed_input = gr.Slider(minimum=-1, maximum=100000, step=1, value=-1, label="Random Seed (use -1 for random and up to 100000)")
            generate_button = gr.Button("Generate Images")
        with gr.Column(scale=1):
            output_gallery = gr.Gallery(label="Generated Images", columns=3)

    generate_button.click(
        fn=generate_images,
        inputs=[img1, img2, img3, prompt_input, w_input, num_samples_input, seed_input],
        outputs=output_gallery
    )

if __name__ == "__main__":
    demo.launch(share=True, debug=True)