We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 70caefe commit c18caf0Copy full SHA for c18caf0
deepspeed/runtime/zero/linear.py
@@ -86,13 +86,13 @@ def backward(ctx, grad_output):
86
# improve efficiency. If you want to make your code simpler, you can
87
# skip them. Returning gradients for inputs that don't require it is
88
# not an error.
89
+ dim = grad_output.dim()
90
if ctx.needs_input_grad[0]:
91
#print(f"Computing grad input weight {weight.shape} grad_output {grad_output.shape}")
92
grad_input = grad_output.matmul(weight)
93
#print(f"Computed grad input {grad_input.shape}")
94
if ctx.needs_input_grad[1]:
95
#print("Computing grad weight")
- dim = grad_output.dim()
96
if dim > 2:
97
grad_weight = grad_output.reshape(-1,
98
grad_output.shape[-1]).t().matmul(input.reshape(-1, input.shape[-1]))
0 commit comments