File size: 6,244 Bytes
19a3898 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
from deepspeed.accelerator import get_accelerator
from ..features.cuda_graph import CUDAGraph
class DSVAE(CUDAGraph, torch.nn.Module):
def __init__(self, vae, enable_cuda_graph=True):
super().__init__(enable_cuda_graph=enable_cuda_graph)
self.vae = vae
self.config = vae.config
self.device = self.vae.device
self.dtype = self.vae.dtype
self.vae.requires_grad_(requires_grad=False)
self.decoder_cuda_graph_created = False
self.encoder_cuda_graph_created = False
self.all_cuda_graph_created = False
def _graph_replay_decoder(self, *inputs, **kwargs):
for i in range(len(inputs)):
if torch.is_tensor(inputs[i]):
self.static_decoder_inputs[i].copy_(inputs[i])
for k in kwargs:
if torch.is_tensor(kwargs[k]):
self.static_decoder_kwargs[k].copy_(kwargs[k])
get_accelerator().replay_graph(self._decoder_cuda_graph)
return self.static_decoder_output
def _decode(self, x, return_dict=True, generator=None):
return self.vae.decode(x, return_dict=return_dict)
def _create_cuda_graph_decoder(self, *inputs, **kwargs):
# warmup to create the workspace and cublas handle
cuda_stream = torch.cuda.Stream()
cuda_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(cuda_stream):
for i in range(3):
ret = self._decode(*inputs, **kwargs)
torch.cuda.current_stream().wait_stream(cuda_stream)
# create cuda_graph and assign static_inputs and static_outputs
self._decoder_cuda_graph = get_accelerator().create_graph()
self.static_decoder_inputs = inputs
self.static_decoder_kwargs = kwargs
with get_accelerator().capture_to_graph(self._decoder_cuda_graph):
self.static_decoder_output = self._decode(*self.static_decoder_inputs, **self.static_decoder_kwargs)
self.decoder_cuda_graph_created = True
def decode(self, *inputs, **kwargs):
if self.enable_cuda_graph:
if self.decoder_cuda_graph_created:
outputs = self._graph_replay_decoder(*inputs, **kwargs)
else:
self._create_cuda_graph_decoder(*inputs, **kwargs)
outputs = self._graph_replay_decoder(*inputs, **kwargs)
return outputs
else:
return self._decode(*inputs, **kwargs)
def _graph_replay_encoder(self, *inputs, **kwargs):
for i in range(len(inputs)):
if torch.is_tensor(inputs[i]):
self.static_encoder_inputs[i].copy_(inputs[i])
for k in kwargs:
if torch.is_tensor(kwargs[k]):
self.static_encoder_kwargs[k].copy_(kwargs[k])
get_accelerator().replay_graph(self._encoder_cuda_graph)
return self.static_encoder_output
def _encode(self, x, return_dict=True):
return self.vae.encode(x, return_dict=return_dict)
def _create_cuda_graph_encoder(self, *inputs, **kwargs):
# warmup to create the workspace and cublas handle
cuda_stream = torch.cuda.Stream()
cuda_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(cuda_stream):
for i in range(3):
ret = self._encode(*inputs, **kwargs)
torch.cuda.current_stream().wait_stream(cuda_stream)
# create cuda_graph and assign static_inputs and static_outputs
self._encoder_cuda_graph = get_accelerator().create_graph()
self.static_encoder_inputs = inputs
self.static_encoder_kwargs = kwargs
with get_accelerator().capture_to_graph(self._encoder_cuda_graph):
self.static_encoder_output = self._encode(*self.static_encoder_inputs, **self.static_encoder_kwargs)
self.encoder_cuda_graph_created = True
def encode(self, *inputs, **kwargs):
if self.enable_cuda_graph:
if self.encoder_cuda_graph_created:
outputs = self._graph_replay_encoder(*inputs, **kwargs)
else:
self._create_cuda_graph_encoder(*inputs, **kwargs)
outputs = self._graph_replay_encoder(*inputs, **kwargs)
return outputs
else:
return self._encode(*inputs, **kwargs)
def _graph_replay(self, *inputs, **kwargs):
for i in range(len(inputs)):
if torch.is_tensor(inputs[i]):
self.static_inputs[i].copy_(inputs[i])
for k in kwargs:
if torch.is_tensor(kwargs[k]):
self.static_kwargs[k].copy_(kwargs[k])
get_accelerator().replay_graph(self._all_cuda_graph)
return self.static_output
def forward(self, *inputs, **kwargs):
if self.enable_cuda_graph:
if self.cuda_graph_created:
outputs = self._graph_replay(*inputs, **kwargs)
else:
self._create_cuda_graph(*inputs, **kwargs)
outputs = self._graph_replay(*inputs, **kwargs)
return outputs
else:
return self._forward(*inputs, **kwargs)
def _create_cuda_graph(self, *inputs, **kwargs):
# warmup to create the workspace and cublas handle
cuda_stream = torch.cuda.Stream()
cuda_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(cuda_stream):
for i in range(3):
ret = self._forward(*inputs, **kwargs)
torch.cuda.current_stream().wait_stream(cuda_stream)
# create cuda_graph and assign static_inputs and static_outputs
self._all_cuda_graph = get_accelerator().create_graph()
self.static_inputs = inputs
self.static_kwargs = kwargs
with get_accelerator().capture_to_graph(self._all_cuda_graph):
self.static_output = self._forward(*self.static_inputs, **self.static_kwargs)
self.all_cuda_graph_created = True
def _forward(self, sample, timestamp, encoder_hidden_states, return_dict=True):
return self.vae(sample, timestamp, encoder_hidden_states, return_dict)
|