Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,519 Bytes
90a9dd3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
# From wikipedia. MATLAB code,
# function x = conjgrad(A, b, x)
# r = b - A * x;
# p = r;
# rsold = r' * r;
#
# for i = 1:length(b)
# Ap = A * p;
# alpha = rsold / (p' * Ap);
# x = x + alpha * p;
# r = r - alpha * Ap;
# rsnew = r' * r;
# if sqrt(rsnew) < 1e-10
# break
# end
# p = r + (rsnew / rsold) * p;
# rsold = rsnew;
# end
# end
from typing import Callable, Optional
import torch
def CG(A: Callable,
b: torch.Tensor,
x: torch.Tensor,
m: Optional[int]=5,
eps: Optional[float]=1e-4,
damping: float=0.0,
use_mm: bool=False) -> torch.Tensor:
if use_mm:
mm_fn = lambda x, y: torch.mm(x.view(1, -1), y.view(1, -1).T)
else:
mm_fn = lambda x, y: (x * y).flatten().sum()
orig_shape = x.shape
x = x.view(x.shape[0], -1)
r = b - A(x)
p = r.clone()
rsold = mm_fn(r, r)
assert not (rsold != rsold).any(), print(f'NaN detected 1')
for i in range(m):
Ap = A(p) + damping * p
alpha = rsold / mm_fn(p, Ap)
assert not (alpha != alpha).any(), print(f'NaN detected 2')
x = x + alpha * p
r = r - alpha * Ap
rsnew = mm_fn(r, r)
assert not (rsnew != rsnew).any(), print('NaN detected 3')
if rsnew.sqrt().abs() < eps:
break
p = r + (rsnew / rsold) * p
rsold = rsnew.clone()
return x.reshape(orig_shape)
|