Skip to content

Commit 8449932

Browse files
committed
Add ResNet-RS models
1 parent 779107b commit 8449932

File tree

1 file changed

+89
-10
lines changed

1 file changed

+89
-10
lines changed

timm/models/resnet.py

Lines changed: 89 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,23 @@ def _cfg(url='', **kwargs):
233233
interpolation='bicubic'),
234234
'resnetblur50': _cfg(
235235
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnetblur50-84f4748f.pth',
236-
interpolation='bicubic')
236+
interpolation='bicubic'),
237+
238+
# ResNet-RS models
239+
'resnetrs50': _cfg(
240+
interpolation='bicubic'),
241+
'resnetrs101': _cfg(
242+
interpolation='bicubic'),
243+
'resnetrs152': _cfg(
244+
interpolation='bicubic'),
245+
'resnetrs200': _cfg(
246+
interpolation='bicubic'),
247+
'resnetrs270': _cfg(
248+
interpolation='bicubic'),
249+
'resnetrs350': _cfg(
250+
interpolation='bicubic'),
251+
'resnetrs420': _cfg(
252+
interpolation='bicubic'),
237253
}
238254

239255

@@ -426,7 +442,7 @@ def drop_blocks(drop_block_rate=0.):
426442

427443
def make_blocks(
428444
block_fn, channels, block_repeats, inplanes, reduce_first=1, output_stride=32,
429-
down_kernel_size=1, avg_down=False, drop_block_rate=0., drop_path_rate=0., **kwargs):
445+
down_kernel_size=1, avg_down=False, drop_block_rate=0., drop_path_rate=0., first_conv_stride=1, **kwargs):
430446
stages = []
431447
feature_info = []
432448
net_num_blocks = sum(block_repeats)
@@ -435,7 +451,7 @@ def make_blocks(
435451
dilation = prev_dilation = 1
436452
for stage_idx, (planes, num_blocks, db) in enumerate(zip(channels, block_repeats, drop_blocks(drop_block_rate))):
437453
stage_name = f'layer{stage_idx + 1}' # never liked this name, but weight compat requires it
438-
stride = 1 if stage_idx == 0 else 2
454+
stride = first_conv_stride if stage_idx == 0 else 2
439455
if net_stride >= output_stride:
440456
dilation *= stride
441457
stride = 1
@@ -542,11 +558,12 @@ def __init__(self, block, layers, num_classes=1000, in_chans=3,
542558
cardinality=1, base_width=64, stem_width=64, stem_type='',
543559
output_stride=32, block_reduce_first=1, down_kernel_size=1, avg_down=False,
544560
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_rate=0.0, drop_path_rate=0.,
545-
drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None):
561+
drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None, skip_stem_max_pool=False):
546562
block_args = block_args or dict()
547563
assert output_stride in (8, 16, 32)
548564
self.num_classes = num_classes
549565
self.drop_rate = drop_rate
566+
self.skip_stem_max_pool = skip_stem_max_pool
550567
super(ResNet, self).__init__()
551568

552569
# Stem
@@ -571,20 +588,25 @@ def __init__(self, block, layers, num_classes=1000, in_chans=3,
571588
self.feature_info = [dict(num_chs=inplanes, reduction=2, module='act1')]
572589

573590
# Stem Pooling
574-
if aa_layer is not None:
575-
self.maxpool = nn.Sequential(*[
576-
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
577-
aa_layer(channels=inplanes, stride=2)])
591+
if not self.skip_stem_max_pool:
592+
first_conv_stride = 1
593+
if aa_layer is not None:
594+
self.maxpool = nn.Sequential(*[
595+
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
596+
aa_layer(channels=inplanes, stride=2)])
597+
else:
598+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
578599
else:
579-
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
600+
self.maxpool = nn.Identity()
601+
first_conv_stride = 2
580602

581603
# Feature Blocks
582604
channels = [64, 128, 256, 512]
583605
stage_modules, stage_feature_info = make_blocks(
584606
block, channels, layers, inplanes, cardinality=cardinality, base_width=base_width,
585607
output_stride=output_stride, reduce_first=block_reduce_first, avg_down=avg_down,
586608
down_kernel_size=down_kernel_size, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer,
587-
drop_block_rate=drop_block_rate, drop_path_rate=drop_path_rate, **block_args)
609+
drop_block_rate=drop_block_rate, drop_path_rate=drop_path_rate, first_conv_stride=first_conv_stride, **block_args)
588610
for stage in stage_modules:
589611
self.add_module(*stage) # layer1, layer2, etc
590612
self.feature_info.extend(stage_feature_info)
@@ -1053,6 +1075,63 @@ def ecaresnet50d(pretrained=False, **kwargs):
10531075
return _create_resnet('ecaresnet50d', pretrained, **model_args)
10541076

10551077

1078+
@register_model
1079+
def resnetrs50(pretrained=False, **kwargs):
1080+
model_args = dict(
1081+
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', skip_stem_max_pool=True,
1082+
avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
1083+
return _create_resnet('resnetrs50', pretrained, **model_args)
1084+
1085+
1086+
@register_model
1087+
def resnetrs101(pretrained=False, **kwargs):
1088+
model_args = dict(
1089+
block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', skip_stem_max_pool=True,
1090+
avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
1091+
return _create_resnet('resnetrs101', pretrained, **model_args)
1092+
1093+
1094+
@register_model
1095+
def resnetrs152(pretrained=False, **kwargs):
1096+
model_args = dict(
1097+
block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', skip_stem_max_pool=True,
1098+
avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
1099+
return _create_resnet('resnetrs152', pretrained, **model_args)
1100+
1101+
1102+
@register_model
1103+
def resnetrs200(pretrained=False, **kwargs):
1104+
model_args = dict(
1105+
block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', skip_stem_max_pool=True,
1106+
avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
1107+
return _create_resnet('resnetrs200', pretrained, **model_args)
1108+
1109+
1110+
@register_model
1111+
def resnetrs270(pretrained=False, **kwargs):
1112+
model_args = dict(
1113+
block=Bottleneck, layers=[4, 29, 53, 4], stem_width=32, stem_type='deep', skip_stem_max_pool=True,
1114+
avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
1115+
return _create_resnet('resnetrs270', pretrained, **model_args)
1116+
1117+
1118+
1119+
@register_model
1120+
def resnetrs350(pretrained=False, **kwargs):
1121+
model_args = dict(
1122+
block=Bottleneck, layers=[4, 36, 72, 4], stem_width=32, stem_type='deep', skip_stem_max_pool=True,
1123+
avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
1124+
return _create_resnet('resnetrs350', pretrained, **model_args)
1125+
1126+
1127+
@register_model
1128+
def resnetrs420(pretrained=False, **kwargs):
1129+
model_args = dict(
1130+
block=Bottleneck, layers=[4, 44, 87, 4], stem_width=32, stem_type='deep', skip_stem_max_pool=True,
1131+
avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
1132+
return _create_resnet('resnetrs420', pretrained, **model_args)
1133+
1134+
10561135
@register_model
10571136
def ecaresnet50d_pruned(pretrained=False, **kwargs):
10581137
"""Constructs a ResNet-50-D model pruned with eca.

0 commit comments

Comments
 (0)