Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions docsrc/user_guide/saving_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,19 @@ b) ExportedProgram

import torch
import torch_tensorrt
from torch_tensorrt.dynamo.export import transform, create_exported_program

model = MyModel().eval().cuda()
inputs = torch.randn((1, 3, 224, 224)).cuda()
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) # Output is a torch.fx.GraphModule
# Transform and create an exported program
trt_gm = transform(trt_gm, inputs)
trt_exp_program = create_exported_program(trt_gm, call_spec, trt_gm.state_dict())
torch._export.save(trt_exp_program, "trt_model.ep")
trt_exp_program = torch_tensorrt.dynamo.transform(trt_gm, inputs, call_spec)
torch.export.save(trt_exp_program, "trt_model.ep")

# Later, you can load it and run inference
model = torch._export.load("trt_model.ep")
model = torch.export.load("trt_model.ep")
model(inputs)

`torch_tensorrt.dynamo.export.transform` inlines the submodules within a GraphModule to their corresponding nodes and stiches all the nodes together.
`torch_tensorrt.dynamo.transform` inlines the submodules within a GraphModule to their corresponding nodes, stiches all the nodes together and creates an ExportedProgram.
This is needed as `torch._export` serialization cannot handle serializing and deserializing of submodules (`call_module` nodes).

NOTE: This way of saving the models using `ExportedProgram` is experimental. Here is a known issue : https://github.com/pytorch/TensorRT/issues/2341
Expand Down
5 changes: 2 additions & 3 deletions py/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
numpy
packaging
pybind11==2.6.2
--extra-index-url https://download.pytorch.org/whl/nightly/cu121
torch>=2.1.0,<2.2.0
torchvision>=0.16.0,<0.17.0
torch==2.1.0
torchvision==0.16.0
--extra-index-url https://pypi.ngc.nvidia.com
tensorrt==8.6.1
pyyaml
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@
DYNAMO_CONVERTERS,
dynamo_tensorrt_converter,
)
from .export import transform
12 changes: 10 additions & 2 deletions py/torch_tensorrt/dynamo/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@


def compile(
exported_program: ExportedProgram,
exported_program: Union[torch.fx.GraphModule, ExportedProgram],
inputs: Any,
*,
device: Optional[Union[Device, torch.device, str]] = DEVICE,
Expand Down Expand Up @@ -86,7 +86,15 @@ def compile(
inputs = prepare_inputs(inputs)
device = to_torch_tensorrt_device(device)

gm = exported_program.module()
if isinstance(exported_program, torch.fx.GraphModule):
gm = exported_program
elif isinstance(exported_program, ExportedProgram):
gm = exported_program.module()
else:
raise AssertionError(
f"Input graph should either be an ExportedProgram or a GraphModule but got type {type(exported_program)}"
)

logger.debug("Input graph: " + str(gm.graph))

# Apply lowering on the graph module
Expand Down
18 changes: 9 additions & 9 deletions py/torch_tensorrt/dynamo/export.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import copy
import operator
from typing import Any, Dict, Sequence, Tuple, Union, cast
from typing import Any, Dict, Sequence, Tuple, cast

import torch
from torch._export.exported_program import CallSpec
Expand All @@ -11,8 +11,8 @@


def transform(
gm: torch.fx.GraphModule, inputs: Sequence[torch.Tensor]
) -> torch.fx.GraphModule:
gm: torch.fx.GraphModule, inputs: Sequence[torch.Tensor], call_spec: CallSpec
) -> ExportedProgram:
# Run shape analysis
_, outputs_map = partitioning.run_shape_analysis(gm, inputs)

Expand All @@ -31,7 +31,10 @@ def transform(
gm.graph.eliminate_dead_code()
gm.graph.lint()

return gm
# Create an exported program with the TRT GraphModule
exp_program = create_trt_exp_program(gm, call_spec)

return exp_program


def lift_constant_pass(trt_gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
Expand Down Expand Up @@ -115,7 +118,6 @@ def inline_torch_modules(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:

# Copy all nodes in the submodule into gm and
# store the output node of this submodule which is now present in gm

submodule_output = gm.graph.graph_copy(submodule.graph, val_map)

# Get their references (since we copied) in the parent graph (gm)
Expand Down Expand Up @@ -174,9 +176,7 @@ def copy_submodule_attributes(


def create_trt_exp_program(
gm: torch.fx.GraphModule,
call_spec: CallSpec,
state_dict: Dict[str, Union[torch.Tensor, torch.nn.Parameter]],
gm: torch.fx.GraphModule, call_spec: CallSpec
) -> ExportedProgram:
"""Creates a new Exported Program. This function takes an torch.fx.GraphModule which has TRT engines
and constructs an Exported Program object with the new IO node names, call_spec and state_dict
Expand Down Expand Up @@ -208,7 +208,7 @@ def create_trt_exp_program(
)

trt_exp_program = ExportedProgram(
gm, gm.graph, trt_graph_signature, call_spec, state_dict, {}, [], []
gm, gm.graph, trt_graph_signature, call_spec, gm.state_dict(), {}, [], []
)

return trt_exp_program
Expand Down
85 changes: 5 additions & 80 deletions tests/py/dynamo/models/test_export_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import torch_tensorrt as torchtrt
import torchvision.models as models
from torch._export.serde.serialize import deserialize, serialize
from torch_tensorrt.dynamo.export import create_trt_exp_program, transform
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity

assertions = unittest.TestCase()
Expand Down Expand Up @@ -45,10 +44,7 @@ def forward(self, x):

exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec)
trt_gm = transform(trt_gm, [input])
trt_exp_program = create_trt_exp_program(
trt_gm, exp_program.call_spec, trt_gm.state_dict()
)
trt_exp_program = torchtrt.dynamo.transform(trt_gm, [input], exp_program.call_spec)
serialized_prog = serialize(trt_exp_program)
deserialized_prog = deserialize(*serialized_prog)

Expand Down Expand Up @@ -100,11 +96,7 @@ def forward(self, x):

exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec)
trt_gm = transform(trt_gm, [input])
trt_exp_program = create_trt_exp_program(
trt_gm, exp_program.call_spec, trt_gm.state_dict()
)

trt_exp_program = torchtrt.dynamo.transform(trt_gm, [input], exp_program.call_spec)
serialized_prog = serialize(trt_exp_program)
deserialized_prog = deserialize(*serialized_prog)
# Check Pyt and TRT exported program outputs
Expand Down Expand Up @@ -161,11 +153,7 @@ def forward(self, x):

exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec)
trt_gm = transform(trt_gm, [input])
trt_exp_program = create_trt_exp_program(
trt_gm, exp_program.call_spec, trt_gm.state_dict()
)

trt_exp_program = torchtrt.dynamo.transform(trt_gm, [input], exp_program.call_spec)
torch._export.save(trt_exp_program, "/tmp/trt.ep")
deser_trt_exp_program = torch._export.load("/tmp/trt.ep")

Expand Down Expand Up @@ -224,11 +212,7 @@ def forward(self, x):

exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec)
trt_gm = transform(trt_gm, [input])
trt_exp_program = create_trt_exp_program(
trt_gm, exp_program.call_spec, trt_gm.state_dict()
)

trt_exp_program = torchtrt.dynamo.transform(trt_gm, [input], exp_program.call_spec)
torch._export.save(trt_exp_program, "/tmp/trt.ep")
deser_trt_exp_program = torch._export.load("/tmp/trt.ep")

Expand Down Expand Up @@ -270,10 +254,7 @@ def test_resnet18_save_load(ir):

exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec)
trt_gm = transform(trt_gm, [input])
trt_exp_program = create_trt_exp_program(
trt_gm, exp_program.call_spec, trt_gm.state_dict()
)
trt_exp_program = torchtrt.dynamo.transform(trt_gm, [input], exp_program.call_spec)
torch._export.save(trt_exp_program, "/tmp/trt.ep")
deser_trt_exp_program = torch._export.load("/tmp/trt.ep")

Expand All @@ -291,59 +272,3 @@ def test_resnet18_save_load(ir):
cos_sim > COSINE_THRESHOLD,
msg=f"test_resnet18_save_load TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)


# Enable this test once this issue is resolved https://github.com/pytorch/TensorRT/issues/2341
# @pytest.mark.unit
# def test_hybrid_conv_fallback(ir):
# """
# This tests export save and load functionality on a hybrid
# model where a conv (a weighted layer) has been forced to fallback to Pytorch.
# """

# class MyModule(torch.nn.Module):
# def __init__(self):
# super().__init__()
# self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
# self.relu = torch.nn.ReLU()

# def forward(self, x):
# conv = self.conv(x)
# relu = self.relu(conv)
# mul = relu * 0.5
# return mul

# model = MyModule().eval().cuda()
# input = torch.randn((1, 3, 224, 224)).to("cuda")

# compile_spec = {
# "inputs": [
# torchtrt.Input(
# input.shape, dtype=torch.float, format=torch.contiguous_format
# )
# ],
# "ir": ir,
# "min_block_size": 1,
# "torch_executed_ops": "torch.ops.aten.convolution.default",
# }

# trt_exp_program = torchtrt.compile(model, **compile_spec)
# torch._export.save(trt_exp_program, "/tmp/trt.ep")
# deser_trt_exp_program = torch._export.load("/tmp/trt.ep")

# outputs_pyt = model(input)
# outputs_trt = trt_exp_program(input)
# for idx in range(len(outputs_pyt)):
# cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt[idx])
# assertions.assertTrue(
# cos_sim > COSINE_THRESHOLD,
# msg=f"test_base_full_compile_multiple_outputs TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
# )

# outputs_trt_deser = deser_trt_exp_program(input)
# for idx in range(len(outputs_pyt)):
# cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
# assertions.assertTrue(
# cos_sim > COSINE_THRESHOLD,
# msg=f"test_base_full_compile_save_load TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
# )