File size: 5,852 Bytes
a156452
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
031cf60
a156452
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7f9353
a156452
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6d7b6f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import spaces
import gradio as gr
import torch
from PIL import Image
from transformers import AutoTokenizer
from diffusers import StableDiffusionXLPipeline
from huggingface_hub import hf_hub_download
from model import EmotionInjectionTransformer
from transformers import GPT2Config

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Initialize Emotion Injection Model
config = GPT2Config.from_pretrained('gpt2')
emotion_add_method = {"a": "cross", "v": "cross"}
model = EmotionInjectionTransformer(config, final_out_type="Linear+LN").to(device)
model = torch.nn.DataParallel(model)

# Initialize Stable Diffusion XL Pipeline
pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    use_safetensors=True
)
pipe.to(device)

@spaces.GPU
def generate_image(prompt, arousal, valence, model_scale, seed=24):
    # Map scales to checkpoint filenames in the Hugging Face repo
    model_checkpoints = {
        1.0: 'scale_factor_1.0.pth',
        1.25: 'scale_factor_1.25.pth',
        1.5: 'scale_factor_1.5.pth',
        1.75: 'scale_factor_1.75.pth',
        2.0: 'scale_factor_2.0.pth'
    }

    # Download the corresponding checkpoint from the Hugging Face Hub
    if model_scale in model_checkpoints:
        filename = model_checkpoints[model_scale]
        model_path = hf_hub_download(
            repo_id="idvxlab/EmotiCrafter",
            filename=filename
        )
        state_dict = torch.load(model_path, map_location=device)
        model.load_state_dict(state_dict)
    else:
        raise ValueError(f"Model scale {model_scale} not found in hosted checkpoints.")

    model.eval()

    # Encode prompt into embeddings
    (prompt_embeds_ori,
     negative_prompt_embeds,
     pooled_prompt_embeds_ori,
     negative_pooled_prompt_embeds) = pipe.encode_prompt(
        prompt=[prompt],
        prompt_2=[prompt],
        device=device,
        num_images_per_prompt=1,
        do_classifier_free_guidance=True,
        negative_prompt=None,
        negative_prompt_2=None
    )

    resolution = 1024

    with torch.no_grad():
        # Inject emotions into embeddings
        out = model(
            inputs_embeds=prompt_embeds_ori.to(torch.float32),
            arousal=torch.FloatTensor([[arousal]]).to(device),
            valence=torch.FloatTensor([[valence]]).to(device)
        )

        # Generate image with or without seed
        gen_kwargs = dict(
            prompt_embeds=out[0].to(torch.float16),
            pooled_prompt_embeds=pooled_prompt_embeds_ori,
            guidance_scale=7.5,
            num_inference_steps=40,
            height=resolution,
            width=resolution
        )
        if seed is not None:
            gen_kwargs['generator'] = torch.manual_seed(seed)

        image = pipe(**gen_kwargs).images[0]
        return image

# Gradio UI
css = """
#small-image {
    width: 50%;
    margin: 0 auto;
}
"""

def gradio_interface(prompt, arousal, valence, model_scale, seed=42):
    return generate_image(prompt, arousal, valence, model_scale, seed)

html_content = """
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
  <div>
    <h1>Emoticrafter</h1>
    <span>Emotion-based image generation using Stable Diffusion XL</span>
    <br>
    <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
      <a href="http://arxiv.org/abs/2501.05710"><img src="https://img.shields.io/badge/arXiv-2501.05710-red"></a>
      <a href="https://github.com/idvxlab/EmotiCrafter"><img src="https://img.shields.io/badge/Github-Code-blue"></a>
    </div>
  </div>
</div>
"""

with gr.Blocks() as iface:
    gr.HTML(html_content)
    description = """
        **You can inject emotions into pictures by adjusting the values of arousal and valence!**
        The Arousal-Valence model is a two-dimensional framework used in psychology and affective computing to describe emotional states.
        - **Valence**: Measures the degree of emotional pleasantness, ranging from negative (e.g., sadness, anger) to positive (e.g., happiness, satisfaction). Scale: -3 (very unpleasant) to 3 (very pleasant).
        - **Arousal**: Measures level of emotional activation, from low (e.g., calm) to high (e.g., excited). Scale: -3 (very calm) to 3 (very excited).
    """
    gr.Markdown(description)

    with gr.Row():
        with gr.Column(scale=1.5):
            gr.Markdown("<i>Arousal-Valence Model</i>")
            gr.Image("assets/emotion.png", label="Emotion Coordinate System")
        with gr.Column(scale=2):
            gr.Markdown("<i>From left to right: Valence increases</i>")
            gr.Image("assets/output_image.png", label="Valence increasing")
            gr.Markdown("<i>From left to right: Arousal increases</i>")
            gr.Image("assets/output_image3.png", label="Arousal increasing")

    with gr.Row():
        with gr.Column(scale=2.25):
            prompt = gr.Textbox(label="Prompt", placeholder="Enter the prompt for image generation")
            arousal_slider = gr.Slider(minimum=-3.0, maximum=3.0, step=0.1, label="Arousal", value=0.0)
            valence_slider = gr.Slider(minimum=-3.0, maximum=3.0, step=0.1, label="Valence", value=0.0)
            model_slider = gr.Slider(minimum=1.0, maximum=2.0, step=0.25, label="Model Scale", value=1.5)
            seed = gr.Slider(0, 10000000, step=1, label="Seed", value=42)
            submit_btn = gr.Button("Generate")

        with gr.Column(scale=5):
            output_image = gr.Image(type="pil", height=1024, width=1024)

    submit_btn.click(fn=gradio_interface, inputs=[prompt, arousal_slider, valence_slider, model_slider, seed], outputs=output_image)

if __name__ == "__main__":
    iface.launch(debug=True)