Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,852 Bytes
a156452 031cf60 a156452 c7f9353 a156452 c6d7b6f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import spaces
import gradio as gr
import torch
from PIL import Image
from transformers import AutoTokenizer
from diffusers import StableDiffusionXLPipeline
from huggingface_hub import hf_hub_download
from model import EmotionInjectionTransformer
from transformers import GPT2Config
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Initialize Emotion Injection Model
config = GPT2Config.from_pretrained('gpt2')
emotion_add_method = {"a": "cross", "v": "cross"}
model = EmotionInjectionTransformer(config, final_out_type="Linear+LN").to(device)
model = torch.nn.DataParallel(model)
# Initialize Stable Diffusion XL Pipeline
pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
use_safetensors=True
)
pipe.to(device)
@spaces.GPU
def generate_image(prompt, arousal, valence, model_scale, seed=24):
# Map scales to checkpoint filenames in the Hugging Face repo
model_checkpoints = {
1.0: 'scale_factor_1.0.pth',
1.25: 'scale_factor_1.25.pth',
1.5: 'scale_factor_1.5.pth',
1.75: 'scale_factor_1.75.pth',
2.0: 'scale_factor_2.0.pth'
}
# Download the corresponding checkpoint from the Hugging Face Hub
if model_scale in model_checkpoints:
filename = model_checkpoints[model_scale]
model_path = hf_hub_download(
repo_id="idvxlab/EmotiCrafter",
filename=filename
)
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict)
else:
raise ValueError(f"Model scale {model_scale} not found in hosted checkpoints.")
model.eval()
# Encode prompt into embeddings
(prompt_embeds_ori,
negative_prompt_embeds,
pooled_prompt_embeds_ori,
negative_pooled_prompt_embeds) = pipe.encode_prompt(
prompt=[prompt],
prompt_2=[prompt],
device=device,
num_images_per_prompt=1,
do_classifier_free_guidance=True,
negative_prompt=None,
negative_prompt_2=None
)
resolution = 1024
with torch.no_grad():
# Inject emotions into embeddings
out = model(
inputs_embeds=prompt_embeds_ori.to(torch.float32),
arousal=torch.FloatTensor([[arousal]]).to(device),
valence=torch.FloatTensor([[valence]]).to(device)
)
# Generate image with or without seed
gen_kwargs = dict(
prompt_embeds=out[0].to(torch.float16),
pooled_prompt_embeds=pooled_prompt_embeds_ori,
guidance_scale=7.5,
num_inference_steps=40,
height=resolution,
width=resolution
)
if seed is not None:
gen_kwargs['generator'] = torch.manual_seed(seed)
image = pipe(**gen_kwargs).images[0]
return image
# Gradio UI
css = """
#small-image {
width: 50%;
margin: 0 auto;
}
"""
def gradio_interface(prompt, arousal, valence, model_scale, seed=42):
return generate_image(prompt, arousal, valence, model_scale, seed)
html_content = """
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div>
<h1>Emoticrafter</h1>
<span>Emotion-based image generation using Stable Diffusion XL</span>
<br>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<a href="http://arxiv.org/abs/2501.05710"><img src="https://img.shields.io/badge/arXiv-2501.05710-red"></a>
<a href="https://github.com/idvxlab/EmotiCrafter"><img src="https://img.shields.io/badge/Github-Code-blue"></a>
</div>
</div>
</div>
"""
with gr.Blocks() as iface:
gr.HTML(html_content)
description = """
**You can inject emotions into pictures by adjusting the values of arousal and valence!**
The Arousal-Valence model is a two-dimensional framework used in psychology and affective computing to describe emotional states.
- **Valence**: Measures the degree of emotional pleasantness, ranging from negative (e.g., sadness, anger) to positive (e.g., happiness, satisfaction). Scale: -3 (very unpleasant) to 3 (very pleasant).
- **Arousal**: Measures level of emotional activation, from low (e.g., calm) to high (e.g., excited). Scale: -3 (very calm) to 3 (very excited).
"""
gr.Markdown(description)
with gr.Row():
with gr.Column(scale=1.5):
gr.Markdown("<i>Arousal-Valence Model</i>")
gr.Image("assets/emotion.png", label="Emotion Coordinate System")
with gr.Column(scale=2):
gr.Markdown("<i>From left to right: Valence increases</i>")
gr.Image("assets/output_image.png", label="Valence increasing")
gr.Markdown("<i>From left to right: Arousal increases</i>")
gr.Image("assets/output_image3.png", label="Arousal increasing")
with gr.Row():
with gr.Column(scale=2.25):
prompt = gr.Textbox(label="Prompt", placeholder="Enter the prompt for image generation")
arousal_slider = gr.Slider(minimum=-3.0, maximum=3.0, step=0.1, label="Arousal", value=0.0)
valence_slider = gr.Slider(minimum=-3.0, maximum=3.0, step=0.1, label="Valence", value=0.0)
model_slider = gr.Slider(minimum=1.0, maximum=2.0, step=0.25, label="Model Scale", value=1.5)
seed = gr.Slider(0, 10000000, step=1, label="Seed", value=42)
submit_btn = gr.Button("Generate")
with gr.Column(scale=5):
output_image = gr.Image(type="pil", height=1024, width=1024)
submit_btn.click(fn=gradio_interface, inputs=[prompt, arousal_slider, valence_slider, model_slider, seed], outputs=output_image)
if __name__ == "__main__":
iface.launch(debug=True) |