idvxlab's picture
Update app.py
031cf60 verified
import spaces
import gradio as gr
import torch
from PIL import Image
from transformers import AutoTokenizer
from diffusers import StableDiffusionXLPipeline
from huggingface_hub import hf_hub_download
from model import EmotionInjectionTransformer
from transformers import GPT2Config
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Initialize Emotion Injection Model
config = GPT2Config.from_pretrained('gpt2')
emotion_add_method = {"a": "cross", "v": "cross"}
model = EmotionInjectionTransformer(config, final_out_type="Linear+LN").to(device)
model = torch.nn.DataParallel(model)
# Initialize Stable Diffusion XL Pipeline
pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
use_safetensors=True
)
pipe.to(device)
@spaces.GPU
def generate_image(prompt, arousal, valence, model_scale, seed=24):
# Map scales to checkpoint filenames in the Hugging Face repo
model_checkpoints = {
1.0: 'scale_factor_1.0.pth',
1.25: 'scale_factor_1.25.pth',
1.5: 'scale_factor_1.5.pth',
1.75: 'scale_factor_1.75.pth',
2.0: 'scale_factor_2.0.pth'
}
# Download the corresponding checkpoint from the Hugging Face Hub
if model_scale in model_checkpoints:
filename = model_checkpoints[model_scale]
model_path = hf_hub_download(
repo_id="idvxlab/EmotiCrafter",
filename=filename
)
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict)
else:
raise ValueError(f"Model scale {model_scale} not found in hosted checkpoints.")
model.eval()
# Encode prompt into embeddings
(prompt_embeds_ori,
negative_prompt_embeds,
pooled_prompt_embeds_ori,
negative_pooled_prompt_embeds) = pipe.encode_prompt(
prompt=[prompt],
prompt_2=[prompt],
device=device,
num_images_per_prompt=1,
do_classifier_free_guidance=True,
negative_prompt=None,
negative_prompt_2=None
)
resolution = 1024
with torch.no_grad():
# Inject emotions into embeddings
out = model(
inputs_embeds=prompt_embeds_ori.to(torch.float32),
arousal=torch.FloatTensor([[arousal]]).to(device),
valence=torch.FloatTensor([[valence]]).to(device)
)
# Generate image with or without seed
gen_kwargs = dict(
prompt_embeds=out[0].to(torch.float16),
pooled_prompt_embeds=pooled_prompt_embeds_ori,
guidance_scale=7.5,
num_inference_steps=40,
height=resolution,
width=resolution
)
if seed is not None:
gen_kwargs['generator'] = torch.manual_seed(seed)
image = pipe(**gen_kwargs).images[0]
return image
# Gradio UI
css = """
#small-image {
width: 50%;
margin: 0 auto;
}
"""
def gradio_interface(prompt, arousal, valence, model_scale, seed=42):
return generate_image(prompt, arousal, valence, model_scale, seed)
html_content = """
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div>
<h1>Emoticrafter</h1>
<span>Emotion-based image generation using Stable Diffusion XL</span>
<br>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<a href="http://arxiv.org/abs/2501.05710"><img src="https://img.shields.io/badge/arXiv-2501.05710-red"></a>
<a href="https://github.com/idvxlab/EmotiCrafter"><img src="https://img.shields.io/badge/Github-Code-blue"></a>
</div>
</div>
</div>
"""
with gr.Blocks() as iface:
gr.HTML(html_content)
description = """
**You can inject emotions into pictures by adjusting the values of arousal and valence!**
The Arousal-Valence model is a two-dimensional framework used in psychology and affective computing to describe emotional states.
- **Valence**: Measures the degree of emotional pleasantness, ranging from negative (e.g., sadness, anger) to positive (e.g., happiness, satisfaction). Scale: -3 (very unpleasant) to 3 (very pleasant).
- **Arousal**: Measures level of emotional activation, from low (e.g., calm) to high (e.g., excited). Scale: -3 (very calm) to 3 (very excited).
"""
gr.Markdown(description)
with gr.Row():
with gr.Column(scale=1.5):
gr.Markdown("<i>Arousal-Valence Model</i>")
gr.Image("assets/emotion.png", label="Emotion Coordinate System")
with gr.Column(scale=2):
gr.Markdown("<i>From left to right: Valence increases</i>")
gr.Image("assets/output_image.png", label="Valence increasing")
gr.Markdown("<i>From left to right: Arousal increases</i>")
gr.Image("assets/output_image3.png", label="Arousal increasing")
with gr.Row():
with gr.Column(scale=2.25):
prompt = gr.Textbox(label="Prompt", placeholder="Enter the prompt for image generation")
arousal_slider = gr.Slider(minimum=-3.0, maximum=3.0, step=0.1, label="Arousal", value=0.0)
valence_slider = gr.Slider(minimum=-3.0, maximum=3.0, step=0.1, label="Valence", value=0.0)
model_slider = gr.Slider(minimum=1.0, maximum=2.0, step=0.25, label="Model Scale", value=1.5)
seed = gr.Slider(0, 10000000, step=1, label="Seed", value=42)
submit_btn = gr.Button("Generate")
with gr.Column(scale=5):
output_image = gr.Image(type="pil", height=1024, width=1024)
submit_btn.click(fn=gradio_interface, inputs=[prompt, arousal_slider, valence_slider, model_slider, seed], outputs=output_image)
if __name__ == "__main__":
iface.launch(debug=True)