Skip to content

Commit 2790220

Browse files
authored
[TiledMLP]: fix for bs>1 (#7412)
It looks like my TiledMLP was working correctly only for batch_size=1 fixing to work with any bs thanks to @winglian for detecting the problem and sending me an easy repro --------- Signed-off-by: Stas Bekman <[email protected]>
1 parent 15f054d commit 2790220

File tree

2 files changed

+53
-42
lines changed

2 files changed

+53
-42
lines changed

deepspeed/runtime/sequence_parallel/ulysses_sp.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -706,7 +706,7 @@ def backward(ctx, *grads) -> torch.Tensor:
706706
}
707707

708708
# if seqlen is not exactly divisible by shards the last step will be shorter than shard_step
709-
shard_step = kwargs_to_shard_shards[grad_requiring_tensor_key][0].numel()
709+
shard_step = kwargs_to_shard_shards[grad_requiring_tensor_key][0].shape[1]
710710
for i in range(shards):
711711

712712
# when fn involves one or more model weights deepspeed will normally push a grad to
@@ -731,8 +731,8 @@ def backward(ctx, *grads) -> torch.Tensor:
731731
shard_offset = i * shard_step
732732
# this will enable gradual population of the pre-allocated
733733
# `grad_requiring_tensor_shard.grad` during `torch.autograd.backward` calls
734-
grad_requiring_tensor_shard.grad = (grad_requiring_tensor_grad.view(-1).narrow(
735-
0, shard_offset, grad_requiring_tensor_shard.numel()).view_as(grad_requiring_tensor_shard))
734+
grad_requiring_tensor_shard.grad = (grad_requiring_tensor_grad.narrow(
735+
1, shard_offset, shard_step).view_as(grad_requiring_tensor_shard))
736736

737737
with torch.enable_grad():
738738
output = fn(**kwargs_to_shard_shard, **kwargs_to_pass)
@@ -741,8 +741,8 @@ def backward(ctx, *grads) -> torch.Tensor:
741741
# loss use-case
742742
torch.autograd.backward(output, incoming_grad)
743743
else:
744-
incoming_grad_shard = (incoming_grad.view(-1).narrow(
745-
0, shard_offset, grad_requiring_tensor_shard.numel()).view_as(grad_requiring_tensor_shard))
744+
incoming_grad_shard = (incoming_grad.narrow(1, shard_offset,
745+
shard_step).view_as(grad_requiring_tensor_shard))
746746
torch.autograd.backward(output, incoming_grad_shard)
747747

748748
# positional args
@@ -836,7 +836,7 @@ def backward(ctx, *grads) -> torch.Tensor:
836836
x_grad = torch.zeros_like(x)
837837
x_shards = list(torch.chunk(x, chunks=shards, dim=1))
838838

839-
shard_step = x_shards[0].numel()
839+
shard_step = x_shards[0].shape[1]
840840
for i, x_shard in enumerate(x_shards):
841841

842842
# Tell deepspeed not to add a new grad to its ipg bucket until the last shard is run
@@ -852,8 +852,8 @@ def backward(ctx, *grads) -> torch.Tensor:
852852
x_shard.requires_grad_(x_requires_grad)
853853

854854
shard_offset = i * shard_step
855-
x_shard.grad = x_grad.view(-1).narrow(0, shard_offset, x_shard.numel()).view_as(x_shard)
856-
incoming_grad_shard = incoming_grad.view(-1).narrow(0, shard_offset, x_shard.numel()).view_as(x_shard)
855+
x_shard.grad = x_grad.narrow(1, shard_offset, shard_step).view_as(x_shard)
856+
incoming_grad_shard = incoming_grad.narrow(1, shard_offset, shard_step).view_as(x_shard)
857857
with torch.enable_grad():
858858
output = fn(self, x_shard)
859859
torch.autograd.backward(output, incoming_grad_shard)
@@ -1010,15 +1010,14 @@ def backward(ctx, *grads) -> torch.Tensor:
10101010
shift_labels_shards = list(torch.chunk(shift_labels, chunks=shards, dim=1))
10111011

10121012
# if seqlen is not exactly divisible by shards the last step will be shorter than shard_step
1013-
shard_step = logits_shards[0].numel()
1013+
shard_step = logits_shards[0].shape[1]
10141014
for i in range(shards):
10151015
logits_shard = logits_shards.pop(0)
10161016
shift_labels_shard = shift_labels_shards.pop(0)
10171017

10181018
shard_offset = i * shard_step
10191019
# this will enable gradual population of the pre-allocated `logits_shard.grad` during `torch.autograd.backward` calls
1020-
logits_shard.grad = (logits_grad.view(-1).narrow(0, shard_offset,
1021-
logits_shard.numel()).view_as(logits_shard))
1020+
logits_shard.grad = (logits_grad.narrow(1, shard_offset, shard_step).view_as(logits_shard))
10221021

10231022
with torch.enable_grad():
10241023
if all((shift_labels_shard == -100).squeeze()):

tests/unit/ulysses_alst/test_tiled_compute.py

Lines changed: 43 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -43,17 +43,20 @@ def forward(self, x):
4343

4444
class MyModel(Module):
4545

46-
def __init__(self, hidden_dim):
46+
def __init__(self, hidden_dim, vocab_size):
4747
super().__init__()
48+
self.vocab_size = vocab_size
4849
# Critical - need to use a stack of at least 2 mlps to validate that the backward of the last mlp sends the correct gradients to the previous mlp in the stack
4950
self.mlp1 = SimpleMLP(hidden_dim)
5051
self.mlp2 = SimpleMLP(hidden_dim)
52+
self.lm_head = torch.nn.Linear(hidden_dim, vocab_size, bias=False)
5153
self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
5254

5355
def forward(self, x, y):
5456
x = self.mlp1(x)
5557
x = self.mlp2(x)
56-
return self.cross_entropy_loss(x, y)
58+
logits = self.lm_head(x)
59+
return self.cross_entropy_loss(logits.view(-1, self.vocab_size), y.view(-1))
5760

5861

5962
def mlp_forward_tiled_mlp(self, x):
@@ -121,17 +124,18 @@ def test_tiled_mlp(self, zero_stage):
121124
# for debug
122125
# torch.set_printoptions(precision=8, sci_mode=True)
123126

127+
vocab_size = 10
124128
seed = 42
125-
hidden_dim = 100
126-
bs = 1
127-
seqlen = hidden_dim
129+
hidden_dim = 128
130+
bs = 2
131+
seqlen = 64
128132
torch.manual_seed(seed)
129133
x = torch.rand((bs, seqlen, hidden_dim), dtype=dtype, requires_grad=True)
130-
y = torch.empty((bs, seqlen), dtype=torch.long, requires_grad=False).random_(hidden_dim)
134+
y = torch.empty((bs, seqlen), dtype=torch.long, requires_grad=False).random_(vocab_size)
131135

132136
# A. Baseline: model with normal MLP
133137
torch.manual_seed(seed)
134-
model_a = MyModel(hidden_dim=hidden_dim).to(dtype)
138+
model_a = MyModel(hidden_dim=hidden_dim, vocab_size=vocab_size).to(dtype)
135139
model_a, _, _, _ = deepspeed.initialize(config=config_dict,
136140
model=model_a,
137141
model_parameters=model_a.parameters())
@@ -144,15 +148,17 @@ def test_tiled_mlp(self, zero_stage):
144148

145149
loss_a = model_a(x_a, y_a)
146150
model_a.backward(loss_a)
147-
grad_a1 = get_grad(model_a.module.mlp1.up_proj.weight, zero_stage)
148-
grad_a2 = get_grad(model_a.module.mlp2.up_proj.weight, zero_stage)
149-
assert grad_a1 is not None
150-
assert grad_a2 is not None
151+
param_grad_a1 = get_grad(model_a.module.mlp1.up_proj.weight, zero_stage)
152+
param_grad_a2 = get_grad(model_a.module.mlp2.up_proj.weight, zero_stage)
153+
x_grad_a = x_a.grad
154+
assert param_grad_a1 is not None
155+
assert param_grad_a2 is not None
156+
assert x_grad_a is not None
151157

152158
# B. model with tiled MLP using TiledMLP
153159
torch.manual_seed(seed)
154160
SimpleMLP.forward = mlp_forward_tiled_mlp
155-
model_b = MyModel(hidden_dim=hidden_dim).to(dtype)
161+
model_b = MyModel(hidden_dim=hidden_dim, vocab_size=vocab_size).to(dtype)
156162
model_b, _, _, _ = deepspeed.initialize(config=config_dict,
157163
model=model_b,
158164
model_parameters=model_b.parameters())
@@ -161,31 +167,34 @@ def test_tiled_mlp(self, zero_stage):
161167
y_b = y.clone().detach()
162168
loss_b = model_b(x_b, y_b)
163169
model_b.backward(loss_b)
164-
grad_b1 = get_grad(model_b.module.mlp1.up_proj.weight, zero_stage)
165-
grad_b2 = get_grad(model_b.module.mlp2.up_proj.weight, zero_stage)
166-
assert grad_b1 is not None
167-
assert grad_b2 is not None
170+
param_grad_b1 = get_grad(model_b.module.mlp1.up_proj.weight, zero_stage)
171+
param_grad_b2 = get_grad(model_b.module.mlp2.up_proj.weight, zero_stage)
172+
x_grad_b = x_b.grad
173+
assert param_grad_b1 is not None
174+
assert param_grad_b2 is not None
175+
assert x_grad_b is not None
168176

169177
# print(f"{loss_a=}")
170178
# print(f"{loss_b=}")
171-
# print(f"{grad_a1=}")
172-
# print(f"{grad_b1=}")
173-
# print(f"{grad_a2=}")
174-
# print(f"{grad_b2=}")
179+
# print(f"{param_grad_a1=}")
180+
# print(f"{param_grad_b1=}")
181+
# print(f"{param_grad_a2=}")
182+
# print(f"{param_grad_b2=}")
175183
torch_assert_equal(loss_a, loss_b)
176184

177185
# Gradient will not be exactly the same, especially under half-precision. And bf16 is
178186
# particularly lossy so need to lower tolerance a bit more than the default. Switch to
179187
# dtype torch.float or even torch.double to see that the diff is tiny - so the math is
180188
# correct, but accumulation error adds up. Alternatively making hidden_dim bigger makes the
181189
# divergence much smaller as well.
182-
torch_assert_close(grad_a1, grad_b1) #, rtol=1e-03, atol=1e-04)
183-
torch_assert_close(grad_a2, grad_b2) #, rtol=1e-03, atol=1e-04)
190+
torch_assert_close(param_grad_a1, param_grad_b1) #, rtol=1e-03, atol=1e-04)
191+
torch_assert_close(param_grad_a2, param_grad_b2) #, rtol=1e-03, atol=1e-04)
192+
torch_assert_close(x_grad_a, x_grad_b)
184193

185194
# C. model with tiled MLP using the generic version of the same via sequence_tiled_compute + SequenceTiledCompute
186195
torch.manual_seed(seed)
187196
SimpleMLP.forward = mlp_forward_sequence_tiled_compute
188-
model_c = MyModel(hidden_dim=hidden_dim).to(dtype)
197+
model_c = MyModel(hidden_dim=hidden_dim, vocab_size=vocab_size).to(dtype)
189198
model_c, _, _, _ = deepspeed.initialize(config=config_dict,
190199
model=model_c,
191200
model_parameters=model_c.parameters())
@@ -194,16 +203,19 @@ def test_tiled_mlp(self, zero_stage):
194203
y_c = y.clone().detach()
195204
loss_c = model_c(x_c, y_c)
196205
model_c.backward(loss_c)
197-
grad_c1 = get_grad(model_c.module.mlp1.up_proj.weight, zero_stage)
198-
grad_c2 = get_grad(model_c.module.mlp2.up_proj.weight, zero_stage)
199-
assert grad_c1 is not None
200-
assert grad_c2 is not None
206+
param_grad_c1 = get_grad(model_c.module.mlp1.up_proj.weight, zero_stage)
207+
param_grad_c2 = get_grad(model_c.module.mlp2.up_proj.weight, zero_stage)
208+
x_grad_c = x_c.grad
209+
assert param_grad_c1 is not None
210+
assert param_grad_c2 is not None
211+
assert x_grad_c is not None
201212

202213
# print(f"{loss_a=}")
203214
# print(f"{loss_c=}")
204-
# print(f"{grad_a1=}")
205-
# print(f"{grad_c1=}")
215+
# print(f"{param_grad_a1=}")
216+
# print(f"{param_grad_c1=}")
206217
# see notes for B
207218
torch_assert_equal(loss_a, loss_c)
208-
torch_assert_close(grad_a1, grad_c1) #, rtol=1e-03, atol=1e-04)
209-
torch_assert_close(grad_a2, grad_c2) #, rtol=1e-03, atol=1e-04)
219+
torch_assert_close(param_grad_a1, param_grad_c1) #, rtol=1e-03, atol=1e-04)
220+
torch_assert_close(param_grad_a2, param_grad_c2) #, rtol=1e-03, atol=1e-04)
221+
torch_assert_close(x_grad_a, x_grad_c)

0 commit comments

Comments
 (0)