PierrunoYT commited on
Commit
d65e5f6
Β·
1 Parent(s): 50f1efd

refactor(model): remove explicit device_type parameter from amp decorators

Browse files
Files changed (1) hide show
  1. llava/model/qlinear_te.py +2 -2
llava/model/qlinear_te.py CHANGED
@@ -98,7 +98,7 @@ class QLinearTE(nn.Linear):
98
 
99
  class QuantLinearTE(Function):
100
  @staticmethod
101
- @amp.custom_fwd(cast_inputs=torch.bfloat16, device_type='cuda')
102
  def forward(ctx, input, weight, bias, args, layer_name):
103
 
104
  time_bench = os.getenv("TIME_BENCH")
@@ -149,7 +149,7 @@ class QuantLinearTE(Function):
149
  return fc_output
150
 
151
  @staticmethod
152
- @amp.custom_bwd(device_type='cuda')
153
  def backward(ctx, grad_output):
154
  Qinput_t, Iscale, Qweight_t, Wscale, bias, args, layer_name = ctx.saved
155
 
 
98
 
99
  class QuantLinearTE(Function):
100
  @staticmethod
101
+ @amp.custom_fwd(cast_inputs=torch.bfloat16)
102
  def forward(ctx, input, weight, bias, args, layer_name):
103
 
104
  time_bench = os.getenv("TIME_BENCH")
 
149
  return fc_output
150
 
151
  @staticmethod
152
+ @amp.custom_bwd
153
  def backward(ctx, grad_output):
154
  Qinput_t, Iscale, Qweight_t, Wscale, bias, args, layer_name = ctx.saved
155