Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
metavar='N', help='mini-batch size (default: 1)')
parser.add_argument('--img-size', default=None, type=int,
metavar='N', help='Input image dimension, uses model default if empty')
parser.add_argument('--input-size', default=None, nargs=3, type=int, metavar='N',
help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
Expand Down Expand Up @@ -82,6 +84,14 @@ def main():
if args.reparam:
model = reparameterize_model(model)

if args.input_size is not None:
assert len(args.input_size) == 3, 'input-size should be N H W (channels, height, width)'
input_size = args.input_size
elif args.img_size is not None:
input_size = (3, args.img_size, args.img_size)
else:
input_size = None

onnx_export(
model,
args.output,
Expand All @@ -93,7 +103,7 @@ def main():
training=args.training,
verbose=args.verbose,
use_dynamo=args.dynamo,
input_size=(3, args.img_size, args.img_size),
input_size=input_size,
batch_size=args.batch_size,
)

Expand Down
11 changes: 3 additions & 8 deletions timm/utils/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def onnx_export(

if example_input is None:
if not input_size:
assert hasattr(model, 'default_cfg')
assert hasattr(model, 'default_cfg'), 'Cannot file model default config, input size must be provided'
input_size = model.default_cfg.get('input_size')
example_input = torch.randn((batch_size,) + input_size, requires_grad=training)

Expand Down Expand Up @@ -78,9 +78,8 @@ def onnx_export(
export_options=export_options,
)
export_output.save(output_file)
torch_out = None
else:
torch_out = torch.onnx._export(
torch.onnx.export(
model,
example_input,
output_file,
Expand All @@ -101,9 +100,5 @@ def onnx_export(
if check_forward and not training:
import numpy as np
onnx_out = onnx_forward(output_file, example_input)
if torch_out is not None:
np.testing.assert_almost_equal(torch_out.numpy(), onnx_out, decimal=3)
np.testing.assert_almost_equal(original_out.numpy(), torch_out.numpy(), decimal=5)
else:
np.testing.assert_almost_equal(original_out.numpy(), onnx_out, decimal=3)
np.testing.assert_almost_equal(original_out.numpy(), onnx_out, decimal=3)