Skip to content

Commit 1123063

Browse files
authored
Make all CCT regularization parameters user-configurable. (#346)
1 parent f8bec5e commit 1123063

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

vit_pytorch/cct.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,9 @@ def __init__(
316316
pooling_kernel_size=3,
317317
pooling_stride=2,
318318
pooling_padding=1,
319+
dropout_rate=0.,
320+
attention_dropout=0.1,
321+
stochastic_depth_rate=0.1,
319322
*args, **kwargs
320323
):
321324
super().__init__()
@@ -340,9 +343,9 @@ def __init__(
340343
width=img_width),
341344
embedding_dim=embedding_dim,
342345
seq_pool=True,
343-
dropout_rate=0.,
344-
attention_dropout=0.1,
345-
stochastic_depth=0.1,
346+
dropout_rate=dropout_rate,
347+
attention_dropout=attention_dropout,
348+
stochastic_depth_rate=stochastic_depth_rate,
346349
*args, **kwargs)
347350

348351
def forward(self, x):

0 commit comments

Comments
 (0)