Spaces:
Runtime error
Runtime error
Commit
Β·
d65e5f6
1
Parent(s):
50f1efd
refactor(model): remove explicit device_type parameter from amp decorators
Browse files
llava/model/qlinear_te.py
CHANGED
@@ -98,7 +98,7 @@ class QLinearTE(nn.Linear):
|
|
98 |
|
99 |
class QuantLinearTE(Function):
|
100 |
@staticmethod
|
101 |
-
@amp.custom_fwd(cast_inputs=torch.bfloat16
|
102 |
def forward(ctx, input, weight, bias, args, layer_name):
|
103 |
|
104 |
time_bench = os.getenv("TIME_BENCH")
|
@@ -149,7 +149,7 @@ class QuantLinearTE(Function):
|
|
149 |
return fc_output
|
150 |
|
151 |
@staticmethod
|
152 |
-
@amp.custom_bwd
|
153 |
def backward(ctx, grad_output):
|
154 |
Qinput_t, Iscale, Qweight_t, Wscale, bias, args, layer_name = ctx.saved
|
155 |
|
|
|
98 |
|
99 |
class QuantLinearTE(Function):
|
100 |
@staticmethod
|
101 |
+
@amp.custom_fwd(cast_inputs=torch.bfloat16)
|
102 |
def forward(ctx, input, weight, bias, args, layer_name):
|
103 |
|
104 |
time_bench = os.getenv("TIME_BENCH")
|
|
|
149 |
return fc_output
|
150 |
|
151 |
@staticmethod
|
152 |
+
@amp.custom_bwd
|
153 |
def backward(ctx, grad_output):
|
154 |
Qinput_t, Iscale, Qweight_t, Wscale, bias, args, layer_name = ctx.saved
|
155 |
|