Skip to content

Commit 9988027

Browse files
weeknantjruwaseloadams
authored andcommitted
Fix: UnboundLocalError for variable 'dim' about issue (deepspeedai#7449)
## Fix `UnboundLocalError` in `ZeroLinear.backward()` when training only bias parameters, as mentioned in deepspeedai#7435 This PR addresses an issue in the `ZeroLinear.backward()` method, where the local variable `dim` could be referenced before assignment. This happens specifically when: - Only the bias parameters are set to `requires_grad=True`, and - The training setup uses **ZeRO Stage 3**, **AMP**, and **gradient checkpointing**. ### Problem When only the bias requires gradients, the condition for setting `dim = grad_output.dim()` is skipped, but the value of `dim` is still used later in the computation, leading to: ### Fix Move the assignment `dim = grad_output.dim()` to occur unconditionally, so that `dim` is always defined before being used in any branch of the gradient computation logic. ### Impact This makes the backward pass more robust across different training setups. Signed-off-by: weeknan <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]> Co-authored-by: Logan Adams <[email protected]> Signed-off-by: qimcis <[email protected]>
1 parent d79704e commit 9988027

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

deepspeed/runtime/zero/linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,13 @@ def backward(ctx, grad_output):
8686
# improve efficiency. If you want to make your code simpler, you can
8787
# skip them. Returning gradients for inputs that don't require it is
8888
# not an error.
89+
dim = grad_output.dim()
8990
if ctx.needs_input_grad[0]:
9091
#print(f"Computing grad input weight {weight.shape} grad_output {grad_output.shape}")
9192
grad_input = grad_output.matmul(weight)
9293
#print(f"Computed grad input {grad_input.shape}")
9394
if ctx.needs_input_grad[1]:
9495
#print("Computing grad weight")
95-
dim = grad_output.dim()
9696
if dim > 2:
9797
grad_weight = grad_output.reshape(-1,
9898
grad_output.shape[-1]).t().matmul(input.reshape(-1, input.shape[-1]))

0 commit comments

Comments
 (0)