Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
import torch | |
from PIL import Image | |
from transformers import AutoTokenizer | |
from diffusers import StableDiffusionXLPipeline | |
from huggingface_hub import hf_hub_download | |
from model import EmotionInjectionTransformer | |
from transformers import GPT2Config | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
# Initialize Emotion Injection Model | |
config = GPT2Config.from_pretrained('gpt2') | |
emotion_add_method = {"a": "cross", "v": "cross"} | |
model = EmotionInjectionTransformer(config, final_out_type="Linear+LN").to(device) | |
model = torch.nn.DataParallel(model) | |
# Initialize Stable Diffusion XL Pipeline | |
pipe = StableDiffusionXLPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0", | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
use_safetensors=True | |
) | |
pipe.to(device) | |
def generate_image(prompt, arousal, valence, model_scale, seed=24): | |
# Map scales to checkpoint filenames in the Hugging Face repo | |
model_checkpoints = { | |
1.0: 'scale_factor_1.0.pth', | |
1.25: 'scale_factor_1.25.pth', | |
1.5: 'scale_factor_1.5.pth', | |
1.75: 'scale_factor_1.75.pth', | |
2.0: 'scale_factor_2.0.pth' | |
} | |
# Download the corresponding checkpoint from the Hugging Face Hub | |
if model_scale in model_checkpoints: | |
filename = model_checkpoints[model_scale] | |
model_path = hf_hub_download( | |
repo_id="idvxlab/EmotiCrafter", | |
filename=filename | |
) | |
state_dict = torch.load(model_path, map_location=device) | |
model.load_state_dict(state_dict) | |
else: | |
raise ValueError(f"Model scale {model_scale} not found in hosted checkpoints.") | |
model.eval() | |
# Encode prompt into embeddings | |
(prompt_embeds_ori, | |
negative_prompt_embeds, | |
pooled_prompt_embeds_ori, | |
negative_pooled_prompt_embeds) = pipe.encode_prompt( | |
prompt=[prompt], | |
prompt_2=[prompt], | |
device=device, | |
num_images_per_prompt=1, | |
do_classifier_free_guidance=True, | |
negative_prompt=None, | |
negative_prompt_2=None | |
) | |
resolution = 1024 | |
with torch.no_grad(): | |
# Inject emotions into embeddings | |
out = model( | |
inputs_embeds=prompt_embeds_ori.to(torch.float32), | |
arousal=torch.FloatTensor([[arousal]]).to(device), | |
valence=torch.FloatTensor([[valence]]).to(device) | |
) | |
# Generate image with or without seed | |
gen_kwargs = dict( | |
prompt_embeds=out[0].to(torch.float16), | |
pooled_prompt_embeds=pooled_prompt_embeds_ori, | |
guidance_scale=7.5, | |
num_inference_steps=40, | |
height=resolution, | |
width=resolution | |
) | |
if seed is not None: | |
gen_kwargs['generator'] = torch.manual_seed(seed) | |
image = pipe(**gen_kwargs).images[0] | |
return image | |
# Gradio UI | |
css = """ | |
#small-image { | |
width: 50%; | |
margin: 0 auto; | |
} | |
""" | |
def gradio_interface(prompt, arousal, valence, model_scale, seed=42): | |
return generate_image(prompt, arousal, valence, model_scale, seed) | |
html_content = """ | |
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> | |
<div> | |
<h1>Emoticrafter</h1> | |
<span>Emotion-based image generation using Stable Diffusion XL</span> | |
<br> | |
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> | |
<a href="http://arxiv.org/abs/2501.05710"><img src="https://img.shields.io/badge/arXiv-2501.05710-red"></a> | |
<a href="https://github.com/idvxlab/EmotiCrafter"><img src="https://img.shields.io/badge/Github-Code-blue"></a> | |
</div> | |
</div> | |
</div> | |
""" | |
with gr.Blocks() as iface: | |
gr.HTML(html_content) | |
description = """ | |
**You can inject emotions into pictures by adjusting the values of arousal and valence!** | |
The Arousal-Valence model is a two-dimensional framework used in psychology and affective computing to describe emotional states. | |
- **Valence**: Measures the degree of emotional pleasantness, ranging from negative (e.g., sadness, anger) to positive (e.g., happiness, satisfaction). Scale: -3 (very unpleasant) to 3 (very pleasant). | |
- **Arousal**: Measures level of emotional activation, from low (e.g., calm) to high (e.g., excited). Scale: -3 (very calm) to 3 (very excited). | |
""" | |
gr.Markdown(description) | |
with gr.Row(): | |
with gr.Column(scale=1.5): | |
gr.Markdown("<i>Arousal-Valence Model</i>") | |
gr.Image("assets/emotion.png", label="Emotion Coordinate System") | |
with gr.Column(scale=2): | |
gr.Markdown("<i>From left to right: Valence increases</i>") | |
gr.Image("assets/output_image.png", label="Valence increasing") | |
gr.Markdown("<i>From left to right: Arousal increases</i>") | |
gr.Image("assets/output_image3.png", label="Arousal increasing") | |
with gr.Row(): | |
with gr.Column(scale=2.25): | |
prompt = gr.Textbox(label="Prompt", placeholder="Enter the prompt for image generation") | |
arousal_slider = gr.Slider(minimum=-3.0, maximum=3.0, step=0.1, label="Arousal", value=0.0) | |
valence_slider = gr.Slider(minimum=-3.0, maximum=3.0, step=0.1, label="Valence", value=0.0) | |
model_slider = gr.Slider(minimum=1.0, maximum=2.0, step=0.25, label="Model Scale", value=1.5) | |
seed = gr.Slider(0, 10000000, step=1, label="Seed", value=42) | |
submit_btn = gr.Button("Generate") | |
with gr.Column(scale=5): | |
output_image = gr.Image(type="pil", height=1024, width=1024) | |
submit_btn.click(fn=gradio_interface, inputs=[prompt, arousal_slider, valence_slider, model_slider, seed], outputs=output_image) | |
if __name__ == "__main__": | |
iface.launch(debug=True) |