Spaces:
Running
on
Zero
Running
on
Zero
Upload 6 files
Browse files- .gitattributes +2 -0
- app.py +151 -0
- assets/emotion.png +0 -0
- assets/output_image.png +3 -0
- assets/output_image3.png +3 -0
- model.py +267 -0
- requirements.txt +241 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/output_image.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/output_image3.png filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spaces
|
2 |
+
import gradio as gr
|
3 |
+
import torch
|
4 |
+
from PIL import Image
|
5 |
+
from transformers import AutoTokenizer
|
6 |
+
from diffusers import StableDiffusionXLPipeline
|
7 |
+
from huggingface_hub import hf_hub_download
|
8 |
+
from model import EmotionInjectionTransformer
|
9 |
+
from transformers import GPT2Config
|
10 |
+
|
11 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
12 |
+
|
13 |
+
# Initialize Emotion Injection Model
|
14 |
+
config = GPT2Config.from_pretrained('gpt2')
|
15 |
+
emotion_add_method = {"a": "cross", "v": "cross"}
|
16 |
+
model = EmotionInjectionTransformer(config, final_out_type="Linear+LN").to(device)
|
17 |
+
model = torch.nn.DataParallel(model)
|
18 |
+
|
19 |
+
# Initialize Stable Diffusion XL Pipeline
|
20 |
+
pipe = StableDiffusionXLPipeline.from_pretrained(
|
21 |
+
"stabilityai/stable-diffusion-xl-base-1.0",
|
22 |
+
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
23 |
+
use_safetensors=True
|
24 |
+
)
|
25 |
+
pipe.to(device)
|
26 |
+
|
27 |
+
@spaces.GPU
|
28 |
+
def generate_image(prompt, arousal, valence, model_scale, seed=24):
|
29 |
+
# Map scales to checkpoint filenames in the Hugging Face repo
|
30 |
+
model_checkpoints = {
|
31 |
+
1.0: 'scale_factor_1.0.pth',
|
32 |
+
1.25: 'scale_factor_1.25.pth',
|
33 |
+
1.5: 'scale_factor_1.5.pth',
|
34 |
+
1.75: 'scale_factor_1.75.pth',
|
35 |
+
2.0: 'scale_factor_2.0.pth'
|
36 |
+
}
|
37 |
+
|
38 |
+
# Download the corresponding checkpoint from the Hugging Face Hub
|
39 |
+
if model_scale in model_checkpoints:
|
40 |
+
filename = model_checkpoints[model_scale]
|
41 |
+
model_path = hf_hub_download(
|
42 |
+
repo_id="idvxlab/EmotiCrafter",
|
43 |
+
filename=filename
|
44 |
+
)
|
45 |
+
state_dict = torch.load(model_path, map_location=device)
|
46 |
+
model.load_state_dict(state_dict)
|
47 |
+
else:
|
48 |
+
raise ValueError(f"Model scale {model_scale} not found in hosted checkpoints.")
|
49 |
+
|
50 |
+
model.eval()
|
51 |
+
|
52 |
+
# Encode prompt into embeddings
|
53 |
+
(prompt_embeds_ori,
|
54 |
+
negative_prompt_embeds,
|
55 |
+
pooled_prompt_embeds_ori,
|
56 |
+
negative_pooled_prompt_embeds) = pipe.encode_prompt(
|
57 |
+
prompt=[prompt],
|
58 |
+
prompt_2=[prompt],
|
59 |
+
device=device,
|
60 |
+
num_images_per_prompt=1,
|
61 |
+
do_classifier_free_guidance=True,
|
62 |
+
negative_prompt=None,
|
63 |
+
negative_prompt_2=None
|
64 |
+
)
|
65 |
+
|
66 |
+
resolution = 1024
|
67 |
+
|
68 |
+
with torch.no_grad():
|
69 |
+
# Inject emotions into embeddings
|
70 |
+
out = model(
|
71 |
+
inputs_embeds=prompt_embeds_ori.to(torch.float32),
|
72 |
+
arousal=torch.FloatTensor([[arousal]]).to(device),
|
73 |
+
valence=torch.FloatTensor([[valence]]).to(device)
|
74 |
+
)
|
75 |
+
|
76 |
+
# Generate image with or without seed
|
77 |
+
gen_kwargs = dict(
|
78 |
+
prompt_embeds=out[0].to(torch.float16),
|
79 |
+
pooled_prompt_embeds=pooled_prompt_embeds_ori,
|
80 |
+
guidance_scale=7.5,
|
81 |
+
num_inference_steps=40,
|
82 |
+
height=resolution,
|
83 |
+
width=resolution
|
84 |
+
)
|
85 |
+
if seed is not None:
|
86 |
+
gen_kwargs['generator'] = torch.manual_seed(seed)
|
87 |
+
|
88 |
+
image = pipe(**gen_kwargs).images[0]
|
89 |
+
return image
|
90 |
+
|
91 |
+
# Gradio UI
|
92 |
+
css = """
|
93 |
+
#small-image {
|
94 |
+
width: 50%;
|
95 |
+
margin: 0 auto;
|
96 |
+
}
|
97 |
+
"""
|
98 |
+
|
99 |
+
def gradio_interface(prompt, arousal, valence, model_scale, seed=42):
|
100 |
+
return generate_image(prompt, arousal, valence, model_scale, seed)
|
101 |
+
|
102 |
+
html_content = """
|
103 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
104 |
+
<div>
|
105 |
+
<h1>Emoticrafter</h1>
|
106 |
+
<span>Emotion-based image generation using Stable Diffusion XL</span>
|
107 |
+
<br>
|
108 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
109 |
+
<a href="http://arxiv.org/abs/2501.05710"><img src="https://img.shields.io/badge/arXiv-2407.03168-red"></a>
|
110 |
+
<a href="https://github.com/idvxlab/EmotiCrafter"><img src="https://img.shields.io/badge/Github-Code-blue"></a>
|
111 |
+
</div>
|
112 |
+
</div>
|
113 |
+
</div>
|
114 |
+
"""
|
115 |
+
|
116 |
+
with gr.Blocks() as iface:
|
117 |
+
gr.HTML(html_content)
|
118 |
+
description = """
|
119 |
+
**You can inject emotions into pictures by adjusting the values of arousal and valence!**
|
120 |
+
The Arousal-Valence model is a two-dimensional framework used in psychology and affective computing to describe emotional states.
|
121 |
+
- **Valence**: Measures the degree of emotional pleasantness, ranging from negative (e.g., sadness, anger) to positive (e.g., happiness, satisfaction). Scale: -3 (very unpleasant) to 3 (very pleasant).
|
122 |
+
- **Arousal**: Measures level of emotional activation, from low (e.g., calm) to high (e.g., excited). Scale: -3 (very calm) to 3 (very excited).
|
123 |
+
"""
|
124 |
+
gr.Markdown(description)
|
125 |
+
|
126 |
+
with gr.Row():
|
127 |
+
with gr.Column(scale=2.25):
|
128 |
+
gr.Markdown("<i>Arousal-Valence Model</i>")
|
129 |
+
gr.Image("assets/emotion.png", label="Emotion Coordinate System")
|
130 |
+
with gr.Column(scale=5):
|
131 |
+
gr.Markdown("<i>From left to right: Valence increases</i>")
|
132 |
+
gr.Image("assets/output_image.png", label="Valence increasing")
|
133 |
+
gr.Markdown("<i>From left to right: Arousal increases</i>")
|
134 |
+
gr.Image("assets/output_image3.png", label="Arousal increasing")
|
135 |
+
|
136 |
+
with gr.Row():
|
137 |
+
with gr.Column(scale=2.25):
|
138 |
+
prompt = gr.Textbox(label="Prompt", placeholder="Enter the prompt for image generation")
|
139 |
+
arousal_slider = gr.Slider(minimum=-3.0, maximum=3.0, step=0.1, label="Arousal", value=0.0)
|
140 |
+
valence_slider = gr.Slider(minimum=-3.0, maximum=3.0, step=0.1, label="Valence", value=0.0)
|
141 |
+
model_slider = gr.Slider(minimum=1.0, maximum=2.0, step=0.25, label="Model Scale", value=1.5)
|
142 |
+
seed = gr.Slider(0, 10000000, step=1, label="Seed", value=42)
|
143 |
+
submit_btn = gr.Button("Generate")
|
144 |
+
|
145 |
+
with gr.Column(scale=5):
|
146 |
+
output_image = gr.Image(type="pil", height=1024, width=1024)
|
147 |
+
|
148 |
+
submit_btn.click(fn=gradio_interface, inputs=[prompt, arousal_slider, valence_slider, model_slider, seed], outputs=output_image)
|
149 |
+
|
150 |
+
if __name__ == "__main__":
|
151 |
+
iface.launch(debug=True)
|
assets/emotion.png
ADDED
![]() |
assets/output_image.png
ADDED
![]() |
Git LFS Details
|
assets/output_image3.png
ADDED
![]() |
Git LFS Details
|
model.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import GPT2Model, GPT2Config
|
3 |
+
from transformers.modeling_utils import PreTrainedModel, Conv1D, prune_conv1d_layer, SequenceSummary
|
4 |
+
from transformers.models.gpt2.modeling_gpt2 import (
|
5 |
+
GPT2Block, GPT2Attention, GPT2MLP
|
6 |
+
)
|
7 |
+
from torch import nn
|
8 |
+
|
9 |
+
class Cond_Attention(GPT2Attention):
|
10 |
+
def __init__(self, nx, n_ctx, config, is_cross_attention=False):
|
11 |
+
super(GPT2Attention, self).__init__()
|
12 |
+
self.output_attentions = config.output_attentions
|
13 |
+
n_state = nx
|
14 |
+
assert n_state % config.n_head == 0
|
15 |
+
self.embed_dim = config.n_embd
|
16 |
+
self.num_heads = config.n_head
|
17 |
+
self.head_dim = self.embed_dim // self.num_heads
|
18 |
+
self.split_size = n_state
|
19 |
+
self.scale_attn_weights = config.scale_attn_weights
|
20 |
+
self.is_cross_attention = is_cross_attention
|
21 |
+
self.c_attn = Conv1D(n_state * 3, nx)
|
22 |
+
self.c_proj = Conv1D(n_state, nx)
|
23 |
+
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
24 |
+
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
25 |
+
self.pruned_heads = set()
|
26 |
+
self.c_z = Conv1D(n_state * 2, nx)
|
27 |
+
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
28 |
+
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
29 |
+
if self.scale_attn_weights:
|
30 |
+
attn_weights = attn_weights / torch.full(
|
31 |
+
[], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
|
32 |
+
)
|
33 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
34 |
+
attn_weights = attn_weights.type(value.dtype)
|
35 |
+
attn_weights = self.attn_dropout(attn_weights)
|
36 |
+
|
37 |
+
if head_mask is not None:
|
38 |
+
attn_weights = attn_weights * head_mask
|
39 |
+
|
40 |
+
attn_output = torch.matmul(attn_weights, value)
|
41 |
+
return attn_output, attn_weights
|
42 |
+
|
43 |
+
def forward(self, x, z, layer_past=None, attention_mask=None, head_mask=None, use_cache=True, output_attentions=False):
|
44 |
+
x = self.c_attn(x)
|
45 |
+
query, key, value = x.split(self.split_size, dim=2)
|
46 |
+
query = self._split_heads(query, self.num_heads, self.head_dim)
|
47 |
+
key = self._split_heads(key, self.num_heads, self.head_dim)
|
48 |
+
value = self._split_heads(value, self.num_heads, self.head_dim)
|
49 |
+
if layer_past is not None:
|
50 |
+
past_key, past_value = layer_past
|
51 |
+
key = torch.cat((past_key, key), dim=-2)
|
52 |
+
value = torch.cat((past_value, value), dim=-2)
|
53 |
+
if use_cache:
|
54 |
+
present = (key, value)
|
55 |
+
else:
|
56 |
+
present = None
|
57 |
+
|
58 |
+
z_conv = self.c_z(z)
|
59 |
+
key_z, value_z = z_conv.split(self.split_size, dim=2)
|
60 |
+
key_z = self._split_heads(key_z, self.num_heads, self.head_dim)
|
61 |
+
value_z = self._split_heads(value_z, self.num_heads, self.head_dim)
|
62 |
+
|
63 |
+
key = key_z
|
64 |
+
value = value_z
|
65 |
+
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
66 |
+
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
67 |
+
attn_output = self.c_proj(attn_output)
|
68 |
+
attn_output = self.resid_dropout(attn_output)
|
69 |
+
|
70 |
+
outputs = (attn_output, present)
|
71 |
+
if output_attentions:
|
72 |
+
outputs += (attn_weights,)
|
73 |
+
return outputs
|
74 |
+
|
75 |
+
class Cond_Block(GPT2Block):
|
76 |
+
def __init__(self, config,activate_a = False,activate_v = False):
|
77 |
+
super(GPT2Block, self).__init__()
|
78 |
+
self.activate_a = activate_a
|
79 |
+
self.activate_v = activate_v
|
80 |
+
nx = config.n_embd
|
81 |
+
self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
|
82 |
+
|
83 |
+
self.attn = Cond_Attention(nx,config.n_ctx,config)
|
84 |
+
|
85 |
+
self.attn_a =None if not self.activate_a else Cond_Attention(nx,config.n_ctx,config)
|
86 |
+
self.ln_a = None if not self.activate_a else nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
|
87 |
+
|
88 |
+
self.attn_v =None if not self.activate_v else Cond_Attention(nx,config.n_ctx,config)
|
89 |
+
self.ln_v = None if not self.activate_v else nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
|
90 |
+
|
91 |
+
self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
|
92 |
+
self.mlp = GPT2MLP(4 * nx, config)
|
93 |
+
|
94 |
+
def forward(self, x, a,v, layer_past=None, attention_mask=None, head_mask=None):
|
95 |
+
residual = x
|
96 |
+
x = self.ln_1(x)
|
97 |
+
attn_outputs = self.attn(
|
98 |
+
x=x,
|
99 |
+
z=x
|
100 |
+
)
|
101 |
+
attn_output = attn_outputs[0]
|
102 |
+
# outputs = attn_outputs[1:]
|
103 |
+
x = x + attn_output
|
104 |
+
if self.activate_a:
|
105 |
+
x = self.ln_a(x)
|
106 |
+
cross_attn_outputs = self.attn_a(
|
107 |
+
x=x,
|
108 |
+
z=a
|
109 |
+
)
|
110 |
+
cross_attn_output = cross_attn_outputs[0]
|
111 |
+
x = x + cross_attn_output
|
112 |
+
if self.activate_v:
|
113 |
+
x = self.ln_v(x)
|
114 |
+
cross_attn_outputs = self.attn_v(
|
115 |
+
x=x,
|
116 |
+
z=v
|
117 |
+
)
|
118 |
+
cross_attn_output = cross_attn_outputs[0]
|
119 |
+
x = x + cross_attn_output
|
120 |
+
m = self.mlp(self.ln_2(x))
|
121 |
+
x = x + m
|
122 |
+
outputs = (x,)
|
123 |
+
return outputs
|
124 |
+
|
125 |
+
class EmotionInjectionTransformer(GPT2Model):
|
126 |
+
def __init__(self, config, final_out_type="Linear+LN",sd_feature_dim=2048):
|
127 |
+
super(GPT2Model, self).__init__(config)
|
128 |
+
self.add_attn = True
|
129 |
+
self.sd_feature_dim = sd_feature_dim
|
130 |
+
self.activate_a = True
|
131 |
+
self.activate_v = True
|
132 |
+
self.output_hidden_states = config.output_hidden_states
|
133 |
+
self.output_attentions = config.output_attentions
|
134 |
+
self.use_cache = config.use_cache
|
135 |
+
self.embed_dim = config.n_embd
|
136 |
+
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
|
137 |
+
self.wpe = nn.Embedding(config.n_positions, self.embed_dim)
|
138 |
+
self.drop = nn.Dropout(config.embd_pdrop)
|
139 |
+
self.xl_feature2gpt_feature = nn.Linear(self.sd_feature_dim,config.n_embd,bias=False)
|
140 |
+
self.gpt_feature2xl_feature = nn.Linear(config.n_embd,self.sd_feature_dim,bias=False)
|
141 |
+
if final_out_type == "Linear+LN" or final_out_type=="Linear+LN+noResidual":
|
142 |
+
self.ln_xl_feature = nn.LayerNorm(self.sd_feature_dim, eps=1e-5)
|
143 |
+
elif final_out_type == "Linear+LN+Linear" or final_out_type=="Linear+LN+Linear+noResidual":
|
144 |
+
self.ln_xl_feature = nn.LayerNorm(self.sd_feature_dim, eps=1e-5)
|
145 |
+
self.ff = nn.Linear(self.sd_feature_dim,self.sd_feature_dim,bias=False)
|
146 |
+
else:
|
147 |
+
raise NotImplementedError
|
148 |
+
self.init_weights()
|
149 |
+
self.cross_token = 16
|
150 |
+
self.a_f = nn.Sequential(
|
151 |
+
nn.Linear(1, 256),
|
152 |
+
nn.ReLU(),
|
153 |
+
nn.Linear(256, config.n_embd*self.cross_token if self.activate_a else config.n_embd)
|
154 |
+
)
|
155 |
+
self.v_f = nn.Sequential(
|
156 |
+
nn.Linear(1, 256),
|
157 |
+
nn.ReLU(),
|
158 |
+
nn.Linear(256, config.n_embd*self.cross_token if self.activate_v else config.n_embd)
|
159 |
+
)
|
160 |
+
if self.add_attn:
|
161 |
+
self.attn_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
|
162 |
+
self.h = nn.ModuleList([Cond_Block(config,self.activate_a,self.activate_v) for _ in range(config.n_layer)])
|
163 |
+
else:
|
164 |
+
self.h = nn.ModuleList([GPT2Block(config) for _ in range(config.n_layer)])
|
165 |
+
self.final_out_type = final_out_type
|
166 |
+
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
167 |
+
def forward(
|
168 |
+
self,
|
169 |
+
input_ids=None,
|
170 |
+
past_key_values=None,
|
171 |
+
attention_mask=None,
|
172 |
+
token_type_ids=None,
|
173 |
+
position_ids=None,
|
174 |
+
head_mask=None,
|
175 |
+
inputs_embeds=None,
|
176 |
+
arousal=None,
|
177 |
+
valence=None,
|
178 |
+
):
|
179 |
+
if input_ids is not None and inputs_embeds is not None:
|
180 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
181 |
+
elif input_ids is not None:
|
182 |
+
input_shape = input_ids.size()
|
183 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
184 |
+
batch_size = input_ids.shape[0]
|
185 |
+
elif inputs_embeds is not None:
|
186 |
+
input_shape = inputs_embeds.size()[:-1]
|
187 |
+
batch_size = inputs_embeds.shape[0]
|
188 |
+
else:
|
189 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
190 |
+
|
191 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
192 |
+
|
193 |
+
if token_type_ids is not None:
|
194 |
+
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
195 |
+
if position_ids is not None:
|
196 |
+
position_ids = position_ids.view(-1, input_shape[-1])
|
197 |
+
|
198 |
+
if past_key_values is None:
|
199 |
+
past_length = 0
|
200 |
+
past_key_values = [None] * len(self.h)
|
201 |
+
else:
|
202 |
+
past_length = past_key_values[0][0].size(-2)
|
203 |
+
if position_ids is None:
|
204 |
+
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
205 |
+
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
206 |
+
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
207 |
+
|
208 |
+
if inputs_embeds is None:
|
209 |
+
inputs_embeds = self.wte(input_ids)
|
210 |
+
else:
|
211 |
+
residual = inputs_embeds
|
212 |
+
inputs_embeds = self.xl_feature2gpt_feature(inputs_embeds)
|
213 |
+
|
214 |
+
position_embeds = self.wpe(position_ids)
|
215 |
+
hidden_states = inputs_embeds + position_embeds
|
216 |
+
|
217 |
+
hidden_states = self.drop(hidden_states)
|
218 |
+
|
219 |
+
a_feature = self.attn_proj(self.a_f(arousal).view(-1, self.cross_token, self.config.n_embd) )
|
220 |
+
v_feature = self.attn_proj(self.v_f(valence).view(-1, self.cross_token, self.config.n_embd) )
|
221 |
+
|
222 |
+
output_shape = input_shape + (hidden_states.size(-1),)
|
223 |
+
|
224 |
+
all_self_attentions = () if self.output_attentions else None
|
225 |
+
all_hidden_states = () if self.output_hidden_states else None
|
226 |
+
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
227 |
+
if self.output_hidden_states:
|
228 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
229 |
+
outputs = block(
|
230 |
+
hidden_states, a = a_feature,v = v_feature, layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask[i]
|
231 |
+
)
|
232 |
+
hidden_states = outputs[0]
|
233 |
+
if self.output_attentions:
|
234 |
+
all_self_attentions = all_self_attentions + (outputs[2 if self.use_cache else 1],)
|
235 |
+
|
236 |
+
hidden_states = self.ln_f(hidden_states)
|
237 |
+
|
238 |
+
|
239 |
+
if self.final_out_type == "Linear+LN":
|
240 |
+
hidden_states = residual+self.ln_xl_feature(self.gpt_feature2xl_feature(hidden_states))
|
241 |
+
elif self.final_out_type == "Linear+LN+noResidual":
|
242 |
+
hidden_states = self.ln_xl_feature(self.gpt_feature2xl_feature(hidden_states))
|
243 |
+
elif self.final_out_type == "Linear+LN+Linear":
|
244 |
+
hidden_states = residual+self.ff(self.ln_xl_feature(self.gpt_feature2xl_feature(hidden_states)))
|
245 |
+
elif self.final_out_type == "Linear+LN+Linear+noResidual":
|
246 |
+
hidden_states = self.ff(self.ln_xl_feature(self.gpt_feature2xl_feature(hidden_states)))
|
247 |
+
elif self.final_out_type == "Linear+noResidual":
|
248 |
+
hidden_states = self.gpt_feature2xl_feature(hidden_states)
|
249 |
+
else:
|
250 |
+
hidden_states = residual+self.gpt_feature2xl_feature(hidden_states)
|
251 |
+
|
252 |
+
if self.output_hidden_states:
|
253 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
254 |
+
outputs = (hidden_states,)
|
255 |
+
if self.output_hidden_states:
|
256 |
+
outputs = outputs + (all_hidden_states,)
|
257 |
+
if self.output_attentions:
|
258 |
+
attention_output_shape = input_shape[:-1] + (-1,) + all_self_attentions[0].shape[-2:]
|
259 |
+
all_attentions = tuple(t.view(*attention_output_shape) for t in all_self_attentions)
|
260 |
+
outputs = outputs + (all_attentions,)
|
261 |
+
|
262 |
+
return outputs
|
263 |
+
|
264 |
+
|
265 |
+
|
266 |
+
|
267 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==2.1.0
|
2 |
+
accelerate
|
3 |
+
aiofiles==23.2.1
|
4 |
+
aiohttp==3.9.5
|
5 |
+
aiosignal==1.3.1
|
6 |
+
alembic==1.13.2
|
7 |
+
altair==5.4.0
|
8 |
+
annotated-types==0.7.0
|
9 |
+
antlr4-python3-runtime==4.9.3
|
10 |
+
anyio==4.4.0
|
11 |
+
|
12 |
+
astunparse==1.6.3
|
13 |
+
attrs==23.2.0
|
14 |
+
banal==1.0.6
|
15 |
+
bleach==6.1.0
|
16 |
+
blinker==1.8.2
|
17 |
+
blis==0.7.11
|
18 |
+
braceexpand==0.1.7
|
19 |
+
cachetools==5.5.0
|
20 |
+
catalogue==2.0.10
|
21 |
+
certifi==2024.6.2
|
22 |
+
cfgv==3.4.0
|
23 |
+
charset-normalizer==3.3.2
|
24 |
+
click==8.1.7
|
25 |
+
cloudpathlib==0.18.1
|
26 |
+
|
27 |
+
confection==0.1.5
|
28 |
+
contexttimer==0.3.3
|
29 |
+
contourpy==1.2.1
|
30 |
+
cycler==0.12.1
|
31 |
+
cymem==2.0.8
|
32 |
+
datasets==2.20.0
|
33 |
+
|
34 |
+
decord==0.6.0
|
35 |
+
diffusers==0.31.0
|
36 |
+
transformers==4.46.3
|
37 |
+
|
38 |
+
dill==0.3.8
|
39 |
+
distlib==0.3.8
|
40 |
+
distro==1.9.0
|
41 |
+
docker-pycreds==0.4.0
|
42 |
+
einops==0.8.0
|
43 |
+
|
44 |
+
fairscale
|
45 |
+
fastapi==0.112.0
|
46 |
+
ffmpy==0.4.0
|
47 |
+
filelock==3.15.1
|
48 |
+
Flask==3.0.3
|
49 |
+
Flask-Cors==4.0.1
|
50 |
+
flatbuffers==24.3.25
|
51 |
+
fonttools==4.53.0
|
52 |
+
frozenlist==1.4.1
|
53 |
+
fsspec==2024.5.0
|
54 |
+
ftfy==6.2.3
|
55 |
+
gast==0.6.0
|
56 |
+
gitdb==4.0.11
|
57 |
+
gitlab==1.0.2
|
58 |
+
GitPython==3.1.43
|
59 |
+
google-pasta==0.2.0
|
60 |
+
gradio==4.41.0
|
61 |
+
gradio_client==1.3.0
|
62 |
+
greenlet==3.0.3
|
63 |
+
grpcio==1.64.1
|
64 |
+
h11==0.14.0
|
65 |
+
h5py==3.11.0
|
66 |
+
httpcore==1.0.5
|
67 |
+
httpx==0.27.0
|
68 |
+
huggingface-hub==0.23.4
|
69 |
+
identify==2.6.0
|
70 |
+
idna==3.7
|
71 |
+
imageio==2.34.1
|
72 |
+
importlib_metadata==7.1.0
|
73 |
+
importlib_resources==6.4.0
|
74 |
+
invisible-watermark
|
75 |
+
iopath==0.1.10
|
76 |
+
|
77 |
+
itsdangerous==2.2.0
|
78 |
+
|
79 |
+
Jinja2==3.1.4
|
80 |
+
joblib==1.4.2
|
81 |
+
jsonlines==4.0.0
|
82 |
+
jsonschema==4.23.0
|
83 |
+
jsonschema-specifications==2023.12.1
|
84 |
+
|
85 |
+
kaggle==1.6.17
|
86 |
+
keras==3.4.1
|
87 |
+
kiwisolver==1.4.5
|
88 |
+
langcodes==3.4.0
|
89 |
+
language_data==1.2.0
|
90 |
+
lazy_loader==0.4
|
91 |
+
libclang==18.1.1
|
92 |
+
lightning-utilities==0.11.6
|
93 |
+
lmdb==1.4.1
|
94 |
+
lmdbdict==0.2.2
|
95 |
+
Mako==1.3.5
|
96 |
+
marisa-trie==1.2.0
|
97 |
+
Markdown==3.6
|
98 |
+
markdown-it-py==3.0.0
|
99 |
+
MarkupSafe==2.1.5
|
100 |
+
matplotlib==3.9.0
|
101 |
+
|
102 |
+
mdurl==0.1.2
|
103 |
+
ml-dtypes==0.4.0
|
104 |
+
modelscope==1.16.1
|
105 |
+
mpmath==1.3.0
|
106 |
+
multidict==6.0.5
|
107 |
+
multiprocess==0.70.16
|
108 |
+
murmurhash==1.0.10
|
109 |
+
namex==0.0.8
|
110 |
+
narwhals==1.4.2
|
111 |
+
|
112 |
+
networkx
|
113 |
+
nodeenv==1.9.1
|
114 |
+
numpy==1.26.4
|
115 |
+
nvidia-cublas-cu12==12.1.3.1
|
116 |
+
nvidia-cuda-cupti-cu12==12.1.105
|
117 |
+
nvidia-cuda-nvrtc-cu12==12.1.105
|
118 |
+
nvidia-cuda-runtime-cu12==12.1.105
|
119 |
+
nvidia-cudnn-cu12==8.9.2.26
|
120 |
+
nvidia-cufft-cu12==11.0.2.54
|
121 |
+
nvidia-curand-cu12==10.3.2.106
|
122 |
+
nvidia-cusolver-cu12==11.4.5.107
|
123 |
+
nvidia-cusparse-cu12==12.1.0.106
|
124 |
+
nvidia-nccl-cu12
|
125 |
+
nvidia-nvjitlink-cu12==12.5.40
|
126 |
+
nvidia-nvtx-cu12==12.1.105
|
127 |
+
omegaconf==2.3.0
|
128 |
+
openai==0.28.0
|
129 |
+
opencv-python-headless==4.5.5.64
|
130 |
+
opendatasets==0.1.22
|
131 |
+
opt-einsum==3.3.0
|
132 |
+
optree==0.12.1
|
133 |
+
orjson==3.10.7
|
134 |
+
|
135 |
+
pandas==2.2.2
|
136 |
+
|
137 |
+
peft
|
138 |
+
|
139 |
+
pillow==10.3.0
|
140 |
+
piq==0.8.0
|
141 |
+
|
142 |
+
plotly==5.23.0
|
143 |
+
portalocker==2.10.1
|
144 |
+
pre-commit==3.8.0
|
145 |
+
preshed==3.0.9
|
146 |
+
|
147 |
+
protobuf==4.25.3
|
148 |
+
psutil==6.0.0
|
149 |
+
|
150 |
+
pyarrow==17.0.0
|
151 |
+
pyarrow-hotfix==0.6
|
152 |
+
pycocoevalcap==1.2
|
153 |
+
pycocotools==2.0.8
|
154 |
+
pydantic==2.8.2
|
155 |
+
pydantic_core==2.20.1
|
156 |
+
pydeck==0.9.1
|
157 |
+
pydub==0.25.1
|
158 |
+
|
159 |
+
pyparsing==3.1.2
|
160 |
+
|
161 |
+
python-gitlab==4.6.0
|
162 |
+
python-magic==0.4.27
|
163 |
+
python-multipart==0.0.9
|
164 |
+
python-slugify==8.0.4
|
165 |
+
pytz==2024.1
|
166 |
+
PyYAML==6.0.1
|
167 |
+
|
168 |
+
referencing==0.35.1
|
169 |
+
regex==2024.5.15
|
170 |
+
requests==2.32.3
|
171 |
+
requests-toolbelt==1.0.0
|
172 |
+
rich==13.7.1
|
173 |
+
rpds-py==0.20.0
|
174 |
+
ruff==0.5.7
|
175 |
+
safetensors==0.4.3
|
176 |
+
salesforce-lavis
|
177 |
+
scikit-image==0.24.0
|
178 |
+
scikit-learn==1.5.0
|
179 |
+
scipy==1.13.1
|
180 |
+
seaborn==0.13.2
|
181 |
+
semantic-version==2.10.0
|
182 |
+
sentencepiece==0.2.0
|
183 |
+
sentry-sdk==2.10.0
|
184 |
+
setproctitle==1.3.3
|
185 |
+
shellingham==1.5.4
|
186 |
+
|
187 |
+
smart-open==7.0.4
|
188 |
+
smmap==5.0.1
|
189 |
+
sniffio==1.3.1
|
190 |
+
spacy==3.7.5
|
191 |
+
spacy-legacy==3.0.12
|
192 |
+
spacy-loggers==1.0.5
|
193 |
+
SQLAlchemy==1.4.52
|
194 |
+
srsly==2.4.8
|
195 |
+
|
196 |
+
starlette==0.37.2
|
197 |
+
streamlit==1.37.1
|
198 |
+
sympy==1.12.1
|
199 |
+
tenacity==8.5.0
|
200 |
+
tensorboard==2.17.0
|
201 |
+
tensorboard-data-server==0.7.2
|
202 |
+
tensorflow==2.17.0
|
203 |
+
tensorflow-io-gcs-filesystem==0.37.1
|
204 |
+
termcolor==2.4.0
|
205 |
+
text-unidecode==1.3
|
206 |
+
thinc==8.2.5
|
207 |
+
threadpoolctl==3.5.0
|
208 |
+
tifffile==2024.6.18
|
209 |
+
timm
|
210 |
+
tokenizers
|
211 |
+
toml==0.10.2
|
212 |
+
tomlkit==0.12.0
|
213 |
+
torch==2.2.0
|
214 |
+
torchmetrics
|
215 |
+
torchstat
|
216 |
+
torchsummary==1.5.1
|
217 |
+
torchvision
|
218 |
+
|
219 |
+
tqdm==4.66.4
|
220 |
+
|
221 |
+
|
222 |
+
triton==2.2.0
|
223 |
+
typer==0.12.3
|
224 |
+
|
225 |
+
tzdata==2024.1
|
226 |
+
urllib3==2.2.2
|
227 |
+
uvicorn==0.30.5
|
228 |
+
virtualenv==20.26.3
|
229 |
+
wandb==0.17.5
|
230 |
+
wasabi==1.1.3
|
231 |
+
watchdog==4.0.2
|
232 |
+
|
233 |
+
weasel==0.4.1
|
234 |
+
webdataset==0.2.96
|
235 |
+
webencodings==0.5.1
|
236 |
+
websockets==12.0
|
237 |
+
Werkzeug==3.0.3
|
238 |
+
wrapt==1.16.0
|
239 |
+
xformer==1.0.1
|
240 |
+
xxhash==3.4.1
|
241 |
+
yarl==1.9.4
|