File size: 3,940 Bytes
8896a5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import numpy as np

import torch
import torch.nn as nn
import torch.functional as F

class LogisticActivation(nn.Module):
    """
    Implementation of Generalized Sigmoid
    Applies the element-wise function:
    .. math::
        \text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-k(x-x_0))}
    Shape:
        - Input: :math:`(N, *)` where `*` means, any number of additional dimensions
        - Output: :math:`(N, *)`, same shape as the input
    Parameters:
        - x0: The value of the sigmoid midpoint
        - k: The slope of the sigmoid - trainable
    Examples:
        >>> logAct = LogisticActivation(0, 5)
        >>> x = torch.randn(256)
        >>> x = logAct(x)
    """

    def __init__(self, x0 = 0, k = 1, train=False):
        """
        Initialization
        INPUT:
            - x0: The value of the sigmoid midpoint
            - k: The slope of the sigmoid - trainable
            - train: Whether to make k a trainable parameter
            x0 and k are initialized to 0,1 respectively
            Behaves the same as torch.sigmoid by default
        """
        super(LogisticActivation,self).__init__()
        self.x0 = x0
        self.k = nn.Parameter(torch.FloatTensor([float(k)]))
        self.k.requiresGrad = train

    def forward(self, x):
        """
        Applies the function to the input elementwise
        """
        o = torch.clamp(1 / (1 + torch.exp(-self.k * (x - self.x0))), min=0, max=1).squeeze()
        return o

    def clip(self):
        self.k.data.clamp_(min=0)

class ModelInteraction(nn.Module):
    def __init__(self, embedding, contact, use_cuda, pool_size=9, theta_init=1, lambda_init = 0, gamma_init = 0, use_W=True):
        super(ModelInteraction, self).__init__()
        self.use_cuda = use_cuda
        self.use_W = use_W
        self.activation = LogisticActivation(x0=0.5, k = 20)

        self.embedding = embedding
        self.contact = contact

        if self.use_W:
            self.theta = nn.Parameter(torch.FloatTensor([theta_init]))
            self.lambda_ = nn.Parameter(torch.FloatTensor([lambda_init]))

        self.maxPool = nn.MaxPool2d(pool_size,padding=pool_size//2)
        self.gamma = nn.Parameter(torch.FloatTensor([gamma_init]))

        self.clip()

    def clip(self):
        self.contact.clip()

        if self.use_W:
            self.theta.data.clamp_(min=0, max=1)
            self.lambda_.data.clamp_(min=0)

        self.gamma.data.clamp_(min=0)

    def embed(self, x):
        if self.embedding is None:
            return x
        else:
            return self.embedding(x)

    def cpred(self, z0, z1):
        e0 = self.embed(z0)
        e1 = self.embed(z1)
        B = self.contact.cmap(e0, e1)
        C = self.contact.predict(B)
        return C

    def map_predict(self, z0, z1):

        C = self.cpred(z0, z1)

        if self.use_W:
            # Create contact weighting matrix
            N, M = C.shape[2:]

            x1 = torch.from_numpy(-1 * ((np.arange(N)+1 - ((N+1)/2)) / (-1 * ((N+1)/2)))**2).float()
            if self.use_cuda:
                x1 = x1.cuda()
            x1 = torch.exp(self.lambda_ * x1)

            x2 = torch.from_numpy(-1 * ((np.arange(M)+1 - ((M+1)/2)) / (-1 * ((M+1)/2)))**2).float()
            if self.use_cuda:
                x2 = x2.cuda()
            x2 = torch.exp(self.lambda_ * x2)

            W = x1.unsqueeze(1) * x2
            W = (1 - self.theta) * W + self.theta

            yhat = C * W

        else:
            yhat = C

        yhat = self.maxPool(yhat)

        # Mean of contact predictions where p_ij > mu + gamma*sigma
        mu = torch.mean(yhat)
        sigma = torch.var(yhat)
        Q = torch.relu(yhat - mu - (self.gamma * sigma))
        phat = torch.sum(Q) / (torch.sum(torch.sign(Q)) + 1)
        phat = self.activation(phat)
        return C, phat

    def predict(self, z0, z1):
        _, phat = self.map_predict(z0,z1)
        return phat