diff --git a/deepspeed/runtime/zero/linear.py b/deepspeed/runtime/zero/linear.py index 8c8db60768eb..0a86e3c389b1 100644 --- a/deepspeed/runtime/zero/linear.py +++ b/deepspeed/runtime/zero/linear.py @@ -86,13 +86,13 @@ def backward(ctx, grad_output): # improve efficiency. If you want to make your code simpler, you can # skip them. Returning gradients for inputs that don't require it is # not an error. + dim = grad_output.dim() if ctx.needs_input_grad[0]: #print(f"Computing grad input weight {weight.shape} grad_output {grad_output.shape}") grad_input = grad_output.matmul(weight) #print(f"Computed grad input {grad_input.shape}") if ctx.needs_input_grad[1]: #print("Computing grad weight") - dim = grad_output.dim() if dim > 2: grad_weight = grad_output.reshape(-1, grad_output.shape[-1]).t().matmul(input.reshape(-1, input.shape[-1]))