idvxlab commited on
Commit
a156452
·
verified ·
1 Parent(s): c6d7b6f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +150 -150
app.py CHANGED
@@ -1,151 +1,151 @@
1
- import spaces
2
- import gradio as gr
3
- import torch
4
- from PIL import Image
5
- from transformers import AutoTokenizer
6
- from diffusers import StableDiffusionXLPipeline
7
- from huggingface_hub import hf_hub_download
8
- from model import EmotionInjectionTransformer
9
- from transformers import GPT2Config
10
-
11
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
12
-
13
- # Initialize Emotion Injection Model
14
- config = GPT2Config.from_pretrained('gpt2')
15
- emotion_add_method = {"a": "cross", "v": "cross"}
16
- model = EmotionInjectionTransformer(config, final_out_type="Linear+LN").to(device)
17
- model = torch.nn.DataParallel(model)
18
-
19
- # Initialize Stable Diffusion XL Pipeline
20
- pipe = StableDiffusionXLPipeline.from_pretrained(
21
- "stabilityai/stable-diffusion-xl-base-1.0",
22
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
23
- use_safetensors=True
24
- )
25
- pipe.to(device)
26
-
27
- @spaces.GPU
28
- def generate_image(prompt, arousal, valence, model_scale, seed=24):
29
- # Map scales to checkpoint filenames in the Hugging Face repo
30
- model_checkpoints = {
31
- 1.0: 'scale_factor_1.0.pth',
32
- 1.25: 'scale_factor_1.25.pth',
33
- 1.5: 'scale_factor_1.5.pth',
34
- 1.75: 'scale_factor_1.75.pth',
35
- 2.0: 'scale_factor_2.0.pth'
36
- }
37
-
38
- # Download the corresponding checkpoint from the Hugging Face Hub
39
- if model_scale in model_checkpoints:
40
- filename = model_checkpoints[model_scale]
41
- model_path = hf_hub_download(
42
- repo_id="idvxlab/EmotiCrafter",
43
- filename=filename
44
- )
45
- state_dict = torch.load(model_path, map_location=device)
46
- model.load_state_dict(state_dict)
47
- else:
48
- raise ValueError(f"Model scale {model_scale} not found in hosted checkpoints.")
49
-
50
- model.eval()
51
-
52
- # Encode prompt into embeddings
53
- (prompt_embeds_ori,
54
- negative_prompt_embeds,
55
- pooled_prompt_embeds_ori,
56
- negative_pooled_prompt_embeds) = pipe.encode_prompt(
57
- prompt=[prompt],
58
- prompt_2=[prompt],
59
- device=device,
60
- num_images_per_prompt=1,
61
- do_classifier_free_guidance=True,
62
- negative_prompt=None,
63
- negative_prompt_2=None
64
- )
65
-
66
- resolution = 1024
67
-
68
- with torch.no_grad():
69
- # Inject emotions into embeddings
70
- out = model(
71
- inputs_embeds=prompt_embeds_ori.to(torch.float32),
72
- arousal=torch.FloatTensor([[arousal]]).to(device),
73
- valence=torch.FloatTensor([[valence]]).to(device)
74
- )
75
-
76
- # Generate image with or without seed
77
- gen_kwargs = dict(
78
- prompt_embeds=out[0].to(torch.float16),
79
- pooled_prompt_embeds=pooled_prompt_embeds_ori,
80
- guidance_scale=7.5,
81
- num_inference_steps=40,
82
- height=resolution,
83
- width=resolution
84
- )
85
- if seed is not None:
86
- gen_kwargs['generator'] = torch.manual_seed(seed)
87
-
88
- image = pipe(**gen_kwargs).images[0]
89
- return image
90
-
91
- # Gradio UI
92
- css = """
93
- #small-image {
94
- width: 50%;
95
- margin: 0 auto;
96
- }
97
- """
98
-
99
- def gradio_interface(prompt, arousal, valence, model_scale, seed=42):
100
- return generate_image(prompt, arousal, valence, model_scale, seed)
101
-
102
- html_content = """
103
- <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
104
- <div>
105
- <h1>Emoticrafter</h1>
106
- <span>Emotion-based image generation using Stable Diffusion XL</span>
107
- <br>
108
- <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
109
- <a href="http://arxiv.org/abs/2501.05710"><img src="https://img.shields.io/badge/arXiv-2407.03168-red"></a>
110
- <a href="https://github.com/idvxlab/EmotiCrafter"><img src="https://img.shields.io/badge/Github-Code-blue"></a>
111
- </div>
112
- </div>
113
- </div>
114
- """
115
-
116
- with gr.Blocks() as iface:
117
- gr.HTML(html_content)
118
- description = """
119
- **You can inject emotions into pictures by adjusting the values of arousal and valence!**
120
- The Arousal-Valence model is a two-dimensional framework used in psychology and affective computing to describe emotional states.
121
- - **Valence**: Measures the degree of emotional pleasantness, ranging from negative (e.g., sadness, anger) to positive (e.g., happiness, satisfaction). Scale: -3 (very unpleasant) to 3 (very pleasant).
122
- - **Arousal**: Measures level of emotional activation, from low (e.g., calm) to high (e.g., excited). Scale: -3 (very calm) to 3 (very excited).
123
- """
124
- gr.Markdown(description)
125
-
126
- with gr.Row():
127
- with gr.Column(scale=2.25):
128
- gr.Markdown("<i>Arousal-Valence Model</i>")
129
- gr.Image("assets/emotion.png", label="Emotion Coordinate System")
130
- with gr.Column(scale=5):
131
- gr.Markdown("<i>From left to right: Valence increases</i>")
132
- gr.Image("assets/output_image.png", label="Valence increasing")
133
- gr.Markdown("<i>From left to right: Arousal increases</i>")
134
- gr.Image("assets/output_image3.png", label="Arousal increasing")
135
-
136
- with gr.Row():
137
- with gr.Column(scale=2.25):
138
- prompt = gr.Textbox(label="Prompt", placeholder="Enter the prompt for image generation")
139
- arousal_slider = gr.Slider(minimum=-3.0, maximum=3.0, step=0.1, label="Arousal", value=0.0)
140
- valence_slider = gr.Slider(minimum=-3.0, maximum=3.0, step=0.1, label="Valence", value=0.0)
141
- model_slider = gr.Slider(minimum=1.0, maximum=2.0, step=0.25, label="Model Scale", value=1.5)
142
- seed = gr.Slider(0, 10000000, step=1, label="Seed", value=42)
143
- submit_btn = gr.Button("Generate")
144
-
145
- with gr.Column(scale=5):
146
- output_image = gr.Image(type="pil", height=1024, width=1024)
147
-
148
- submit_btn.click(fn=gradio_interface, inputs=[prompt, arousal_slider, valence_slider, model_slider, seed], outputs=output_image)
149
-
150
- if __name__ == "__main__":
151
  iface.launch(debug=True)
 
1
+ import spaces
2
+ import gradio as gr
3
+ import torch
4
+ from PIL import Image
5
+ from transformers import AutoTokenizer
6
+ from diffusers import StableDiffusionXLPipeline
7
+ from huggingface_hub import hf_hub_download
8
+ from model import EmotionInjectionTransformer
9
+ from transformers import GPT2Config
10
+
11
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
12
+
13
+ # Initialize Emotion Injection Model
14
+ config = GPT2Config.from_pretrained('gpt2')
15
+ emotion_add_method = {"a": "cross", "v": "cross"}
16
+ model = EmotionInjectionTransformer(config, final_out_type="Linear+LN").to(device)
17
+ model = torch.nn.DataParallel(model)
18
+
19
+ # Initialize Stable Diffusion XL Pipeline
20
+ pipe = StableDiffusionXLPipeline.from_pretrained(
21
+ "stabilityai/stable-diffusion-xl-base-1.0",
22
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
23
+ use_safetensors=True
24
+ )
25
+ pipe.to(device)
26
+
27
+ @spaces.GPU
28
+ def generate_image(prompt, arousal, valence, model_scale, seed=24):
29
+ # Map scales to checkpoint filenames in the Hugging Face repo
30
+ model_checkpoints = {
31
+ 1.0: 'scale_factor_1.0.pth',
32
+ 1.25: 'scale_factor_1.25.pth',
33
+ 1.5: 'scale_factor_1.5.pth',
34
+ 1.75: 'scale_factor_1.75.pth',
35
+ 2.0: 'scale_factor_2.0.pth'
36
+ }
37
+
38
+ # Download the corresponding checkpoint from the Hugging Face Hub
39
+ if model_scale in model_checkpoints:
40
+ filename = model_checkpoints[model_scale]
41
+ model_path = hf_hub_download(
42
+ repo_id="idvxlab/EmotiCrafter",
43
+ filename=filename
44
+ )
45
+ state_dict = torch.load(model_path, map_location=device)
46
+ model.load_state_dict(state_dict)
47
+ else:
48
+ raise ValueError(f"Model scale {model_scale} not found in hosted checkpoints.")
49
+
50
+ model.eval()
51
+
52
+ # Encode prompt into embeddings
53
+ (prompt_embeds_ori,
54
+ negative_prompt_embeds,
55
+ pooled_prompt_embeds_ori,
56
+ negative_pooled_prompt_embeds) = pipe.encode_prompt(
57
+ prompt=[prompt],
58
+ prompt_2=[prompt],
59
+ device=device,
60
+ num_images_per_prompt=1,
61
+ do_classifier_free_guidance=True,
62
+ negative_prompt=None,
63
+ negative_prompt_2=None
64
+ )
65
+
66
+ resolution = 1024
67
+
68
+ with torch.no_grad():
69
+ # Inject emotions into embeddings
70
+ out = model(
71
+ inputs_embeds=prompt_embeds_ori.to(torch.float32),
72
+ arousal=torch.FloatTensor([[arousal]]).to(device),
73
+ valence=torch.FloatTensor([[valence]]).to(device)
74
+ )
75
+
76
+ # Generate image with or without seed
77
+ gen_kwargs = dict(
78
+ prompt_embeds=out[0].to(torch.float16),
79
+ pooled_prompt_embeds=pooled_prompt_embeds_ori,
80
+ guidance_scale=7.5,
81
+ num_inference_steps=40,
82
+ height=resolution,
83
+ width=resolution
84
+ )
85
+ if seed is not None:
86
+ gen_kwargs['generator'] = torch.manual_seed(seed)
87
+
88
+ image = pipe(**gen_kwargs).images[0]
89
+ return image
90
+
91
+ # Gradio UI
92
+ css = """
93
+ #small-image {
94
+ width: 50%;
95
+ margin: 0 auto;
96
+ }
97
+ """
98
+
99
+ def gradio_interface(prompt, arousal, valence, model_scale, seed=42):
100
+ return generate_image(prompt, arousal, valence, model_scale, seed)
101
+
102
+ html_content = """
103
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
104
+ <div>
105
+ <h1>Emoticrafter</h1>
106
+ <span>Emotion-based image generation using Stable Diffusion XL</span>
107
+ <br>
108
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
109
+ <a href="http://arxiv.org/abs/2501.05710"><img src="https://img.shields.io/badge/arXiv-2407.03168-red"></a>
110
+ <a href="https://github.com/idvxlab/EmotiCrafter"><img src="https://img.shields.io/badge/Github-Code-blue"></a>
111
+ </div>
112
+ </div>
113
+ </div>
114
+ """
115
+
116
+ with gr.Blocks() as iface:
117
+ gr.HTML(html_content)
118
+ description = """
119
+ **You can inject emotions into pictures by adjusting the values of arousal and valence!**
120
+ The Arousal-Valence model is a two-dimensional framework used in psychology and affective computing to describe emotional states.
121
+ - **Valence**: Measures the degree of emotional pleasantness, ranging from negative (e.g., sadness, anger) to positive (e.g., happiness, satisfaction). Scale: -3 (very unpleasant) to 3 (very pleasant).
122
+ - **Arousal**: Measures level of emotional activation, from low (e.g., calm) to high (e.g., excited). Scale: -3 (very calm) to 3 (very excited).
123
+ """
124
+ gr.Markdown(description)
125
+
126
+ with gr.Row():
127
+ with gr.Column(scale=2.25):
128
+ gr.Markdown("<i>Arousal-Valence Model</i>")
129
+ gr.Image("assets/emotion.png", label="Emotion Coordinate System")
130
+ with gr.Column(scale=2):
131
+ gr.Markdown("<i>From left to right: Valence increases</i>")
132
+ gr.Image("assets/output_image.png", label="Valence increasing")
133
+ gr.Markdown("<i>From left to right: Arousal increases</i>")
134
+ gr.Image("assets/output_image3.png", label="Arousal increasing")
135
+
136
+ with gr.Row():
137
+ with gr.Column(scale=2.25):
138
+ prompt = gr.Textbox(label="Prompt", placeholder="Enter the prompt for image generation")
139
+ arousal_slider = gr.Slider(minimum=-3.0, maximum=3.0, step=0.1, label="Arousal", value=0.0)
140
+ valence_slider = gr.Slider(minimum=-3.0, maximum=3.0, step=0.1, label="Valence", value=0.0)
141
+ model_slider = gr.Slider(minimum=1.0, maximum=2.0, step=0.25, label="Model Scale", value=1.5)
142
+ seed = gr.Slider(0, 10000000, step=1, label="Seed", value=42)
143
+ submit_btn = gr.Button("Generate")
144
+
145
+ with gr.Column(scale=5):
146
+ output_image = gr.Image(type="pil", height=1024, width=1024)
147
+
148
+ submit_btn.click(fn=gradio_interface, inputs=[prompt, arousal_slider, valence_slider, model_slider, seed], outputs=output_image)
149
+
150
+ if __name__ == "__main__":
151
  iface.launch(debug=True)