Spaces:
Sleeping
Sleeping
| import spaces | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoTokenizer | |
| from diffusers import StableDiffusionXLPipeline | |
| from huggingface_hub import hf_hub_download | |
| from model import EmotionInjectionTransformer | |
| from transformers import GPT2Config | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| # Initialize Emotion Injection Model | |
| config = GPT2Config.from_pretrained('gpt2') | |
| emotion_add_method = {"a": "cross", "v": "cross"} | |
| model = EmotionInjectionTransformer(config, final_out_type="Linear+LN").to(device) | |
| model = torch.nn.DataParallel(model) | |
| # Initialize Stable Diffusion XL Pipeline | |
| pipe = StableDiffusionXLPipeline.from_pretrained( | |
| "stabilityai/stable-diffusion-xl-base-1.0", | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| use_safetensors=True | |
| ) | |
| pipe.to(device) | |
| def generate_image(prompt, arousal, valence, model_scale, seed=24): | |
| # Map scales to checkpoint filenames in the Hugging Face repo | |
| model_checkpoints = { | |
| 1.0: 'scale_factor_1.0.pth', | |
| 1.25: 'scale_factor_1.25.pth', | |
| 1.5: 'scale_factor_1.5.pth', | |
| 1.75: 'scale_factor_1.75.pth', | |
| 2.0: 'scale_factor_2.0.pth' | |
| } | |
| # Download the corresponding checkpoint from the Hugging Face Hub | |
| if model_scale in model_checkpoints: | |
| filename = model_checkpoints[model_scale] | |
| model_path = hf_hub_download( | |
| repo_id="idvxlab/EmotiCrafter", | |
| filename=filename | |
| ) | |
| state_dict = torch.load(model_path, map_location=device) | |
| model.load_state_dict(state_dict) | |
| else: | |
| raise ValueError(f"Model scale {model_scale} not found in hosted checkpoints.") | |
| model.eval() | |
| # Encode prompt into embeddings | |
| (prompt_embeds_ori, | |
| negative_prompt_embeds, | |
| pooled_prompt_embeds_ori, | |
| negative_pooled_prompt_embeds) = pipe.encode_prompt( | |
| prompt=[prompt], | |
| prompt_2=[prompt], | |
| device=device, | |
| num_images_per_prompt=1, | |
| do_classifier_free_guidance=True, | |
| negative_prompt=None, | |
| negative_prompt_2=None | |
| ) | |
| resolution = 1024 | |
| with torch.no_grad(): | |
| # Inject emotions into embeddings | |
| out = model( | |
| inputs_embeds=prompt_embeds_ori.to(torch.float32), | |
| arousal=torch.FloatTensor([[arousal]]).to(device), | |
| valence=torch.FloatTensor([[valence]]).to(device) | |
| ) | |
| # Generate image with or without seed | |
| gen_kwargs = dict( | |
| prompt_embeds=out[0].to(torch.float16), | |
| pooled_prompt_embeds=pooled_prompt_embeds_ori, | |
| guidance_scale=7.5, | |
| num_inference_steps=40, | |
| height=resolution, | |
| width=resolution | |
| ) | |
| if seed is not None: | |
| gen_kwargs['generator'] = torch.manual_seed(seed) | |
| image = pipe(**gen_kwargs).images[0] | |
| return image | |
| # Gradio UI | |
| css = """ | |
| #small-image { | |
| width: 50%; | |
| margin: 0 auto; | |
| } | |
| """ | |
| def gradio_interface(prompt, arousal, valence, model_scale, seed=42): | |
| return generate_image(prompt, arousal, valence, model_scale, seed) | |
| html_content = """ | |
| <div style="display: flex; justify-content: center; align-items: center; text-align: center;"> | |
| <div> | |
| <h1>Emoticrafter</h1> | |
| <span>Emotion-based image generation using Stable Diffusion XL</span> | |
| <br> | |
| <div style="display: flex; justify-content: center; align-items: center; text-align: center;"> | |
| <a href="http://arxiv.org/abs/2501.05710"><img src="https://img.shields.io/badge/arXiv-2407.03168-red"></a> | |
| <a href="https://github.com/idvxlab/EmotiCrafter"><img src="https://img.shields.io/badge/Github-Code-blue"></a> | |
| </div> | |
| </div> | |
| </div> | |
| """ | |
| with gr.Blocks() as iface: | |
| gr.HTML(html_content) | |
| description = """ | |
| **You can inject emotions into pictures by adjusting the values of arousal and valence!** | |
| The Arousal-Valence model is a two-dimensional framework used in psychology and affective computing to describe emotional states. | |
| - **Valence**: Measures the degree of emotional pleasantness, ranging from negative (e.g., sadness, anger) to positive (e.g., happiness, satisfaction). Scale: -3 (very unpleasant) to 3 (very pleasant). | |
| - **Arousal**: Measures level of emotional activation, from low (e.g., calm) to high (e.g., excited). Scale: -3 (very calm) to 3 (very excited). | |
| """ | |
| gr.Markdown(description) | |
| with gr.Row(): | |
| with gr.Column(scale=1.5): | |
| gr.Markdown("<i>Arousal-Valence Model</i>") | |
| gr.Image("assets/emotion.png", label="Emotion Coordinate System") | |
| with gr.Column(scale=2): | |
| gr.Markdown("<i>From left to right: Valence increases</i>") | |
| gr.Image("assets/output_image.png", label="Valence increasing") | |
| gr.Markdown("<i>From left to right: Arousal increases</i>") | |
| gr.Image("assets/output_image3.png", label="Arousal increasing") | |
| with gr.Row(): | |
| with gr.Column(scale=2.25): | |
| prompt = gr.Textbox(label="Prompt", placeholder="Enter the prompt for image generation") | |
| arousal_slider = gr.Slider(minimum=-3.0, maximum=3.0, step=0.1, label="Arousal", value=0.0) | |
| valence_slider = gr.Slider(minimum=-3.0, maximum=3.0, step=0.1, label="Valence", value=0.0) | |
| model_slider = gr.Slider(minimum=1.0, maximum=2.0, step=0.25, label="Model Scale", value=1.5) | |
| seed = gr.Slider(0, 10000000, step=1, label="Seed", value=42) | |
| submit_btn = gr.Button("Generate") | |
| with gr.Column(scale=5): | |
| output_image = gr.Image(type="pil", height=1024, width=1024) | |
| submit_btn.click(fn=gradio_interface, inputs=[prompt, arousal_slider, valence_slider, model_slider, seed], outputs=output_image) | |
| if __name__ == "__main__": | |
| iface.launch(debug=True) |