import spaces import gradio as gr import torch from PIL import Image from transformers import AutoTokenizer from diffusers import StableDiffusionXLPipeline from huggingface_hub import hf_hub_download from model import EmotionInjectionTransformer from transformers import GPT2Config device = 'cuda' if torch.cuda.is_available() else 'cpu' # Initialize Emotion Injection Model config = GPT2Config.from_pretrained('gpt2') emotion_add_method = {"a": "cross", "v": "cross"} model = EmotionInjectionTransformer(config, final_out_type="Linear+LN").to(device) model = torch.nn.DataParallel(model) # Initialize Stable Diffusion XL Pipeline pipe = StableDiffusionXLPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, use_safetensors=True ) pipe.to(device) @spaces.GPU def generate_image(prompt, arousal, valence, model_scale, seed=24): # Map scales to checkpoint filenames in the Hugging Face repo model_checkpoints = { 1.0: 'scale_factor_1.0.pth', 1.25: 'scale_factor_1.25.pth', 1.5: 'scale_factor_1.5.pth', 1.75: 'scale_factor_1.75.pth', 2.0: 'scale_factor_2.0.pth' } # Download the corresponding checkpoint from the Hugging Face Hub if model_scale in model_checkpoints: filename = model_checkpoints[model_scale] model_path = hf_hub_download( repo_id="idvxlab/EmotiCrafter", filename=filename ) state_dict = torch.load(model_path, map_location=device) model.load_state_dict(state_dict) else: raise ValueError(f"Model scale {model_scale} not found in hosted checkpoints.") model.eval() # Encode prompt into embeddings (prompt_embeds_ori, negative_prompt_embeds, pooled_prompt_embeds_ori, negative_pooled_prompt_embeds) = pipe.encode_prompt( prompt=[prompt], prompt_2=[prompt], device=device, num_images_per_prompt=1, do_classifier_free_guidance=True, negative_prompt=None, negative_prompt_2=None ) resolution = 1024 with torch.no_grad(): # Inject emotions into embeddings out = model( inputs_embeds=prompt_embeds_ori.to(torch.float32), arousal=torch.FloatTensor([[arousal]]).to(device), valence=torch.FloatTensor([[valence]]).to(device) ) # Generate image with or without seed gen_kwargs = dict( prompt_embeds=out[0].to(torch.float16), pooled_prompt_embeds=pooled_prompt_embeds_ori, guidance_scale=7.5, num_inference_steps=40, height=resolution, width=resolution ) if seed is not None: gen_kwargs['generator'] = torch.manual_seed(seed) image = pipe(**gen_kwargs).images[0] return image # Gradio UI css = """ #small-image { width: 50%; margin: 0 auto; } """ def gradio_interface(prompt, arousal, valence, model_scale, seed=42): return generate_image(prompt, arousal, valence, model_scale, seed) html_content = """
""" with gr.Blocks() as iface: gr.HTML(html_content) description = """ **You can inject emotions into pictures by adjusting the values of arousal and valence!** The Arousal-Valence model is a two-dimensional framework used in psychology and affective computing to describe emotional states. - **Valence**: Measures the degree of emotional pleasantness, ranging from negative (e.g., sadness, anger) to positive (e.g., happiness, satisfaction). Scale: -3 (very unpleasant) to 3 (very pleasant). - **Arousal**: Measures level of emotional activation, from low (e.g., calm) to high (e.g., excited). Scale: -3 (very calm) to 3 (very excited). """ gr.Markdown(description) with gr.Row(): with gr.Column(scale=1.5): gr.Markdown("Arousal-Valence Model") gr.Image("assets/emotion.png", label="Emotion Coordinate System") with gr.Column(scale=2): gr.Markdown("From left to right: Valence increases") gr.Image("assets/output_image.png", label="Valence increasing") gr.Markdown("From left to right: Arousal increases") gr.Image("assets/output_image3.png", label="Arousal increasing") with gr.Row(): with gr.Column(scale=2.25): prompt = gr.Textbox(label="Prompt", placeholder="Enter the prompt for image generation") arousal_slider = gr.Slider(minimum=-3.0, maximum=3.0, step=0.1, label="Arousal", value=0.0) valence_slider = gr.Slider(minimum=-3.0, maximum=3.0, step=0.1, label="Valence", value=0.0) model_slider = gr.Slider(minimum=1.0, maximum=2.0, step=0.25, label="Model Scale", value=1.5) seed = gr.Slider(0, 10000000, step=1, label="Seed", value=42) submit_btn = gr.Button("Generate") with gr.Column(scale=5): output_image = gr.Image(type="pil", height=1024, width=1024) submit_btn.click(fn=gradio_interface, inputs=[prompt, arousal_slider, valence_slider, model_slider, seed], outputs=output_image) if __name__ == "__main__": iface.launch(debug=True)