Skip to content

Commit 329d7bc

Browse files
committed
Replace AVG POOL with Conv2d as per official Resnet-RS implementation
1 parent 6f9fbe9 commit 329d7bc

File tree

1 file changed

+18
-16
lines changed

1 file changed

+18
-16
lines changed

timm/models/resnet.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ def drop_blocks(drop_block_rate=0.):
442442

443443
def make_blocks(
444444
block_fn, channels, block_repeats, inplanes, reduce_first=1, output_stride=32,
445-
down_kernel_size=1, avg_down=False, drop_block_rate=0., drop_path_rate=0., first_conv_stride=1, **kwargs):
445+
down_kernel_size=1, avg_down=False, drop_block_rate=0., drop_path_rate=0., **kwargs):
446446
stages = []
447447
feature_info = []
448448
net_num_blocks = sum(block_repeats)
@@ -451,7 +451,7 @@ def make_blocks(
451451
dilation = prev_dilation = 1
452452
for stage_idx, (planes, num_blocks, db) in enumerate(zip(channels, block_repeats, drop_blocks(drop_block_rate))):
453453
stage_name = f'layer{stage_idx + 1}' # never liked this name, but weight compat requires it
454-
stride = first_conv_stride if stage_idx == 0 else 2
454+
stride = 1 if stage_idx == 0 else 2
455455
if net_stride >= output_stride:
456456
dilation *= stride
457457
stride = 1
@@ -558,12 +558,12 @@ def __init__(self, block, layers, num_classes=1000, in_chans=3,
558558
cardinality=1, base_width=64, stem_width=64, stem_type='',
559559
output_stride=32, block_reduce_first=1, down_kernel_size=1, avg_down=False,
560560
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_rate=0.0, drop_path_rate=0.,
561-
drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None, skip_stem_max_pool=False):
561+
drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None, replace_stem_max_pool=False):
562562
block_args = block_args or dict()
563563
assert output_stride in (8, 16, 32)
564564
self.num_classes = num_classes
565565
self.drop_rate = drop_rate
566-
self.skip_stem_max_pool = skip_stem_max_pool
566+
self.replace_stem_max_pool = replace_stem_max_pool
567567
super(ResNet, self).__init__()
568568

569569
# Stem
@@ -588,25 +588,27 @@ def __init__(self, block, layers, num_classes=1000, in_chans=3,
588588
self.feature_info = [dict(num_chs=inplanes, reduction=2, module='act1')]
589589

590590
# Stem Pooling
591-
if not self.skip_stem_max_pool:
592-
first_conv_stride = 1
591+
if not self.replace_stem_max_pool:
593592
if aa_layer is not None:
594593
self.maxpool = nn.Sequential(*[
595594
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
596595
aa_layer(channels=inplanes, stride=2)])
597596
else:
598597
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
599598
else:
600-
self.maxpool = nn.Identity()
601-
first_conv_stride = 2
599+
self.maxpool = nn.Sequential(*[
600+
nn.Conv2d(inplanes, inplanes, 3, stride=2, padding=1),
601+
nn.BatchNorm2d(inplanes),
602+
nn.ReLU()
603+
])
602604

603605
# Feature Blocks
604606
channels = [64, 128, 256, 512]
605607
stage_modules, stage_feature_info = make_blocks(
606608
block, channels, layers, inplanes, cardinality=cardinality, base_width=base_width,
607609
output_stride=output_stride, reduce_first=block_reduce_first, avg_down=avg_down,
608610
down_kernel_size=down_kernel_size, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer,
609-
drop_block_rate=drop_block_rate, drop_path_rate=drop_path_rate, first_conv_stride=first_conv_stride, **block_args)
611+
drop_block_rate=drop_block_rate, drop_path_rate=drop_path_rate, **block_args)
610612
for stage in stage_modules:
611613
self.add_module(*stage) # layer1, layer2, etc
612614
self.feature_info.extend(stage_feature_info)
@@ -1078,39 +1080,39 @@ def ecaresnet50d(pretrained=False, **kwargs):
10781080
@register_model
10791081
def resnetrs50(pretrained=False, **kwargs):
10801082
model_args = dict(
1081-
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', skip_stem_max_pool=True,
1083+
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', replace_stem_max_pool=True,
10821084
avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
10831085
return _create_resnet('resnetrs50', pretrained, **model_args)
10841086

10851087

10861088
@register_model
10871089
def resnetrs101(pretrained=False, **kwargs):
10881090
model_args = dict(
1089-
block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', skip_stem_max_pool=True,
1091+
block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', replace_stem_max_pool=True,
10901092
avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
10911093
return _create_resnet('resnetrs101', pretrained, **model_args)
10921094

10931095

10941096
@register_model
10951097
def resnetrs152(pretrained=False, **kwargs):
10961098
model_args = dict(
1097-
block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', skip_stem_max_pool=True,
1099+
block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', replace_stem_max_pool=True,
10981100
avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
10991101
return _create_resnet('resnetrs152', pretrained, **model_args)
11001102

11011103

11021104
@register_model
11031105
def resnetrs200(pretrained=False, **kwargs):
11041106
model_args = dict(
1105-
block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', skip_stem_max_pool=True,
1107+
block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', replace_stem_max_pool=True,
11061108
avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
11071109
return _create_resnet('resnetrs200', pretrained, **model_args)
11081110

11091111

11101112
@register_model
11111113
def resnetrs270(pretrained=False, **kwargs):
11121114
model_args = dict(
1113-
block=Bottleneck, layers=[4, 29, 53, 4], stem_width=32, stem_type='deep', skip_stem_max_pool=True,
1115+
block=Bottleneck, layers=[4, 29, 53, 4], stem_width=32, stem_type='deep', replace_stem_max_pool=True,
11141116
avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
11151117
return _create_resnet('resnetrs270', pretrained, **model_args)
11161118

@@ -1119,15 +1121,15 @@ def resnetrs270(pretrained=False, **kwargs):
11191121
@register_model
11201122
def resnetrs350(pretrained=False, **kwargs):
11211123
model_args = dict(
1122-
block=Bottleneck, layers=[4, 36, 72, 4], stem_width=32, stem_type='deep', skip_stem_max_pool=True,
1124+
block=Bottleneck, layers=[4, 36, 72, 4], stem_width=32, stem_type='deep', replace_stem_max_pool=True,
11231125
avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
11241126
return _create_resnet('resnetrs350', pretrained, **model_args)
11251127

11261128

11271129
@register_model
11281130
def resnetrs420(pretrained=False, **kwargs):
11291131
model_args = dict(
1130-
block=Bottleneck, layers=[4, 44, 87, 4], stem_width=32, stem_type='deep', skip_stem_max_pool=True,
1132+
block=Bottleneck, layers=[4, 44, 87, 4], stem_width=32, stem_type='deep', replace_stem_max_pool=True,
11311133
avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
11321134
return _create_resnet('resnetrs420', pretrained, **model_args)
11331135

0 commit comments

Comments
 (0)