Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,561 Bytes
9867d34 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import math
import torch
import torch.nn as nn
from ...utils.helper import to_2tuple, to_1tuple
class PatchEmbed1D(nn.Module):
"""1D Audio to Patch Embedding
A convolution based approach to patchifying a 1D audio w/ embedding projection.
Based on the impl in https://github.com/google-research/vision_transformer
Hacked together by / Copyright 2020 Ross Wightman
"""
def __init__(
self,
patch_size=1,
in_chans=768,
embed_dim=768,
norm_layer=None,
flatten=True,
bias=True,
dtype=None,
device=None,
):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
patch_size = to_1tuple(patch_size)
self.patch_size = patch_size
self.flatten = flatten
self.proj = nn.Conv1d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, **factory_kwargs
)
nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1))
if bias:
nn.init.zeros_(self.proj.bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
assert (
x.shape[2] % self.patch_size[0] == 0
), f"The patch_size of {self.patch_size[0]} must be divisible by the token number ({x.shape[2]}) of x."
x = self.proj(x)
if self.flatten:
x = x.transpose(1, 2) # BCN -> BNC
x = self.norm(x)
return x
class ConditionProjection(nn.Module):
"""
Projects condition embeddings. Also handles dropout for classifier-free guidance.
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
"""
def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
factory_kwargs = {'dtype': dtype, 'device': device}
super().__init__()
self.linear_1 = nn.Linear(in_features=in_channels, out_features=hidden_size, bias=True, **factory_kwargs)
self.act_1 = act_layer()
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True, **factory_kwargs)
def forward(self, caption):
hidden_states = self.linear_1(caption)
hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
Args:
t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
dim (int): the dimension of the output.
max_period (int): controls the minimum frequency of the embeddings.
Returns:
embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
.. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
return embedding
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self,
hidden_size,
act_layer,
frequency_embedding_size=256,
max_period=10000,
out_size=None,
dtype=None,
device=None
):
factory_kwargs = {'dtype': dtype, 'device': device}
super().__init__()
self.frequency_embedding_size = frequency_embedding_size
self.max_period = max_period
if out_size is None:
out_size = hidden_size
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs),
act_layer(),
nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
)
nn.init.normal_(self.mlp[0].weight, std=0.02)
nn.init.normal_(self.mlp[2].weight, std=0.02)
def forward(self, t):
t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype)
t_emb = self.mlp(t_freq)
return t_emb
|