Skip to content

Commit dbe7531

Browse files
committed
Update scripts to support torch.compile(). Make --results_file arg more consistent across benchmark/validate/inference. Fix #1570
1 parent 05637a4 commit dbe7531

File tree

4 files changed

+104
-110
lines changed

4 files changed

+104
-110
lines changed

benchmark.py

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,7 @@
5656
except ImportError as e:
5757
has_functorch = False
5858

59-
try:
60-
import torch._dynamo
61-
has_dynamo = True
62-
except ImportError:
63-
has_dynamo = False
64-
pass
65-
59+
has_compile = hasattr(torch, 'compile')
6660

6761
if torch.cuda.is_available():
6862
torch.backends.cuda.matmul.allow_tf32 = True
@@ -81,8 +75,10 @@
8175
help='Provide train fwd/bwd/opt breakdown detail if True. Defaults to False')
8276
parser.add_argument('--no-retry', action='store_true', default=False,
8377
help='Do not decay batch size and retry on error.')
84-
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
78+
parser.add_argument('--results-file', default='', type=str,
8579
help='Output csv file for validation results (summary)')
80+
parser.add_argument('--results-format', default='csv', type=str,
81+
help='Format for results file one of (csv, json) (default: csv).')
8682
parser.add_argument('--num-warm-iter', default=10, type=int,
8783
metavar='N', help='Number of warmup iterations (default: 10)')
8884
parser.add_argument('--num-bench-iter', default=40, type=int,
@@ -113,19 +109,18 @@
113109
help='Numeric precision. One of (amp, float32, float16, bfloat16, tf32)')
114110
parser.add_argument('--fuser', default='', type=str,
115111
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
116-
parser.add_argument('--dynamo-backend', default=None, type=str,
117-
help="Select dynamo backend. Default: None")
118112
parser.add_argument('--fast-norm', default=False, action='store_true',
119113
help='enable experimental fast-norm')
120114

121115
# codegen (model compilation) options
122116
scripting_group = parser.add_mutually_exclusive_group()
123117
scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true',
124118
help='convert model torchscript for inference')
119+
scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor',
120+
help="Enable compilation w/ specified backend (default: inductor).")
125121
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
126122
help="Enable AOT Autograd optimization.")
127-
scripting_group.add_argument('--dynamo', default=False, action='store_true',
128-
help="Enable Dynamo optimization.")
123+
129124

130125
# train optimizer parameters
131126
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
@@ -218,9 +213,8 @@ def __init__(
218213
detail=False,
219214
device='cuda',
220215
torchscript=False,
216+
torchcompile=None,
221217
aot_autograd=False,
222-
dynamo=False,
223-
dynamo_backend=None,
224218
precision='float32',
225219
fuser='',
226220
num_warm_iter=10,
@@ -259,20 +253,19 @@ def __init__(
259253
self.input_size = data_config['input_size']
260254
self.batch_size = kwargs.pop('batch_size', 256)
261255

262-
self.scripted = False
256+
self.compiled = False
263257
if torchscript:
264258
self.model = torch.jit.script(self.model)
265-
self.scripted = True
266-
elif dynamo:
267-
assert has_dynamo, "torch._dynamo is needed for --dynamo"
259+
self.compiled = True
260+
elif torchcompile:
261+
assert has_compile, 'A version of torch w/ torch.compile() is required, possibly a nightly.'
268262
torch._dynamo.reset()
269-
if dynamo_backend is not None:
270-
self.model = torch._dynamo.optimize(dynamo_backend)(self.model)
271-
else:
272-
self.model = torch._dynamo.optimize()(self.model)
263+
self.model = torch.compile(self.model, backend=torchcompile)
264+
self.compiled = True
273265
elif aot_autograd:
274266
assert has_functorch, "functorch is needed for --aot-autograd"
275267
self.model = memory_efficient_fusion(self.model)
268+
self.compiled = True
276269

277270
self.example_inputs = None
278271
self.num_warm_iter = num_warm_iter
@@ -344,7 +337,7 @@ def _step():
344337
param_count=round(self.param_count / 1e6, 2),
345338
)
346339

347-
retries = 0 if self.scripted else 2 # skip profiling if model is scripted
340+
retries = 0 if self.compiled else 2 # skip profiling if model is scripted
348341
while retries:
349342
retries -= 1
350343
try:
@@ -642,7 +635,6 @@ def main():
642635
model_cfgs = [(n, None) for n in model_names]
643636

644637
if len(model_cfgs):
645-
results_file = args.results_file or './benchmark.csv'
646638
_logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names)))
647639
results = []
648640
try:
@@ -663,22 +655,30 @@ def main():
663655
sort_key = 'infer_gmacs'
664656
results = filter(lambda x: sort_key in x, results)
665657
results = sorted(results, key=lambda x: x[sort_key], reverse=True)
666-
if len(results):
667-
write_results(results_file, results)
668658
else:
669659
results = benchmark(args)
670660

661+
if args.results_file:
662+
write_results(args.results_file, results, format=args.results_format)
663+
671664
# output results in JSON to stdout w/ delimiter for runner script
672665
print(f'--result\n{json.dumps(results, indent=4)}')
673666

674667

675-
def write_results(results_file, results):
668+
def write_results(results_file, results, format='csv'):
676669
with open(results_file, mode='w') as cf:
677-
dw = csv.DictWriter(cf, fieldnames=results[0].keys())
678-
dw.writeheader()
679-
for r in results:
680-
dw.writerow(r)
681-
cf.flush()
670+
if format == 'json':
671+
json.dump(results, cf, indent=4)
672+
else:
673+
if not isinstance(results, (list, tuple)):
674+
results = [results]
675+
if not results:
676+
return
677+
dw = csv.DictWriter(cf, fieldnames=results[0].keys())
678+
dw.writeheader()
679+
for r in results:
680+
dw.writerow(r)
681+
cf.flush()
682682

683683

684684
if __name__ == '__main__':

inference.py

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import os
99
import time
1010
import argparse
11+
import json
1112
import logging
1213
from contextlib import suppress
1314
from functools import partial
@@ -41,11 +42,7 @@
4142
except ImportError as e:
4243
has_functorch = False
4344

44-
try:
45-
import torch._dynamo
46-
has_dynamo = True
47-
except ImportError:
48-
has_dynamo = False
45+
has_compile = hasattr(torch, 'compile')
4946

5047

5148
_FMT_EXT = {
@@ -60,14 +57,16 @@
6057

6158

6259
parser = argparse.ArgumentParser(description='PyTorch ImageNet Inference')
63-
parser.add_argument('data', metavar='DIR',
64-
help='path to dataset')
65-
parser.add_argument('--dataset', '-d', metavar='NAME', default='',
66-
help='dataset type (default: ImageFolder/ImageTar if empty)')
60+
parser.add_argument('data', nargs='?', metavar='DIR', const=None,
61+
help='path to dataset (*deprecated*, use --data-dir)')
62+
parser.add_argument('--data-dir', metavar='DIR',
63+
help='path to dataset (root dir)')
64+
parser.add_argument('--dataset', metavar='NAME', default='',
65+
help='dataset type + name ("<type>/<name>") (default: ImageFolder or ImageTar if empty)')
6766
parser.add_argument('--split', metavar='NAME', default='validation',
6867
help='dataset split (default: validation)')
69-
parser.add_argument('--model', '-m', metavar='MODEL', default='dpn92',
70-
help='model architecture (default: dpn92)')
68+
parser.add_argument('--model', '-m', metavar='MODEL', default='resnet50',
69+
help='model architecture (default: resnet50)')
7170
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
7271
help='number of data loading workers (default: 2)')
7372
parser.add_argument('-b', '--batch-size', default=256, type=int,
@@ -112,16 +111,14 @@
112111
help='lower precision AMP dtype (default: float16)')
113112
parser.add_argument('--fuser', default='', type=str,
114113
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
115-
parser.add_argument('--dynamo-backend', default=None, type=str,
116-
help="Select dynamo backend. Default: None")
117114

118115
scripting_group = parser.add_mutually_exclusive_group()
119116
scripting_group.add_argument('--torchscript', default=False, action='store_true',
120117
help='torch.jit.script the full model')
118+
scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor',
119+
help="Enable compilation w/ specified backend (default: inductor).")
121120
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
122121
help="Enable AOT Autograd support.")
123-
scripting_group.add_argument('--dynamo', default=False, action='store_true',
124-
help="Enable Dynamo optimization.")
125122

126123
parser.add_argument('--results-dir',type=str, default=None,
127124
help='folder for output results')
@@ -160,7 +157,6 @@ def main():
160157
device = torch.device(args.device)
161158

162159
# resolve AMP arguments based on PyTorch / Apex availability
163-
use_amp = None
164160
amp_autocast = suppress
165161
if args.amp:
166162
assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).'
@@ -201,22 +197,20 @@ def main():
201197

202198
if args.torchscript:
203199
model = torch.jit.script(model)
200+
elif args.torchcompile:
201+
assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
202+
torch._dynamo.reset()
203+
model = torch.compile(model, backend=args.torchcompile)
204204
elif args.aot_autograd:
205205
assert has_functorch, "functorch is needed for --aot-autograd"
206206
model = memory_efficient_fusion(model)
207-
elif args.dynamo:
208-
assert has_dynamo, "torch._dynamo is needed for --dynamo"
209-
torch._dynamo.reset()
210-
if args.dynamo_backend is not None:
211-
model = torch._dynamo.optimize(args.dynamo_backend)(model)
212-
else:
213-
model = torch._dynamo.optimize()(model)
214207

215208
if args.num_gpu > 1:
216209
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))
217210

211+
root_dir = args.data or args.data_dir
218212
dataset = create_dataset(
219-
root=args.data,
213+
root=root_dir,
220214
name=args.dataset,
221215
split=args.split,
222216
class_map=args.class_map,
@@ -304,6 +298,9 @@ def main():
304298
for fmt in args.results_format:
305299
save_results(df, results_filename, fmt)
306300

301+
print(f'--result')
302+
print(json.dumps(dict(filename=results_filename)))
303+
307304

308305
def save_results(df, results_filename, results_format='csv', filename_col='filename'):
309306
results_filename += _FMT_EXT[results_format]

train.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,7 @@
6666
except ImportError as e:
6767
has_functorch = False
6868

69-
try:
70-
import torch._dynamo
71-
has_dynamo = True
72-
except ImportError:
73-
has_dynamo = False
74-
pass
69+
has_compile = hasattr(torch, 'compile')
7570

7671

7772
_logger = logging.getLogger('train')
@@ -88,10 +83,12 @@
8883
# Dataset parameters
8984
group = parser.add_argument_group('Dataset parameters')
9085
# Keep this argument outside of the dataset group because it is positional.
91-
parser.add_argument('data_dir', metavar='DIR',
92-
help='path to dataset')
93-
group.add_argument('--dataset', '-d', metavar='NAME', default='',
94-
help='dataset type (default: ImageFolder/ImageTar if empty)')
86+
parser.add_argument('data', nargs='?', metavar='DIR', const=None,
87+
help='path to dataset (positional is *deprecated*, use --data-dir)')
88+
parser.add_argument('--data-dir', metavar='DIR',
89+
help='path to dataset (root dir)')
90+
parser.add_argument('--dataset', metavar='NAME', default='',
91+
help='dataset type + name ("<type>/<name>") (default: ImageFolder or ImageTar if empty)')
9592
group.add_argument('--train-split', metavar='NAME', default='train',
9693
help='dataset train split (default: train)')
9794
group.add_argument('--val-split', metavar='NAME', default='validation',
@@ -143,16 +140,14 @@
143140
help='Enable gradient checkpointing through model blocks/stages')
144141
group.add_argument('--fast-norm', default=False, action='store_true',
145142
help='enable experimental fast-norm')
146-
parser.add_argument('--dynamo-backend', default=None, type=str,
147-
help="Select dynamo backend. Default: None")
148143

149144
scripting_group = group.add_mutually_exclusive_group()
150145
scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true',
151146
help='torch.jit.script the full model')
147+
scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor',
148+
help="Enable compilation w/ specified backend (default: inductor).")
152149
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
153150
help="Enable AOT Autograd support.")
154-
scripting_group.add_argument('--dynamo', default=False, action='store_true',
155-
help="Enable Dynamo optimization.")
156151

157152
# Optimizer parameters
158153
group = parser.add_argument_group('Optimizer parameters')
@@ -377,6 +372,8 @@ def main():
377372
torch.backends.cuda.matmul.allow_tf32 = True
378373
torch.backends.cudnn.benchmark = True
379374

375+
if args.data and not args.data_dir:
376+
args.data_dir = args.data
380377
args.prefetcher = not args.no_prefetcher
381378
device = utils.init_distributed_device(args)
382379
if args.distributed:
@@ -485,18 +482,16 @@ def main():
485482
assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
486483
assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
487484
model = torch.jit.script(model)
485+
elif args.torchcompile:
486+
# FIXME dynamo might need move below DDP wrapping? TBD
487+
assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
488+
torch._dynamo.reset()
489+
model = torch.compile(model, backend=args.torchcompile)
488490
elif args.aot_autograd:
489491
assert has_functorch, "functorch is needed for --aot-autograd"
490492
model = memory_efficient_fusion(model)
491-
elif args.dynamo:
492-
# FIXME dynamo might need move below DDP wrapping? TBD
493-
assert has_dynamo, "torch._dynamo is needed for --dynamo"
494-
if args.dynamo_backend is not None:
495-
model = torch._dynamo.optimize(args.dynamo_backend)(model)
496-
else:
497-
model = torch._dynamo.optimize()(model)
498493

499-
if args.lr is None:
494+
if not args.lr:
500495
global_batch_size = args.batch_size * args.world_size
501496
batch_ratio = global_batch_size / args.lr_base_size
502497
if not args.lr_base_scale:

0 commit comments

Comments
 (0)