Spaces:
Running
Running
File size: 4,628 Bytes
684943d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import torch
from torch.optim import Optimizer
class PerPointAdam(Optimizer):
"""Implements Adam optimizer with per-point learning rates.
Allows unique learning rates for each point in specified parameter tensors,
useful for point cloud optimization.
Args:
params: Iterable of parameters to optimize or parameter groups
lr (float, optional): Default learning rate (default: 1e-3)
betas (tuple, optional): Coefficients for moving averages (default: (0.9, 0.999))
eps (float, optional): Term for numerical stability (default: 1e-8)
weight_decay (float, optional): Weight decay (L2 penalty) (default: 0)
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
if not all(0.0 <= x for x in [lr, eps, weight_decay]):
raise ValueError(f"Invalid learning parameters: lr={lr}, eps={eps}, weight_decay={weight_decay}")
if not all(0.0 <= beta < 1.0 for beta in betas):
raise ValueError(f"Invalid beta parameters: {betas}")
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, per_point_lr=None)
super().__init__(params, defaults)
def _adjust_per_point_lr(self, per_point_lr, grad, mask):
"""Adjusts per-point learning rates based on gradient magnitudes."""
grad_magnitude = grad.norm(dim=-1)
scaling_factor = torch.ones_like(grad_magnitude)
grad_sigmoid = torch.sigmoid(grad_magnitude[mask])
scaling_factor[mask] = 0.99 + (grad_sigmoid * 0.02)
return per_point_lr * scaling_factor.unsqueeze(1)
def step(self, closure=None):
"""Performs a single optimization step."""
loss = closure() if closure is not None else None
for group in self.param_groups:
per_point_lr = group.get('per_point_lr')
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('PerPointAdam does not support sparse gradients')
# Initialize state if needed
state = self.state[p]
if len(state) == 0:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p.data)
state['exp_avg_sq'] = torch.zeros_like(p.data)
# Get state values
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
# Apply weight decay if specified
if group['weight_decay'] != 0:
grad = grad.add(p.data, alpha=group['weight_decay'])
# Compute mask for non-zero gradients
grad_norm = grad.norm()
mask = grad_norm > 0
# Update momentum terms
exp_avg.masked_scatter_(mask,
exp_avg[mask].mul_(beta1).add_(grad[mask], alpha=1 - beta1))
exp_avg_sq.masked_scatter_(mask,
exp_avg_sq[mask].mul_(beta2).addcmul_(grad[mask], grad[mask], value=1 - beta2))
# Compute bias corrections
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
# Compute step size
denom = exp_avg_sq.sqrt().add_(group['eps'])
step_size = group['lr'] * (bias_correction2 ** 0.5 / bias_correction1)
# Apply updates
if per_point_lr is not None:
if not isinstance(per_point_lr, torch.Tensor):
raise TypeError("per_point_lr must be a torch.Tensor")
if per_point_lr.device != p.data.device:
raise ValueError("per_point_lr must be on the same device as parameter")
expected_shape = p.data.shape[:1] + (1,) * (p.data.dim() - 1)
if per_point_lr.shape != expected_shape:
raise ValueError(f"{group['name']}: Invalid per_point_lr shape. Expected {expected_shape}, got {per_point_lr.shape}")
scaled_step_size = step_size * per_point_lr
p.data.add_(-scaled_step_size * (exp_avg / denom))
per_point_lr = self._adjust_per_point_lr(per_point_lr, grad, mask)
else:
p.data.addcdiv_(exp_avg, denom, value=-step_size)
return loss |