QHL067's picture
working
f9567e5
import torch
import torch.nn as nn
from torch.utils.cpp_extension import load
import os
import time
import random
import math
from torch.utils.checkpoint import checkpoint
from torch.autograd import Function
from functools import partial
import warnings
# curr_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extension")
# src_files = ['tdp.cu', 'torch_extension.cpp']
# src_files = [os.path.join(curr_path, file) for file in src_files]
# tdp = load('tdp', src_files, verbose = True)
# import tdp
def exported_tdp(param0, param1, weight, bias, times, custom = True):
original_shape = param0.shape
param0 = param0.reshape(-1)
param1 = param1.reshape(-1)
weight = weight.reshape(-1)
bias = bias.reshape(-1)
if custom and param0.shape[0] % 2 == 0:
result = TDP.apply(param0, param1, weight, bias, times)
else:
warnings.warn(f'Using slower tdp_torch implementation for a tensor with shape {param0.shape}')
result = tdp_torch(param0, param1, weight, bias, times)
result = result.reshape(*([times.shape[0]] + [d for d in original_shape]))
return result
class TDP(Function):
@staticmethod
def forward(ctx, param0, param1, weight, bias, times):
assert param0.shape[0] % 2 == 0
param0 = param0.contiguous()
param1 = param1.contiguous()
weight = weight.contiguous()
bias = bias.contiguous()
times = times.contiguous()
assert param0.shape[0] == param1.shape[0] and param0.shape[0] == weight.shape[0] and param0.shape[0] == bias.shape[0]
assert param0.dim() == 1 and param1.dim() == 1 and weight.dim() == 1 and bias.dim() == 1 and times.dim() == 1
ctx.save_for_backward(param0, param1, weight, bias, times)
return tdp_cuda(param0, param1, weight, bias, times)
@staticmethod
def backward(ctx, g_result):
g_result = g_result.contiguous()
param0, param1, weight, bias, times = ctx.saved_tensors
g_param0, g_param1, g_weight, g_bias = backward_tdp_cuda(param0, param1, weight, bias, times, g_result)
return g_param0, g_param1, g_weight, g_bias, None
def backward_tdp_torch(param0, param1, weight, bias, times, g_result):
param0 = param0[None]
param1 = param1[None]
weight = weight[None]
bias = bias[None]
a = times[:, None] * weight + bias
s = torch.sigmoid(a)
g_param0 = (s * g_result).sum(0)
g_param1 = ((1 - s) * g_result).sum(0)
g_s = (param0 - param1) * g_result
g_a = g_s * s * (1 - s)
g_weight = (g_a * times[:, None]).sum(0)
g_bias = g_a.sum(0)
return g_param0, g_param1, g_weight, g_bias
def backward_tdp_cuda(param0, param1, weight, bias, times, g_result):
g_param0 = torch.empty_like(param0)
g_param1 = torch.empty_like(param0)
g_weight = torch.empty_like(param0)
g_bias = torch.empty_like(param0)
if param0.dtype == torch.half:
tdp.backward_tdp_fp16(param0, param1, weight, bias, times, g_result, g_param0, g_param1, g_weight, g_bias)
elif param0.dtype == torch.float:
tdp.backward_tdp_fp32(param0, param1, weight, bias, times, g_result, g_param0, g_param1, g_weight, g_bias)
else:
raise NotImplementedError
return g_param0, g_param1, g_weight, g_bias
def tdp_torch(param0, param1, weight, bias, times):
a = torch.addcmul(bias[None], times[:, None], weight[None])
s = torch.sigmoid(a)
result = torch.addcmul(param1[None], s, param0[None] - param1[None])
return result
def tdp_cuda(param0, param1, weight, bias, times):
result = torch.empty(times.shape[0], param0.shape[0], dtype = param0.dtype, device = param0.device)
if param0.dtype == torch.half:
tdp.tdp_fp16(param0, param1, weight, bias, times, result)
elif param0.dtype == torch.float:
tdp.tdp_fp32(param0, param1, weight, bias, times, result)
else:
raise NotImplementedError
return result
def corrcoef(x, y):
return torch.corrcoef(torch.stack([x.reshape(-1).float(), y.reshape(-1).float()], dim = 0))[0, 1]
def tdp_cuda_unit_test():
print("***** tdp_cuda_unit_test *****")
batch_size = random.randrange(1, 128)
num_params = random.randrange(1, 1000000) * 2
print("batch_size", batch_size, "num_params", num_params)
param0 = torch.randn(num_params).cuda()
param1 = torch.randn(num_params).cuda()
weight = torch.randn(num_params).cuda()
bias = torch.randn(num_params).cuda()
times = torch.rand(batch_size).cuda()
ref = tdp_torch(param0, param1, weight, bias, times)
out = tdp_cuda(param0, param1, weight, bias, times)
print(corrcoef(ref, out), (ref - out).abs().max())
out = tdp_cuda(param0.half(), param1.half(), weight.half(), bias.half(), times.half()).float()
print(corrcoef(ref, out), (ref - out).abs().max())
def backward_tdp_cuda_unit_test():
print("***** backward_tdp_cuda_unit_test *****")
batch_size = random.randrange(1, 128)
num_params = random.randrange(1, 100000) * 2
print("batch_size", batch_size, "num_params", num_params)
param0 = torch.randn(num_params).cuda()
param1 = torch.randn(num_params).cuda()
weight = torch.randn(num_params).cuda()
bias = torch.randn(num_params).cuda()
times = torch.rand(batch_size).cuda()
g_result = torch.randn(batch_size, num_params).cuda()
refs = backward_tdp_torch(param0, param1, weight, bias, times, g_result)
outs = backward_tdp_cuda(param0, param1, weight, bias, times, g_result)
for r, o in zip(refs, outs):
print(corrcoef(r, o), (r - o).abs().max())
outs = backward_tdp_cuda(param0.half(), param1.half(), weight.half(), bias.half(), times.half(), g_result.half())
for r, o in zip(refs, outs):
print(corrcoef(r, o), (r - o).abs().max())
def autograd_unit_test():
print("***** autograd_unit_test *****")
batch_size = random.randrange(1, 128)
num_params = random.randrange(1, 100000) * 2
print("batch_size", batch_size, "num_params", num_params)
def get_outputs(fn):
torch.manual_seed(1)
param0 = torch.randn(num_params, requires_grad = True).cuda()
param1 = torch.randn(num_params, requires_grad = True).cuda()
weight = torch.randn(num_params, requires_grad = True).cuda()
bias = torch.randn(num_params, requires_grad = True).cuda()
times = torch.rand(batch_size).cuda()
out = fn(param0, param1, weight, bias, times)
loss = ((out - 1.5) ** 2).mean()
param0.retain_grad()
param1.retain_grad()
weight.retain_grad()
bias.retain_grad()
loss.backward()
g_param0 = param0.grad
g_param1 = param1.grad
g_weight = weight.grad
g_bias = bias.grad
return out, g_param0, g_param1, g_weight, g_bias
refs = get_outputs(tdp_torch)
outs = get_outputs(TDP.apply)
for r, o in zip(refs, outs):
print(corrcoef(r, o), (r - o).abs().max())
def exported_tdp_unit_test():
print("***** exported_tdp_unit_test *****")
batch_size = random.randrange(1, 128)
num_params = random.randrange(1, 100000) * 2
print("batch_size", batch_size, "num_params", num_params)
def get_outputs(fn):
torch.manual_seed(1)
param0 = torch.randn(num_params, requires_grad = True).cuda()
param1 = torch.randn(num_params, requires_grad = True).cuda()
weight = torch.randn(num_params, requires_grad = True).cuda()
bias = torch.randn(num_params, requires_grad = True).cuda()
times = torch.rand(batch_size).cuda()
out = fn(param0, param1, weight, bias, times)
loss = ((out - 1.5) ** 2).mean()
param0.retain_grad()
param1.retain_grad()
weight.retain_grad()
bias.retain_grad()
loss.backward()
g_param0 = param0.grad
g_param1 = param1.grad
g_weight = weight.grad
g_bias = bias.grad
return out, g_param0, g_param1, g_weight, g_bias
refs = get_outputs(partial(exported_tdp, custom = False))
outs = get_outputs(partial(exported_tdp, custom = True))
for r, o in zip(refs, outs):
print(corrcoef(r, o), (r - o).abs().max())
def tdp_cuda_profile():
print("***** tdp_cuda_profile *****")
def profiler(fn, args):
for _ in range(10):
fn(*args)
torch.cuda.synchronize()
t0 = time.time()
for _ in range(100):
fn(*args)
torch.cuda.synchronize()
t1 = time.time()
return t1 - t0
batch_size = 16
num_params = 1024 * 1024
print("batch_size", batch_size, "num_params", num_params)
param0 = torch.randn(num_params).cuda()
param1 = torch.randn(num_params).cuda()
weight = torch.randn(num_params).cuda()
bias = torch.randn(num_params).cuda()
times = torch.rand(batch_size).cuda()
print("ref", profiler(tdp_torch, (param0, param1, weight, bias, times)))
print("cuda", profiler(tdp_cuda, (param0, param1, weight, bias, times)))
print("ref", profiler(tdp_torch, (param0.half(), param1.half(), weight.half(), bias.half(), times.half())))
print("cuda", profiler(tdp_cuda, (param0.half(), param1.half(), weight.half(), bias.half(), times.half())))
def backward_tdp_cuda_profile():
print("***** backward_tdp_cuda_profile *****")
def profiler(fn, args):
for _ in range(10):
fn(*args)
torch.cuda.synchronize()
t0 = time.time()
for _ in range(100):
fn(*args)
torch.cuda.synchronize()
t1 = time.time()
return t1 - t0
batch_size = 16
num_params = 1024 * 1024
print("batch_size", batch_size, "num_params", num_params)
param0 = torch.randn(num_params).cuda()
param1 = torch.randn(num_params).cuda()
weight = torch.randn(num_params).cuda()
bias = torch.randn(num_params).cuda()
times = torch.rand(batch_size).cuda()
g_result = torch.randn(batch_size, num_params).cuda()
print("ref", profiler(backward_tdp_torch, (param0, param1, weight, bias, times, g_result)))
print("cuda", profiler(backward_tdp_cuda, (param0, param1, weight, bias, times, g_result)))
print("ref", profiler(backward_tdp_torch, (param0.half(), param1.half(), weight.half(), bias.half(), times.half(), g_result.half())))
print("cuda", profiler(backward_tdp_cuda, (param0.half(), param1.half(), weight.half(), bias.half(), times.half(), g_result.half())))
def autogad_profile():
print("***** autogad_profile *****")
def profiler(fn, args):
for _ in range(10):
fn(*args).mean().backward()
torch.cuda.synchronize()
t0 = time.time()
for _ in range(100):
fn(*args).mean().backward()
torch.cuda.synchronize()
t1 = time.time()
return t1 - t0
batch_size = 16
num_params = 1024 * 1024
print("batch_size", batch_size, "num_params", num_params)
param0 = nn.Parameter(torch.randn(num_params)).cuda()
param1 = nn.Parameter(torch.randn(num_params)).cuda()
weight = nn.Parameter(torch.randn(num_params)).cuda()
bias = nn.Parameter(torch.randn(num_params)).cuda()
times = torch.rand(batch_size).cuda()
print("ref", profiler(tdp_torch, (param0, param1, weight, bias, times)))
print("cuda", profiler(TDP.apply, (param0, param1, weight, bias, times)))
print("ref", profiler(tdp_torch, (param0.half(), param1.half(), weight.half(), bias.half(), times.half())))
print("cuda", profiler(TDP.apply, (param0.half(), param1.half(), weight.half(), bias.half(), times.half())))
if __name__ == "__main__":
tdp_cuda_unit_test()
backward_tdp_cuda_unit_test()
autograd_unit_test()
exported_tdp_unit_test()
tdp_cuda_profile()
backward_tdp_cuda_profile()
autogad_profile()