File size: 1,519 Bytes
90a9dd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# From wikipedia. MATLAB code,
# function x = conjgrad(A, b, x)
#     r = b - A * x;
#     p = r;
#     rsold = r' * r;
# 
#     for i = 1:length(b)
#         Ap = A * p;
#         alpha = rsold / (p' * Ap);
#         x = x + alpha * p;
#         r = r - alpha * Ap;
#         rsnew = r' * r;
#         if sqrt(rsnew) < 1e-10
#             break
#         end
#         p = r + (rsnew / rsold) * p;
#         rsold = rsnew;
#     end
# end

from typing import Callable, Optional

import torch


def CG(A: Callable,
       b: torch.Tensor,
       x: torch.Tensor,
       m: Optional[int]=5,
       eps: Optional[float]=1e-4,
       damping: float=0.0,
       use_mm: bool=False) -> torch.Tensor:
    
    if use_mm:
        mm_fn = lambda x, y: torch.mm(x.view(1, -1), y.view(1, -1).T)
    else:
        mm_fn = lambda x, y: (x * y).flatten().sum()
    
    orig_shape = x.shape
    x = x.view(x.shape[0], -1)

    r = b - A(x)
    p = r.clone()

    rsold = mm_fn(r, r)
    assert not (rsold != rsold).any(), print(f'NaN detected 1')

    for i in range(m):
        Ap = A(p) + damping * p
        alpha = rsold / mm_fn(p, Ap)
        assert not (alpha != alpha).any(), print(f'NaN detected 2')
        
        x = x + alpha * p
        r = r - alpha * Ap

        rsnew = mm_fn(r, r)
        assert not (rsnew != rsnew).any(), print('NaN detected 3')
        
        if rsnew.sqrt().abs() < eps:
            break

        p = r + (rsnew / rsold) * p
        rsold = rsnew.clone()

    return x.reshape(orig_shape)