@@ -43,17 +43,20 @@ def forward(self, x):
43
43
44
44
class MyModel (Module ):
45
45
46
- def __init__ (self , hidden_dim ):
46
+ def __init__ (self , hidden_dim , vocab_size ):
47
47
super ().__init__ ()
48
+ self .vocab_size = vocab_size
48
49
# Critical - need to use a stack of at least 2 mlps to validate that the backward of the last mlp sends the correct gradients to the previous mlp in the stack
49
50
self .mlp1 = SimpleMLP (hidden_dim )
50
51
self .mlp2 = SimpleMLP (hidden_dim )
52
+ self .lm_head = torch .nn .Linear (hidden_dim , vocab_size , bias = False )
51
53
self .cross_entropy_loss = torch .nn .CrossEntropyLoss ()
52
54
53
55
def forward (self , x , y ):
54
56
x = self .mlp1 (x )
55
57
x = self .mlp2 (x )
56
- return self .cross_entropy_loss (x , y )
58
+ logits = self .lm_head (x )
59
+ return self .cross_entropy_loss (logits .view (- 1 , self .vocab_size ), y .view (- 1 ))
57
60
58
61
59
62
def mlp_forward_tiled_mlp (self , x ):
@@ -121,17 +124,18 @@ def test_tiled_mlp(self, zero_stage):
121
124
# for debug
122
125
# torch.set_printoptions(precision=8, sci_mode=True)
123
126
127
+ vocab_size = 10
124
128
seed = 42
125
- hidden_dim = 100
126
- bs = 1
127
- seqlen = hidden_dim
129
+ hidden_dim = 128
130
+ bs = 2
131
+ seqlen = 64
128
132
torch .manual_seed (seed )
129
133
x = torch .rand ((bs , seqlen , hidden_dim ), dtype = dtype , requires_grad = True )
130
- y = torch .empty ((bs , seqlen ), dtype = torch .long , requires_grad = False ).random_ (hidden_dim )
134
+ y = torch .empty ((bs , seqlen ), dtype = torch .long , requires_grad = False ).random_ (vocab_size )
131
135
132
136
# A. Baseline: model with normal MLP
133
137
torch .manual_seed (seed )
134
- model_a = MyModel (hidden_dim = hidden_dim ).to (dtype )
138
+ model_a = MyModel (hidden_dim = hidden_dim , vocab_size = vocab_size ).to (dtype )
135
139
model_a , _ , _ , _ = deepspeed .initialize (config = config_dict ,
136
140
model = model_a ,
137
141
model_parameters = model_a .parameters ())
@@ -144,15 +148,17 @@ def test_tiled_mlp(self, zero_stage):
144
148
145
149
loss_a = model_a (x_a , y_a )
146
150
model_a .backward (loss_a )
147
- grad_a1 = get_grad (model_a .module .mlp1 .up_proj .weight , zero_stage )
148
- grad_a2 = get_grad (model_a .module .mlp2 .up_proj .weight , zero_stage )
149
- assert grad_a1 is not None
150
- assert grad_a2 is not None
151
+ param_grad_a1 = get_grad (model_a .module .mlp1 .up_proj .weight , zero_stage )
152
+ param_grad_a2 = get_grad (model_a .module .mlp2 .up_proj .weight , zero_stage )
153
+ x_grad_a = x_a .grad
154
+ assert param_grad_a1 is not None
155
+ assert param_grad_a2 is not None
156
+ assert x_grad_a is not None
151
157
152
158
# B. model with tiled MLP using TiledMLP
153
159
torch .manual_seed (seed )
154
160
SimpleMLP .forward = mlp_forward_tiled_mlp
155
- model_b = MyModel (hidden_dim = hidden_dim ).to (dtype )
161
+ model_b = MyModel (hidden_dim = hidden_dim , vocab_size = vocab_size ).to (dtype )
156
162
model_b , _ , _ , _ = deepspeed .initialize (config = config_dict ,
157
163
model = model_b ,
158
164
model_parameters = model_b .parameters ())
@@ -161,31 +167,34 @@ def test_tiled_mlp(self, zero_stage):
161
167
y_b = y .clone ().detach ()
162
168
loss_b = model_b (x_b , y_b )
163
169
model_b .backward (loss_b )
164
- grad_b1 = get_grad (model_b .module .mlp1 .up_proj .weight , zero_stage )
165
- grad_b2 = get_grad (model_b .module .mlp2 .up_proj .weight , zero_stage )
166
- assert grad_b1 is not None
167
- assert grad_b2 is not None
170
+ param_grad_b1 = get_grad (model_b .module .mlp1 .up_proj .weight , zero_stage )
171
+ param_grad_b2 = get_grad (model_b .module .mlp2 .up_proj .weight , zero_stage )
172
+ x_grad_b = x_b .grad
173
+ assert param_grad_b1 is not None
174
+ assert param_grad_b2 is not None
175
+ assert x_grad_b is not None
168
176
169
177
# print(f"{loss_a=}")
170
178
# print(f"{loss_b=}")
171
- # print(f"{grad_a1 =}")
172
- # print(f"{grad_b1 =}")
173
- # print(f"{grad_a2 =}")
174
- # print(f"{grad_b2 =}")
179
+ # print(f"{param_grad_a1 =}")
180
+ # print(f"{param_grad_b1 =}")
181
+ # print(f"{param_grad_a2 =}")
182
+ # print(f"{param_grad_b2 =}")
175
183
torch_assert_equal (loss_a , loss_b )
176
184
177
185
# Gradient will not be exactly the same, especially under half-precision. And bf16 is
178
186
# particularly lossy so need to lower tolerance a bit more than the default. Switch to
179
187
# dtype torch.float or even torch.double to see that the diff is tiny - so the math is
180
188
# correct, but accumulation error adds up. Alternatively making hidden_dim bigger makes the
181
189
# divergence much smaller as well.
182
- torch_assert_close (grad_a1 , grad_b1 ) #, rtol=1e-03, atol=1e-04)
183
- torch_assert_close (grad_a2 , grad_b2 ) #, rtol=1e-03, atol=1e-04)
190
+ torch_assert_close (param_grad_a1 , param_grad_b1 ) #, rtol=1e-03, atol=1e-04)
191
+ torch_assert_close (param_grad_a2 , param_grad_b2 ) #, rtol=1e-03, atol=1e-04)
192
+ torch_assert_close (x_grad_a , x_grad_b )
184
193
185
194
# C. model with tiled MLP using the generic version of the same via sequence_tiled_compute + SequenceTiledCompute
186
195
torch .manual_seed (seed )
187
196
SimpleMLP .forward = mlp_forward_sequence_tiled_compute
188
- model_c = MyModel (hidden_dim = hidden_dim ).to (dtype )
197
+ model_c = MyModel (hidden_dim = hidden_dim , vocab_size = vocab_size ).to (dtype )
189
198
model_c , _ , _ , _ = deepspeed .initialize (config = config_dict ,
190
199
model = model_c ,
191
200
model_parameters = model_c .parameters ())
@@ -194,16 +203,19 @@ def test_tiled_mlp(self, zero_stage):
194
203
y_c = y .clone ().detach ()
195
204
loss_c = model_c (x_c , y_c )
196
205
model_c .backward (loss_c )
197
- grad_c1 = get_grad (model_c .module .mlp1 .up_proj .weight , zero_stage )
198
- grad_c2 = get_grad (model_c .module .mlp2 .up_proj .weight , zero_stage )
199
- assert grad_c1 is not None
200
- assert grad_c2 is not None
206
+ param_grad_c1 = get_grad (model_c .module .mlp1 .up_proj .weight , zero_stage )
207
+ param_grad_c2 = get_grad (model_c .module .mlp2 .up_proj .weight , zero_stage )
208
+ x_grad_c = x_c .grad
209
+ assert param_grad_c1 is not None
210
+ assert param_grad_c2 is not None
211
+ assert x_grad_c is not None
201
212
202
213
# print(f"{loss_a=}")
203
214
# print(f"{loss_c=}")
204
- # print(f"{grad_a1 =}")
205
- # print(f"{grad_c1 =}")
215
+ # print(f"{param_grad_a1 =}")
216
+ # print(f"{param_grad_c1 =}")
206
217
# see notes for B
207
218
torch_assert_equal (loss_a , loss_c )
208
- torch_assert_close (grad_a1 , grad_c1 ) #, rtol=1e-03, atol=1e-04)
209
- torch_assert_close (grad_a2 , grad_c2 ) #, rtol=1e-03, atol=1e-04)
219
+ torch_assert_close (param_grad_a1 , param_grad_c1 ) #, rtol=1e-03, atol=1e-04)
220
+ torch_assert_close (param_grad_a2 , param_grad_c2 ) #, rtol=1e-03, atol=1e-04)
221
+ torch_assert_close (x_grad_a , x_grad_c )
0 commit comments