Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	| import torch | |
| import torch.nn as nn | |
| from .util import timestep_embedding | |
| class PooledMLP(nn.Module): | |
| def __init__( | |
| self, | |
| device: torch.device, | |
| *, | |
| input_channels: int = 3, | |
| output_channels: int = 6, | |
| hidden_size: int = 256, | |
| resblocks: int = 4, | |
| pool_op: str = "max", | |
| ): | |
| super().__init__() | |
| self.input_embed = nn.Conv1d(input_channels, hidden_size, kernel_size=1, device=device) | |
| self.time_embed = nn.Linear(hidden_size, hidden_size, device=device) | |
| blocks = [] | |
| for _ in range(resblocks): | |
| blocks.append(ResBlock(hidden_size, pool_op, device=device)) | |
| self.sequence = nn.Sequential(*blocks) | |
| self.out = nn.Conv1d(hidden_size, output_channels, kernel_size=1, device=device) | |
| with torch.no_grad(): | |
| self.out.bias.zero_() | |
| self.out.weight.zero_() | |
| def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: | |
| in_embed = self.input_embed(x) | |
| t_embed = self.time_embed(timestep_embedding(t, in_embed.shape[1])) | |
| h = in_embed + t_embed[..., None] | |
| h = self.sequence(h) | |
| h = self.out(h) | |
| return h | |
| class ResBlock(nn.Module): | |
| def __init__(self, hidden_size: int, pool_op: str, device: torch.device): | |
| super().__init__() | |
| assert pool_op in ["mean", "max"] | |
| self.pool_op = pool_op | |
| self.body = nn.Sequential( | |
| nn.SiLU(), | |
| nn.LayerNorm((hidden_size,), device=device), | |
| nn.Linear(hidden_size, hidden_size, device=device), | |
| nn.SiLU(), | |
| nn.LayerNorm((hidden_size,), device=device), | |
| nn.Linear(hidden_size, hidden_size, device=device), | |
| ) | |
| self.gate = nn.Sequential( | |
| nn.Linear(hidden_size, hidden_size, device=device), | |
| nn.Tanh(), | |
| ) | |
| def forward(self, x: torch.Tensor): | |
| N, C, T = x.shape | |
| out = self.body(x.permute(0, 2, 1).reshape(N * T, C)).reshape([N, T, C]).permute(0, 2, 1) | |
| pooled = pool(self.pool_op, x) | |
| gate = self.gate(pooled) | |
| return x + out * gate[..., None] | |
| def pool(op_name: str, x: torch.Tensor) -> torch.Tensor: | |
| if op_name == "max": | |
| pooled, _ = torch.max(x, dim=-1) | |
| elif op_name == "mean": | |
| pooled, _ = torch.mean(x, dim=-1) | |
| else: | |
| raise ValueError(f"unknown pool op: {op_name}") | |
| return pooled | |