Skip to content

Commit c18caf0

Browse files
committed
Fix: UnboundLocalError for variable 'dim'
Signed-off-by: weeknan <[email protected]>
1 parent 70caefe commit c18caf0

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

deepspeed/runtime/zero/linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,13 @@ def backward(ctx, grad_output):
8686
# improve efficiency. If you want to make your code simpler, you can
8787
# skip them. Returning gradients for inputs that don't require it is
8888
# not an error.
89+
dim = grad_output.dim()
8990
if ctx.needs_input_grad[0]:
9091
#print(f"Computing grad input weight {weight.shape} grad_output {grad_output.shape}")
9192
grad_input = grad_output.matmul(weight)
9293
#print(f"Computed grad input {grad_input.shape}")
9394
if ctx.needs_input_grad[1]:
9495
#print("Computing grad weight")
95-
dim = grad_output.dim()
9696
if dim > 2:
9797
grad_weight = grad_output.reshape(-1,
9898
grad_output.shape[-1]).t().matmul(input.reshape(-1, input.shape[-1]))

0 commit comments

Comments
 (0)