import torch | |
import torch.nn as nn | |
#simple test for the linear function to determine it's properties | |
m = nn.Linear(20, 30) | |
input = torch.randn(128, 20) | |
output = m(input) | |
print(output.size()) | |
print(output) |
import torch | |
import torch.nn as nn | |
#simple test for the linear function to determine it's properties | |
m = nn.Linear(20, 30) | |
input = torch.randn(128, 20) | |
output = m(input) | |
print(output.size()) | |
print(output) |