From d3a5c6c24031dff9cea4ad7124f0b998eb01e9ac Mon Sep 17 00:00:00 2001 From: cehongwang Date: Wed, 3 Jul 2024 15:42:29 -0700 Subject: [PATCH 1/9] Added refitting acceleration --- py/torch_tensorrt/dynamo/_refit.py | 156 +++++++++++++++--- .../dynamo/conversion/_TRTInterpreter.py | 136 ++++++++++++++- .../dynamo/conversion/_conversion.py | 19 +++ .../dynamo/conversion/converter_utils.py | 6 +- .../runtime/_PythonTorchTensorRTModule.py | 3 + .../dynamo/runtime/_TorchTensorRTModule.py | 35 ++-- 6 files changed, 308 insertions(+), 47 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index a97cb528d4..5a6609de3f 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -3,7 +3,7 @@ import collections.abc import copy import logging -from typing import Any, Optional, Sequence, Tuple +from typing import Any, List, Optional, Sequence, Tuple import numpy as np import tensorrt as trt @@ -13,7 +13,7 @@ from torch_tensorrt._Input import Input from torch_tensorrt.dynamo import partitioning from torch_tensorrt.dynamo._exporter import inline_torch_modules -from torch_tensorrt.dynamo.conversion import CompilationSettings +from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.conversion._conversion import infer_module_output_dtypes from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( DYNAMO_CONVERTERS as CONVERTERS, @@ -108,38 +108,97 @@ def construct_refit_mapping( return weight_map +def construct_refit_mapping_from_weight_name_map( + weight_name_map: dict[Any, Any], state_dict: dict[Any, Any] +) -> dict[Any, Any]: + engine_weight_map = {} + for engine_weight_name, (sd_weight_name, np_weight_type) in weight_name_map.items(): + trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType) + torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype) + if engine_weight_name.split(" ")[-1] in ["SCALE", "SHIFT"]: + # Batch Norm Layer + params = {} + for w in sd_weight_name: + params[w.split(".")[-1]] = state_dict[w] + scale = params["weight"] / torch.sqrt(params["running_var"] + 1e-7) + shift = params["bias"] - params["running_mean"] * scale + # Set scale to scale or shift to shift + engine_weight_map[engine_weight_name] = eval( + engine_weight_name.split(" ")[-1].lower() + ) + + elif sd_weight_name not in state_dict: + # If weights is not in sd, we can leave it unchanged + continue + else: + engine_weight_map[engine_weight_name] = state_dict[sd_weight_name] + + engine_weight_map[engine_weight_name] = ( + engine_weight_map[engine_weight_name] + .clone() + .reshape(-1) + .contiguous() + .to(torch_dtype), + trt_dtype, + ) + + return engine_weight_map + + def _refit_single_trt_engine_with_gm( new_gm: torch.fx.GraphModule, old_engine: trt.ICudaEngine, - input_list: Tuple[Any, ...], + input_list: Sequence[Any], settings: CompilationSettings = CompilationSettings(), + weight_name_map: Optional[dict[str, List[str]]] = None, ) -> None: """ Refit a TensorRT Engine in place """ - # Get the refitting mapping - mapping = construct_refit_mapping(new_gm, input_list, settings) + refitted = set() - trt_wt_location = trt.TensorLocation.HOST refitter = trt.Refitter(old_engine, TRT_LOGGER) weight_list = refitter.get_all_weights() - for layer_name in weight_list: - if layer_name not in mapping: - raise AssertionError(f"{layer_name} is not found in weight mapping") - # Use Numpy to create weights - weight, datatype = mapping[layer_name] - trt_wt_tensor = trt.Weights(datatype, weight.ctypes.data, weight.size) - refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location) - refitted.add(layer_name) + if weight_name_map: + # Get the refitting mapping + trt_wt_location = trt.TensorLocation.DEVICE + mapping = construct_refit_mapping_from_weight_name_map( + weight_name_map, new_gm.state_dict() + ) + for layer_name in weight_list: + if layer_name not in mapping: + logger.warning(f"{layer_name} is not found in weight mapping.") + continue + # Use Numpy to create weights + weight, weight_dtype = mapping[layer_name] + trt_wt_tensor = trt.Weights( + weight_dtype, weight.data_ptr(), torch.numel(weight) + ) + refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location) + assert ( + len(refitter.get_missing_weights()) == 0 + ), "Fast refitting failed due to incomplete mapping" - if len(refitted) != len(weight_list): - logger.warning("Not all weights have been refitted!!!") + else: + mapping = construct_refit_mapping(new_gm, input_list, settings) + trt_wt_location = trt.TensorLocation.HOST + for layer_name in weight_list: + if layer_name not in mapping: + raise AssertionError(f"{layer_name} is not found in weight mapping") + # Use Numpy to create weights + weight, datatype = mapping[layer_name] + trt_wt_tensor = trt.Weights(datatype, weight.ctypes.data, weight.size) + refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location) + refitted.add(layer_name) + + if len(refitted) != len(weight_list): + logger.warning("Not all weights have been refitted!!!") if not refitter.refit_cuda_engine(): logger.error("Error: failed to refit new weights.") - exit(0) + raise AssertionError("Refitting failed.") def refit_module_weights( @@ -148,6 +207,7 @@ def refit_module_weights( arg_inputs: Optional[Tuple[Any, ...]] = None, kwarg_inputs: Optional[dict[str, Any]] = None, verify_output: bool = False, + fast_refit: bool = True, ) -> torch.fx.GraphModule: """ Refit a compiled graph module with ExportedProgram. This performs weight updates in compiled_module without recompiling the engine. @@ -182,13 +242,14 @@ def refit_module_weights( for name, engine in compiled_module.__dict__.items() if "engine" in name ] - encoded_settings = compiled_submodules[0][1].__getstate__()[0][ + # [('_run_on_acc_0', inline_module)] + encoded_metadata = compiled_submodules[0][1].__getstate__()[0][ SERIALIZED_METADATA_IDX ] assert ( - encoded_settings != "" + encoded_metadata != "" ), "Settings are not saved in the engine. Please recompile the engine with make_refitable=True." - settings = TorchTensorRTModule.decode_metadata(encoded_settings) + settings = TorchTensorRTModule.decode_metadata(encoded_metadata)["settings"] # Handle torch modules compiled_submodules_map = dict(compiled_submodules) for name, submodule in compiled_module.named_children(): @@ -287,6 +348,7 @@ def refit_module_weights( # Extract engine from the submodule try: if inline_module: + weight_name_map = None compiled_submodule = compiled_submodules_map[name] # If this is a torch module, load the old state_dict if "_run_on_acc" not in name: @@ -297,8 +359,37 @@ def refit_module_weights( engine = get_engine_from_encoded_engine( engine_info[ENGINE_IDX], runtime ) + if fast_refit: + encoded_metadata = compiled_submodule.__getstate__()[0][ + SERIALIZED_METADATA_IDX + ] + assert ( + encoded_metadata != "" + ), "Metadata are not saved in the engine. Please recompile the engine with make_refitable=True." + weight_name_map = TorchTensorRTModule.decode_metadata( + encoded_metadata + )["weight_name_map"] + if not weight_name_map: + fast_refit = False + logger.warning( + "Fast refitting is not supported in this module. Use regular refitting." + ) else: compiled_submodule = getattr(compiled_module, name) + weight_name_map = None + if fast_refit: + try: + weight_name_map = compiled_submodule.weight_name_map + except AttributeError: + fast_refit = False + logger.warning( + "You are using a old version of Torch-TensorRT. Please re-compile the engine to avoid failures." + ) + if not weight_name_map: + fast_refit = False + logger.warning( + "Fast refitting is not supported in this module. Use regular refitting." + ) if isinstance(compiled_submodule, PythonTorchTensorRTModule): engine = compiled_submodule.engine elif isinstance(compiled_submodule, TorchTensorRTModule): @@ -335,13 +426,24 @@ def refit_module_weights( to_torch_device(settings.device), name, ) - - _refit_single_trt_engine_with_gm( - new_gm=new_submodule, - old_engine=engine, - input_list=submodule_inputs, - settings=settings, - ) + try: + _refit_single_trt_engine_with_gm( + new_gm=new_submodule, + old_engine=engine, + input_list=submodule_inputs, + settings=settings, + weight_name_map=weight_name_map, + ) + except AssertionError: + # If fast_refit is used and failed, we fall back to regular refit + if fast_refit and weight_name_map: + _refit_single_trt_engine_with_gm( + new_gm=new_submodule, + old_engine=engine, + input_list=submodule_inputs, + settings=settings, + weight_name_map=None, + ) if isinstance(compiled_submodule, TorchTensorRTModule): serialized_engine = bytes(engine.serialize()) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 703a650c99..74dbc14dd9 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -47,6 +47,7 @@ class TRTInterpreterResult(NamedTuple): serialized_engine: bytes input_names: Sequence[str] output_names: Sequence[str] + weight_name_map: Optional[dict[Any, Any]] class TRTInterpreter(torch.fx.Interpreter): # type: ignore[misc] @@ -110,6 +111,7 @@ def __init__( # Mapping of constants to shapes and dtypes self.const_mapping: Dict[str, Tuple[Sequence[int], str]] = {} + self.weight_name_map: Optional[dict[str, Any]] = None def validate_conversion(self) -> Set[str]: missing_converters: Set[str] = set() @@ -320,6 +322,134 @@ def _construct_trt_network_def(self) -> None: f"TRT INetwork construction elapsed time: {datetime.now() - run_module_start_time}" ) + def _save_weight_mapping(self) -> None: + + def find_weight( + weight_name: str, np_map: dict[str, Any], sd: dict[str, Any] + ) -> str: + network_weight = np_map[weight_name] + for sd_w_name, sd_weight in sd.items(): + if check_weight_equal(sd_weight, network_weight): + return sd_w_name + return "" + + def check_weight_equal( + sd_weight: torch.tensor, network_weight: np.ndarray + ) -> Any: + sd_weight = sd_weight.reshape(-1).cpu().numpy() + return sd_weight.size == network_weight.size and np.allclose( + sd_weight, network_weight, 1e-1, 1e-1 + ) + + MODULE_MAP = { + "SCALE": ( + trt.IScaleLayer, + [ + ( + "scale", + "SCALE", + ("weight", "bias", "running_mean", "running_var"), + ), + ( + "shift", + "SHIFT", + ("weight", "bias", "running_mean", "running_var"), + ), + ], + ), + "CONVOLUTION": ( + trt.IConvolutionLayer, + [("kernel", "KERNEL", "weight"), ("bias", "BIAS", "bias")], + ), + "DECONVOLUTION": ( + trt.IDeconvolutionLayer, + [("kernel", "KERNEL", "weight"), ("bias", "BIAS", "bias")], + ), + "CONSTANT": ( + trt.IConstantLayer, + [("weights", "CONSTANT", ("weight", "bias"))], + ), + } + """ + The structure of this map is: + { + layer_type: ( + Corresponding ILayer type to cast, + [ + ( + ILayer weight attribute, + Weight name postfix in TRT Engine, + Weight name postfix in state_dict + ), + ... + ] + ) + } + """ + + sd = self.module.state_dict() + weight_name_map: dict[str, Any] = {} + np_map = {} + net = self.ctx.net + for i in range(net.num_layers): + layer = net[i] + layer_type: str = layer.type.name + if layer_type in MODULE_MAP: + layer.__class__ = MODULE_MAP[layer_type][0] + # Name mapping + for weight_type, weight_name, torch_attr in MODULE_MAP[layer_type][1]: + weight = layer.__getattribute__(weight_type).copy() + if weight.size == 0: + continue + engine_weight_name = f"{layer.name} {weight_name}" + # Infer the corresponding weight name(s) in state_dict + sd_weight_name_list = ( + layer.name.split("-")[-1] + .replace("[", "") + .replace("]", "") + .split("/") + ) + sd_weight_name: Any = ".".join( + [i for i in sd_weight_name_list[:-1] if i] + ) + suffix = sd_weight_name_list[-1] + if layer_type == "CONSTANT": + if "embedding" in suffix: + sd_weight_name = f"{sd_weight_name}.{torch_attr[0]}" + elif "weight" in suffix or "mm_other" in suffix: + # Linear layer weight + sd_weight_name = f"{sd_weight_name}.{torch_attr[0]}" + else: + sd_weight_name = f"{sd_weight_name}.{torch_attr[1]}" + elif layer_type == "SCALE": + # Batch norm needs all weights to calculate scale and shift + sd_weight_name = [f"{sd_weight_name}.{n}" for n in torch_attr] + else: + sd_weight_name = f"{sd_weight_name}.{torch_attr}" + + weight_name_map[engine_weight_name] = sd_weight_name + np_map[engine_weight_name] = weight + + # Value mapping + for engine_weight_name, sd_weight_name in weight_name_map.items(): + if "SCALE" in engine_weight_name: + # There is no direct connection in batch_norm layer. So skip it + pass + elif sd_weight_name not in sd or not check_weight_equal( + sd[sd_weight_name], np_map[engine_weight_name] + ): + weight_name_map[engine_weight_name] = find_weight( + engine_weight_name, np_map, sd + ) + + weight_name_map[engine_weight_name] = [ + weight_name_map[engine_weight_name], + np_map[engine_weight_name].dtype, + ] + + self.weight_name_map = weight_name_map + # check = {k:(weight_name_map[k], np_map[k]) for k, v in np_map.items()} + def run( self, strict_type_constraints: bool = False, @@ -335,6 +465,10 @@ def run( TRTInterpreterResult """ self._construct_trt_network_def() + + if self.compilation_settings.make_refitable: + self._save_weight_mapping() + build_engine_start_time = datetime.now() builder_config = self._populate_trt_builder_config( @@ -363,7 +497,7 @@ def run( engine_bytes.write(serialized_engine) engine_str = engine_bytes.getvalue() - return TRTInterpreterResult(engine_str, self._input_names, self._output_names) + return TRTInterpreterResult(engine_str, self._input_names, self._output_names, self.weight_name_map) def run_node(self, n: torch.fx.Node) -> torch.fx.Node: self._cur_node_name = get_node_name(n) diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index c1663ca5cd..a62a59905a 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -126,6 +126,24 @@ def convert_module( PythonTorchTensorRTModule or TorchTensorRTModule """ interpreter_result = interpret_module_to_result(module, inputs, settings) + # Test fast refit: + from torch_tensorrt.dynamo._refit import _refit_single_trt_engine_with_gm + from torch_tensorrt.logging import TRT_LOGGER + + runtime = trt.Runtime(TRT_LOGGER) + refit_test_engine = runtime.deserialize_cuda_engine(interpreter_result.engine) + weight_name_map: Any = interpreter_result.weight_name_map + try: + _refit_single_trt_engine_with_gm( + new_gm=module, + old_engine=refit_test_engine, + input_list=inputs, + settings=settings, + weight_name_map=weight_name_map, + ) + except AssertionError: + logger.warning("Fast refit test failed. Removing the weight map caching.") + weight_name_map = None rt_cls = PythonTorchTensorRTModule @@ -149,4 +167,5 @@ def convert_module( output_binding_names=list(interpreter_result.output_names), name=name, settings=settings, + weight_name_map = weight_name_map ) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index f847091800..af0d6b720a 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -1,7 +1,6 @@ import collections import functools import logging -import re from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, overload import numpy as np @@ -34,10 +33,7 @@ def get_node_name(node: torch.fx.Node) -> str: mod_stack = stack_item.popitem() if stack_item else "" node_name = str(node) if mod_stack: - mod_name = str(mod_stack[0]).replace("___", "/") - # Clean up the module name - mod_name = re.sub("^.*__self", "", mod_name) - mod_name = re.sub(r"_(\d+)$", r"/\g<1>", mod_name) + mod_name = mod_stack[1][0] node_name = mod_name + "/" + node_name else: # Try an alternative way to get the module info diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 659f18af52..e21e83aaac 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -38,6 +38,7 @@ def __init__( *, name: str = "", settings: CompilationSettings = CompilationSettings(), + weight_name_map: Any = None, ): """Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs a PyTorch ``torch.nn.Module`` around it. Uses TensorRT Python APIs to run the engine @@ -102,6 +103,8 @@ def __init__( self.profiling_enabled = settings.debug if settings.debug is not None else False self.settings = settings self.engine = None + self.weight_name_map = weight_name_map + self._initialize() if self.serialized_engine is not None and not self.settings.lazy_engine_init: self.setup_engine() diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 0ab0dd49ca..d3216177db 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -58,6 +58,7 @@ def __init__( *, name: str = "", settings: CompilationSettings = CompilationSettings(), # Assumes engine was built with default compilation settings if object not passed + weight_name_map: Optional[dict[Any, Any]] = None, ): """Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs a PyTorch ``torch.nn.Module`` around it. Uses the Torch-TensorRT runtime extension to run the engines @@ -107,6 +108,7 @@ def __init__( self.name = name self.hardware_compatible = settings.hardware_compatible self.settings = copy.deepcopy(settings) + self.weight_name_map = weight_name_map self.serialized_engine = serialized_engine self.engine = None @@ -130,6 +132,7 @@ def setup_engine(self) -> None: if self.settings.device is not None else Device._current_device() ) + metadata = {"settings": self.settings, "weight_name_map": self.weight_name_map} self.engine = torch.classes.tensorrt.Engine( [ torch.ops.tensorrt.ABI_VERSION(), @@ -139,25 +142,29 @@ def setup_engine(self) -> None: TorchTensorRTModule._pack_binding_names(self.input_binding_names), TorchTensorRTModule._pack_binding_names(self.output_binding_names), str(int(self.hardware_compatible)), - self.encode_metadata(self.settings), + self.encode_metadata(metadata), ] ) - - def encode_metadata(self, settings: Any) -> str: - settings = copy.deepcopy(settings) - settings.torch_executed_ops = { - f"torch.ops.{op.__str__()}" for op in settings.torch_executed_ops + + + def encode_metadata(self, metadata: Any) -> str: + metadata = copy.deepcopy(metadata) + metadata["settings"].torch_executed_ops = { + f"torch.ops.{op.__str__()}" + for op in metadata["settings"].torch_executed_ops } - dumped_settings = pickle.dumps(settings) - encoded_settings = base64.b64encode(dumped_settings).decode("utf-8") - return encoded_settings + dumped_metadata = pickle.dumps(metadata) + encoded_metadata = base64.b64encode(dumped_metadata).decode("utf-8") + return encoded_metadata @staticmethod - def decode_metadata(encoded_settings: bytes) -> Any: - dumped_settings = base64.b64decode(encoded_settings.encode("utf-8")) - settings = pickle.loads(dumped_settings) - settings.torch_executed_ops = {eval(op) for op in settings.torch_executed_ops} - return settings + def decode_metadata(encoded_metadata: bytes) -> Any: + dumped_metadata = base64.b64decode(encoded_metadata.encode("utf-8")) + metadata = pickle.loads(dumped_metadata) + metadata["settings"].torch_executed_ops = { + eval(op) for op in metadata["settings"].torch_executed_ops + } + return metadata def get_extra_state(self) -> SerializedTorchTensorRTModuleFmt: if self.engine is None and self.serialized_engine is not None: From 3d3d59dae9b83554a8c49fa271c8236c054476b9 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Wed, 3 Jul 2024 15:56:20 -0700 Subject: [PATCH 2/9] Added test of fast refitting --- tests/py/dynamo/models/test_model_refit.py | 269 ++++++++++++++++++++- 1 file changed, 264 insertions(+), 5 deletions(-) diff --git a/tests/py/dynamo/models/test_model_refit.py b/tests/py/dynamo/models/test_model_refit.py index 82e655d736..b611398e61 100644 --- a/tests/py/dynamo/models/test_model_refit.py +++ b/tests/py/dynamo/models/test_model_refit.py @@ -81,6 +81,260 @@ def test_mapping(): torch._dynamo.reset() +@pytest.mark.unit +def test_fast_refit_one_engine(): + + model = models.resnet18(pretrained=False).eval().to("cuda") + model2 = models.resnet18(pretrained=True).eval().to("cuda") + inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] + enabled_precisions = {torch.float} + debug = False + min_block_size = 1 + use_python_runtime = False + + exp_program = torch.export.export(model, tuple(inputs)) + exp_program2 = torch.export.export(model2, tuple(inputs)) + + trt_gm = torchtrt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=use_python_runtime, + enabled_precisions=enabled_precisions, + debug=debug, + min_block_size=min_block_size, + make_refitable=True, + ) + + new_trt_gm = refit_module_weights( + compiled_module=trt_gm, + new_weight_module=exp_program2, + inputs=inputs, + fast_refit=True, + ) + + # Check the output + expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( + *inputs + ) + for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): + assertions.assertTrue( + torch.allclose(expected_output, refitted_output, 1e-2, 1e-2), + "Refit Result is not correct. Refit failed", + ) + # Clean up model env + + torch._dynamo.reset() + + +@pytest.mark.unit +def test_fast_refit_one_engine_bert(): + inputs = [ + torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"), + ] + model = BertModel.from_pretrained("bert-base-uncased").eval().to("cuda") + model2 = BertModel.from_pretrained("bert-base-uncased").eval().to("cuda") + nn.init.xavier_normal_(model2.embeddings.word_embeddings.weight) + enabled_precisions = {torch.float} + debug = False + min_block_size = 1 + use_python_runtime = False + + exp_program = torch.export.export(model, tuple(inputs)) + exp_program2 = torch.export.export(model2, tuple(inputs)) + + trt_gm = torchtrt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=use_python_runtime, + enabled_precisions=enabled_precisions, + debug=debug, + min_block_size=min_block_size, + make_refitable=True, + ) + + new_trt_gm = refit_module_weights( + compiled_module=trt_gm, + new_weight_module=exp_program2, + inputs=inputs, + fast_refit=True, + ) + + # Check the output + expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( + *inputs + ) + for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): + if not isinstance(expected_output, torch.Tensor) or not isinstance( + refitted_output, torch.Tensor + ): + continue + assertions.assertTrue( + torch.allclose(expected_output, refitted_output, 1e-2, 1e-2), + "Refit Result is not correct. Refit failed", + ) + # Clean up model env + + torch._dynamo.reset() + + +@pytest.mark.unit +def test_fast_refit_one_engine_inline_runtime(): + trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep") + model = models.resnet18(pretrained=False).eval().to("cuda") + model2 = models.resnet18(pretrained=True).eval().to("cuda") + inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] + enabled_precisions = {torch.float} + debug = False + min_block_size = 1 + use_python_runtime = False + + exp_program = torch.export.export(model, tuple(inputs)) + exp_program2 = torch.export.export(model2, tuple(inputs)) + + trt_gm = torchtrt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=use_python_runtime, + enabled_precisions=enabled_precisions, + debug=debug, + min_block_size=min_block_size, + make_refitable=True, + ) + torchtrt.save(trt_gm, trt_ep_path, inputs=inputs) + trt_gm = torch.export.load(trt_ep_path) + new_trt_gm = refit_module_weights( + compiled_module=trt_gm, + new_weight_module=exp_program2, + inputs=inputs, + fast_refit=True, + ) + + # Check the output + expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( + *inputs + ) + for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): + assertions.assertTrue( + torch.allclose(expected_output, refitted_output, 1e-2, 1e-2), + "Refit Result is not correct. Refit failed", + ) + # Clean up model env + + torch._dynamo.reset() + + +@pytest.mark.unit +def test_fast_refit_one_engine_python_runtime(): + + model = models.resnet18(pretrained=False).eval().to("cuda") + model2 = models.resnet18(pretrained=True).eval().to("cuda") + inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] + enabled_precisions = {torch.float} + debug = False + min_block_size = 1 + use_python_runtime = True + + exp_program = torch.export.export(model, tuple(inputs)) + exp_program2 = torch.export.export(model2, tuple(inputs)) + + trt_gm = torchtrt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=use_python_runtime, + enabled_precisions=enabled_precisions, + debug=debug, + min_block_size=min_block_size, + make_refitable=True, + ) + + new_trt_gm = refit_module_weights( + compiled_module=trt_gm, + new_weight_module=exp_program2, + inputs=inputs, + fast_refit=True, + ) + + # Check the output + expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( + *inputs + ) + for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): + assertions.assertTrue( + torch.allclose(expected_output, refitted_output, 1e-2, 1e-2), + "Refit Result is not correct. Refit failed", + ) + + # Clean up model env + torch._dynamo.reset() + + +@pytest.mark.unit +def test_fast_refit_multiple_engine(): + + class net(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 12, 3, padding=1) + self.bn = nn.BatchNorm2d(12) + self.conv2 = nn.Conv2d(12, 12, 3, padding=1) + self.fc1 = nn.Linear(12 * 56 * 56, 10) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.bn(x) + x = F.max_pool2d(x, (2, 2)) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, (2, 2)) + x = torch.flatten(x, 1) + return self.fc1(x) + + model = net().eval().to("cuda") + model2 = net().eval().to("cuda") + + inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] + enabled_precisions = {torch.float} + debug = False + min_block_size = 1 + use_python_runtime = False + + exp_program = torch.export.export(model, tuple(inputs)) + exp_program2 = torch.export.export(model2, tuple(inputs)) + + torch_executed_ops = {torch.ops.aten.convolution.default} + trt_gm = torchtrt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=use_python_runtime, + enabled_precisions=enabled_precisions, + debug=debug, + min_block_size=min_block_size, + make_refitable=True, + torch_executed_ops=torch_executed_ops, + ) + + new_trt_gm = refit_module_weights( + compiled_module=trt_gm, + new_weight_module=exp_program2, + inputs=inputs, + fast_refit=True, + ) + + # Check the output + expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( + *inputs + ) + for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): + assertions.assertTrue( + torch.allclose(expected_output, refitted_output, 1e-2, 1e-2), + "Refit Result is not correct. Refit failed", + ) + # Clean up model env + + torch._dynamo.reset() + + @pytest.mark.unit def test_refit_one_engine(): @@ -108,7 +362,8 @@ def test_refit_one_engine(): new_trt_gm = refit_module_weights( compiled_module=trt_gm, new_weight_module=exp_program2, - arg_inputs=inputs, + inputs=inputs, + fast_refit=False, ) # Check the output @@ -154,7 +409,8 @@ def test_refit_one_engine_bert(): new_trt_gm = refit_module_weights( compiled_module=trt_gm, new_weight_module=exp_program2, - arg_inputs=inputs, + inputs=inputs, + fast_refit=False, ) # Check the output @@ -203,7 +459,8 @@ def test_refit_one_engine_inline_runtime(): new_trt_gm = refit_module_weights( compiled_module=trt_gm, new_weight_module=exp_program2, - arg_inputs=inputs, + inputs=inputs, + fast_refit=False, ) # Check the output @@ -247,7 +504,8 @@ def test_refit_one_engine_python_runtime(): new_trt_gm = refit_module_weights( compiled_module=trt_gm, new_weight_module=exp_program2, - arg_inputs=inputs, + inputs=inputs, + fast_refit=False, ) # Check the output @@ -313,7 +571,8 @@ def forward(self, x): new_trt_gm = refit_module_weights( compiled_module=trt_gm, new_weight_module=exp_program2, - arg_inputs=inputs, + inputs=inputs, + fast_refit=False, ) # Check the output From 2958dbd2e6d35307a1363a6b64dd07d7a1300880 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Fri, 5 Jul 2024 16:36:25 -0700 Subject: [PATCH 3/9] Added in_place flag --- py/torch_tensorrt/dynamo/_refit.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 5a6609de3f..67a0c3b6d4 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -208,6 +208,7 @@ def refit_module_weights( kwarg_inputs: Optional[dict[str, Any]] = None, verify_output: bool = False, fast_refit: bool = True, + in_place: bool = False, ) -> torch.fx.GraphModule: """ Refit a compiled graph module with ExportedProgram. This performs weight updates in compiled_module without recompiling the engine. @@ -230,7 +231,12 @@ def refit_module_weights( if len(list(compiled_module.named_children())) == 0: inline_module = True - compiled_module = copy.deepcopy(compiled_module) + if not in_place: + if inline_module: + logger.warning( + "Inplace has no effect on exported program. Please use the returned module as the updated module." + ) + compiled_module = copy.deepcopy(compiled_module) # Get the settings and check the setting to be uniform settings: CompilationSettings = None From 96fb4290a9b48dfc586ccafe338b705abb219ddf Mon Sep 17 00:00:00 2001 From: cehongwang Date: Tue, 9 Jul 2024 11:38:01 -0700 Subject: [PATCH 4/9] Added test cases for no map or wrong map --- py/torch_tensorrt/dynamo/_refit.py | 3 +- tests/py/dynamo/models/test_model_refit.py | 98 ++++++++++++++++++++++ 2 files changed, 100 insertions(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 67a0c3b6d4..b50b638a63 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -440,8 +440,9 @@ def refit_module_weights( settings=settings, weight_name_map=weight_name_map, ) - except AssertionError: + except AssertionError as e: # If fast_refit is used and failed, we fall back to regular refit + logger.warning(e) if fast_refit and weight_name_map: _refit_single_trt_engine_with_gm( new_gm=new_submodule, diff --git a/tests/py/dynamo/models/test_model_refit.py b/tests/py/dynamo/models/test_model_refit.py index b611398e61..29a1cec6bb 100644 --- a/tests/py/dynamo/models/test_model_refit.py +++ b/tests/py/dynamo/models/test_model_refit.py @@ -126,6 +126,104 @@ def test_fast_refit_one_engine(): torch._dynamo.reset() +@pytest.mark.unit +def test_fast_refit_one_engin_no_map(): + + model = models.resnet18(pretrained=False).eval().to("cuda") + model2 = models.resnet18(pretrained=True).eval().to("cuda") + inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] + enabled_precisions = {torch.float} + debug = False + min_block_size = 1 + use_python_runtime = False + + exp_program = torch.export.export(model, tuple(inputs)) + exp_program2 = torch.export.export(model2, tuple(inputs)) + + trt_gm = torchtrt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=use_python_runtime, + enabled_precisions=enabled_precisions, + debug=debug, + min_block_size=min_block_size, + make_refitable=True, + ) + + trt_gm._run_on_acc_0.weight_name_map = None + + new_trt_gm = refit_module_weights( + compiled_module=trt_gm, + new_weight_module=exp_program2, + inputs=inputs, + fast_refit=True, + ) + + # Check the output + expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( + *inputs + ) + for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): + assertions.assertTrue( + torch.allclose(expected_output, refitted_output, 1e-2, 1e-2), + "Refit Result is not correct. Refit failed", + ) + # Clean up model env + + torch._dynamo.reset() + + +@pytest.mark.unit +def test_fast_refit_one_engin_wrong_map(): + + model = models.resnet18(pretrained=False).eval().to("cuda") + model2 = models.resnet18(pretrained=True).eval().to("cuda") + inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] + enabled_precisions = {torch.float} + debug = False + min_block_size = 1 + use_python_runtime = False + + exp_program = torch.export.export(model, tuple(inputs)) + exp_program2 = torch.export.export(model2, tuple(inputs)) + + trt_gm = torchtrt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=use_python_runtime, + enabled_precisions=enabled_precisions, + debug=debug, + min_block_size=min_block_size, + make_refitable=True, + ) + # Manually Deleted all batch norm layer. This suppose to fail the fast refit + trt_gm._run_on_acc_0.weight_name_map = { + k: v + for k, v in trt_gm._run_on_acc_0.weight_name_map.items() + if "[SCALE]" not in k + } + + new_trt_gm = refit_module_weights( + compiled_module=trt_gm, + new_weight_module=exp_program2, + inputs=inputs, + fast_refit=True, + ) + + # Check the output + expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( + *inputs + ) + for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): + assertions.assertTrue( + torch.allclose(expected_output, refitted_output, 1e-2, 1e-2), + "Refit Result is not correct. Refit failed", + ) + # Clean up model env + + torch._dynamo.reset() + + @pytest.mark.unit def test_fast_refit_one_engine_bert(): inputs = [ From 0b34982b6d9c7b9f197db696542e4a0c9d2b4f1b Mon Sep 17 00:00:00 2001 From: cehongwang Date: Fri, 12 Jul 2024 17:02:29 -0700 Subject: [PATCH 5/9] Added comments and fixed some issue --- .../dynamo/conversion/_TRTInterpreter.py | 13 ++++++++++--- tests/py/dynamo/models/test_model_refit.py | 4 ++-- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 74dbc14dd9..28aa62072e 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -323,6 +323,13 @@ def _construct_trt_network_def(self) -> None: ) def _save_weight_mapping(self) -> None: + """ + Construct the weight name mapping from engine weight name to state_dict weight name. + Cache the weight name for future refitting usecases. + Two-stage weight name tracing: + 1. Name transformation from engine weight name to state_dict weight name + 2. Value mapping that, for each weight in INetworkDefinition search for identical weight in state_dict + """ def find_weight( weight_name: str, np_map: dict[str, Any], sd: dict[str, Any] @@ -386,7 +393,7 @@ def check_weight_equal( ) } """ - + # Stage 1: Name mapping sd = self.module.state_dict() weight_name_map: dict[str, Any] = {} np_map = {} @@ -413,6 +420,7 @@ def check_weight_equal( [i for i in sd_weight_name_list[:-1] if i] ) suffix = sd_weight_name_list[-1] + # Retrieve each weight name(s) in state_dict if layer_type == "CONSTANT": if "embedding" in suffix: sd_weight_name = f"{sd_weight_name}.{torch_attr[0]}" @@ -430,7 +438,7 @@ def check_weight_equal( weight_name_map[engine_weight_name] = sd_weight_name np_map[engine_weight_name] = weight - # Value mapping + # Stage 2: Value mapping for engine_weight_name, sd_weight_name in weight_name_map.items(): if "SCALE" in engine_weight_name: # There is no direct connection in batch_norm layer. So skip it @@ -448,7 +456,6 @@ def check_weight_equal( ] self.weight_name_map = weight_name_map - # check = {k:(weight_name_map[k], np_map[k]) for k, v in np_map.items()} def run( self, diff --git a/tests/py/dynamo/models/test_model_refit.py b/tests/py/dynamo/models/test_model_refit.py index 29a1cec6bb..ddcf3682c7 100644 --- a/tests/py/dynamo/models/test_model_refit.py +++ b/tests/py/dynamo/models/test_model_refit.py @@ -127,7 +127,7 @@ def test_fast_refit_one_engine(): @pytest.mark.unit -def test_fast_refit_one_engin_no_map(): +def test_fast_refit_one_engine_no_map(): model = models.resnet18(pretrained=False).eval().to("cuda") model2 = models.resnet18(pretrained=True).eval().to("cuda") @@ -174,7 +174,7 @@ def test_fast_refit_one_engin_no_map(): @pytest.mark.unit -def test_fast_refit_one_engin_wrong_map(): +def test_fast_refit_one_engine_wrong_map(): model = models.resnet18(pretrained=False).eval().to("cuda") model2 = models.resnet18(pretrained=True).eval().to("cuda") From 3defe9748f949506bf2d277391d953de4639dc2d Mon Sep 17 00:00:00 2001 From: cehongwang Date: Thu, 1 Aug 2024 13:55:21 -0700 Subject: [PATCH 6/9] Fixed issues in comments --- py/torch_tensorrt/dynamo/_refit.py | 30 ++++++++++------------ tests/py/dynamo/models/test_model_refit.py | 24 ++++++++--------- 2 files changed, 25 insertions(+), 29 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index b50b638a63..e706aec677 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -207,7 +207,7 @@ def refit_module_weights( arg_inputs: Optional[Tuple[Any, ...]] = None, kwarg_inputs: Optional[dict[str, Any]] = None, verify_output: bool = False, - fast_refit: bool = True, + use_weight_map_cache: bool = True, in_place: bool = False, ) -> torch.fx.GraphModule: """ @@ -232,11 +232,11 @@ def refit_module_weights( inline_module = True if not in_place: - if inline_module: - logger.warning( - "Inplace has no effect on exported program. Please use the returned module as the updated module." - ) compiled_module = copy.deepcopy(compiled_module) + elif inline_module: + raise AssertionError( + "Exported program does not support modifying in place. Please set inplace to false and use the returned graph module." + ) # Get the settings and check the setting to be uniform settings: CompilationSettings = None @@ -254,7 +254,7 @@ def refit_module_weights( ] assert ( encoded_metadata != "" - ), "Settings are not saved in the engine. Please recompile the engine with make_refitable=True." + ), "The engine provided is either not refittable or was built with a version of Torch-TensorRT that is too old, please recompile using the latest version with make_refitable=True" settings = TorchTensorRTModule.decode_metadata(encoded_metadata)["settings"] # Handle torch modules compiled_submodules_map = dict(compiled_submodules) @@ -365,36 +365,32 @@ def refit_module_weights( engine = get_engine_from_encoded_engine( engine_info[ENGINE_IDX], runtime ) - if fast_refit: + if use_weight_map_cache: encoded_metadata = compiled_submodule.__getstate__()[0][ SERIALIZED_METADATA_IDX ] - assert ( - encoded_metadata != "" - ), "Metadata are not saved in the engine. Please recompile the engine with make_refitable=True." weight_name_map = TorchTensorRTModule.decode_metadata( encoded_metadata )["weight_name_map"] if not weight_name_map: - fast_refit = False + use_weight_map_cache = False logger.warning( "Fast refitting is not supported in this module. Use regular refitting." ) else: compiled_submodule = getattr(compiled_module, name) weight_name_map = None - if fast_refit: + if use_weight_map_cache: try: weight_name_map = compiled_submodule.weight_name_map except AttributeError: - fast_refit = False logger.warning( - "You are using a old version of Torch-TensorRT. Please re-compile the engine to avoid failures." + "The module was compiled wit an old version of Torch-TensorRT. Rebuilding the weight map." ) if not weight_name_map: - fast_refit = False + use_weight_map_cache = False logger.warning( - "Fast refitting is not supported in this module. Use regular refitting." + "This engine does not have a weight map cache. Rebuilding the weight map" ) if isinstance(compiled_submodule, PythonTorchTensorRTModule): engine = compiled_submodule.engine @@ -443,7 +439,7 @@ def refit_module_weights( except AssertionError as e: # If fast_refit is used and failed, we fall back to regular refit logger.warning(e) - if fast_refit and weight_name_map: + if use_weight_map_cache and weight_name_map: _refit_single_trt_engine_with_gm( new_gm=new_submodule, old_engine=engine, diff --git a/tests/py/dynamo/models/test_model_refit.py b/tests/py/dynamo/models/test_model_refit.py index ddcf3682c7..4b994c6ad5 100644 --- a/tests/py/dynamo/models/test_model_refit.py +++ b/tests/py/dynamo/models/test_model_refit.py @@ -109,7 +109,7 @@ def test_fast_refit_one_engine(): compiled_module=trt_gm, new_weight_module=exp_program2, inputs=inputs, - fast_refit=True, + use_weight_map_cache=True, ) # Check the output @@ -156,7 +156,7 @@ def test_fast_refit_one_engine_no_map(): compiled_module=trt_gm, new_weight_module=exp_program2, inputs=inputs, - fast_refit=True, + use_weight_map_cache=True, ) # Check the output @@ -207,7 +207,7 @@ def test_fast_refit_one_engine_wrong_map(): compiled_module=trt_gm, new_weight_module=exp_program2, inputs=inputs, - fast_refit=True, + use_weight_map_cache=True, ) # Check the output @@ -254,7 +254,7 @@ def test_fast_refit_one_engine_bert(): compiled_module=trt_gm, new_weight_module=exp_program2, inputs=inputs, - fast_refit=True, + use_weight_map_cache=True, ) # Check the output @@ -304,7 +304,7 @@ def test_fast_refit_one_engine_inline_runtime(): compiled_module=trt_gm, new_weight_module=exp_program2, inputs=inputs, - fast_refit=True, + use_weight_map_cache=True, ) # Check the output @@ -349,7 +349,7 @@ def test_fast_refit_one_engine_python_runtime(): compiled_module=trt_gm, new_weight_module=exp_program2, inputs=inputs, - fast_refit=True, + use_weight_map_cache=True, ) # Check the output @@ -416,7 +416,7 @@ def forward(self, x): compiled_module=trt_gm, new_weight_module=exp_program2, inputs=inputs, - fast_refit=True, + use_weight_map_cache=True, ) # Check the output @@ -461,7 +461,7 @@ def test_refit_one_engine(): compiled_module=trt_gm, new_weight_module=exp_program2, inputs=inputs, - fast_refit=False, + use_weight_map_cache=False, ) # Check the output @@ -508,7 +508,7 @@ def test_refit_one_engine_bert(): compiled_module=trt_gm, new_weight_module=exp_program2, inputs=inputs, - fast_refit=False, + use_weight_map_cache=False, ) # Check the output @@ -558,7 +558,7 @@ def test_refit_one_engine_inline_runtime(): compiled_module=trt_gm, new_weight_module=exp_program2, inputs=inputs, - fast_refit=False, + use_weight_map_cache=False, ) # Check the output @@ -603,7 +603,7 @@ def test_refit_one_engine_python_runtime(): compiled_module=trt_gm, new_weight_module=exp_program2, inputs=inputs, - fast_refit=False, + use_weight_map_cache=False, ) # Check the output @@ -670,7 +670,7 @@ def forward(self, x): compiled_module=trt_gm, new_weight_module=exp_program2, inputs=inputs, - fast_refit=False, + use_weight_map_cache=False, ) # Check the output From 36c0a4c324a6792ad41481ad10ecde7068609109 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Thu, 1 Aug 2024 14:18:19 -0700 Subject: [PATCH 7/9] Fixed a bug of regular engine compilation --- .../dynamo/conversion/_conversion.py | 26 ++++++++++--------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index a62a59905a..8316704c6f 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -132,18 +132,20 @@ def convert_module( runtime = trt.Runtime(TRT_LOGGER) refit_test_engine = runtime.deserialize_cuda_engine(interpreter_result.engine) - weight_name_map: Any = interpreter_result.weight_name_map - try: - _refit_single_trt_engine_with_gm( - new_gm=module, - old_engine=refit_test_engine, - input_list=inputs, - settings=settings, - weight_name_map=weight_name_map, - ) - except AssertionError: - logger.warning("Fast refit test failed. Removing the weight map caching.") - weight_name_map = None + weight_name_map: Any = None + # Do the test refit with cached map if make_refitable is enabled + if settings.make_refitable: + weight_name_map = interpreter_result.weight_name_map + try: + _refit_single_trt_engine_with_gm( + new_gm=module, + old_engine=refit_test_engine, + input_list=inputs, + settings=settings, + weight_name_map=interpreter_result.weight_name_map, + ) + except AssertionError: + logger.warning("Fast refit test failed. Removing the weight map caching.") rt_cls = PythonTorchTensorRTModule From 7235de682028eacbe65ae3acdb54d741f36014ef Mon Sep 17 00:00:00 2001 From: cehongwang Date: Thu, 8 Aug 2024 13:59:32 -0700 Subject: [PATCH 8/9] Fixed bugs after rebase --- py/torch_tensorrt/dynamo/_refit.py | 4 ++-- .../dynamo/conversion/_TRTInterpreter.py | 6 +++-- .../dynamo/conversion/_conversion.py | 6 +++-- .../runtime/_PythonTorchTensorRTModule.py | 4 +--- .../dynamo/runtime/_TorchTensorRTModule.py | 3 +-- tests/py/dynamo/models/test_model_refit.py | 24 +++++++++---------- 6 files changed, 24 insertions(+), 23 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index e706aec677..660cb8a875 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -375,7 +375,7 @@ def refit_module_weights( if not weight_name_map: use_weight_map_cache = False logger.warning( - "Fast refitting is not supported in this module. Use regular refitting." + "This engine does not have a weight map cache. Rebuilding the weight map" ) else: compiled_submodule = getattr(compiled_module, name) @@ -385,7 +385,7 @@ def refit_module_weights( weight_name_map = compiled_submodule.weight_name_map except AttributeError: logger.warning( - "The module was compiled wit an old version of Torch-TensorRT. Rebuilding the weight map." + "The module was compiled with an old version of Torch-TensorRT. Rebuilding the weight map." ) if not weight_name_map: use_weight_map_cache = False diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 28aa62072e..9a3cace599 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -6,6 +6,7 @@ from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Tuple import numpy as np +import tensorrt as trt import torch import torch.fx from torch.fx.node import _get_qualified_name @@ -29,7 +30,6 @@ from torch_tensorrt.fx.observer import Observer from torch_tensorrt.logging import TRT_LOGGER -import tensorrt as trt from packaging import version _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -504,7 +504,9 @@ def run( engine_bytes.write(serialized_engine) engine_str = engine_bytes.getvalue() - return TRTInterpreterResult(engine_str, self._input_names, self._output_names, self.weight_name_map) + return TRTInterpreterResult( + engine_str, self._input_names, self._output_names, self.weight_name_map + ) def run_node(self, n: torch.fx.Node) -> torch.fx.Node: self._cur_node_name = get_node_name(n) diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index 8316704c6f..57fa1749bf 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -131,7 +131,9 @@ def convert_module( from torch_tensorrt.logging import TRT_LOGGER runtime = trt.Runtime(TRT_LOGGER) - refit_test_engine = runtime.deserialize_cuda_engine(interpreter_result.engine) + refit_test_engine = runtime.deserialize_cuda_engine( + interpreter_result.serialized_engine + ) weight_name_map: Any = None # Do the test refit with cached map if make_refitable is enabled if settings.make_refitable: @@ -169,5 +171,5 @@ def convert_module( output_binding_names=list(interpreter_result.output_names), name=name, settings=settings, - weight_name_map = weight_name_map + weight_name_map=weight_name_map, ) diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index e21e83aaac..d5da83488a 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -4,6 +4,7 @@ from contextlib import nullcontext from typing import Any, Dict, List, Optional, Sequence, Tuple +import tensorrt as trt import torch import torch_tensorrt from torch.nn import Module @@ -18,8 +19,6 @@ from torch_tensorrt.dynamo.utils import DYNAMIC_DIM from torch_tensorrt.logging import TRT_LOGGER -import tensorrt as trt - logger = logging.getLogger(__name__) @@ -104,7 +103,6 @@ def __init__( self.settings = settings self.engine = None self.weight_name_map = weight_name_map - self._initialize() if self.serialized_engine is not None and not self.settings.lazy_engine_init: self.setup_engine() diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index d3216177db..fe3974ff96 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -145,8 +145,7 @@ def setup_engine(self) -> None: self.encode_metadata(metadata), ] ) - - + def encode_metadata(self, metadata: Any) -> str: metadata = copy.deepcopy(metadata) metadata["settings"].torch_executed_ops = { diff --git a/tests/py/dynamo/models/test_model_refit.py b/tests/py/dynamo/models/test_model_refit.py index 4b994c6ad5..e803c7fad6 100644 --- a/tests/py/dynamo/models/test_model_refit.py +++ b/tests/py/dynamo/models/test_model_refit.py @@ -108,7 +108,7 @@ def test_fast_refit_one_engine(): new_trt_gm = refit_module_weights( compiled_module=trt_gm, new_weight_module=exp_program2, - inputs=inputs, + arg_inputs=inputs, use_weight_map_cache=True, ) @@ -155,7 +155,7 @@ def test_fast_refit_one_engine_no_map(): new_trt_gm = refit_module_weights( compiled_module=trt_gm, new_weight_module=exp_program2, - inputs=inputs, + arg_inputs=inputs, use_weight_map_cache=True, ) @@ -206,7 +206,7 @@ def test_fast_refit_one_engine_wrong_map(): new_trt_gm = refit_module_weights( compiled_module=trt_gm, new_weight_module=exp_program2, - inputs=inputs, + arg_inputs=inputs, use_weight_map_cache=True, ) @@ -253,7 +253,7 @@ def test_fast_refit_one_engine_bert(): new_trt_gm = refit_module_weights( compiled_module=trt_gm, new_weight_module=exp_program2, - inputs=inputs, + arg_inputs=inputs, use_weight_map_cache=True, ) @@ -303,7 +303,7 @@ def test_fast_refit_one_engine_inline_runtime(): new_trt_gm = refit_module_weights( compiled_module=trt_gm, new_weight_module=exp_program2, - inputs=inputs, + arg_inputs=inputs, use_weight_map_cache=True, ) @@ -348,7 +348,7 @@ def test_fast_refit_one_engine_python_runtime(): new_trt_gm = refit_module_weights( compiled_module=trt_gm, new_weight_module=exp_program2, - inputs=inputs, + arg_inputs=inputs, use_weight_map_cache=True, ) @@ -415,7 +415,7 @@ def forward(self, x): new_trt_gm = refit_module_weights( compiled_module=trt_gm, new_weight_module=exp_program2, - inputs=inputs, + arg_inputs=inputs, use_weight_map_cache=True, ) @@ -460,7 +460,7 @@ def test_refit_one_engine(): new_trt_gm = refit_module_weights( compiled_module=trt_gm, new_weight_module=exp_program2, - inputs=inputs, + arg_inputs=inputs, use_weight_map_cache=False, ) @@ -507,7 +507,7 @@ def test_refit_one_engine_bert(): new_trt_gm = refit_module_weights( compiled_module=trt_gm, new_weight_module=exp_program2, - inputs=inputs, + arg_inputs=inputs, use_weight_map_cache=False, ) @@ -557,7 +557,7 @@ def test_refit_one_engine_inline_runtime(): new_trt_gm = refit_module_weights( compiled_module=trt_gm, new_weight_module=exp_program2, - inputs=inputs, + arg_inputs=inputs, use_weight_map_cache=False, ) @@ -602,7 +602,7 @@ def test_refit_one_engine_python_runtime(): new_trt_gm = refit_module_weights( compiled_module=trt_gm, new_weight_module=exp_program2, - inputs=inputs, + arg_inputs=inputs, use_weight_map_cache=False, ) @@ -669,7 +669,7 @@ def forward(self, x): new_trt_gm = refit_module_weights( compiled_module=trt_gm, new_weight_module=exp_program2, - inputs=inputs, + arg_inputs=inputs, use_weight_map_cache=False, ) From b3aa04f4b8d53d855091eb620480b1ace69afe04 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Fri, 9 Aug 2024 15:45:59 -0700 Subject: [PATCH 9/9] Renamed tests and added test skip --- tests/py/dynamo/models/test_model_refit.py | 71 +++++++++++++++++----- 1 file changed, 57 insertions(+), 14 deletions(-) diff --git a/tests/py/dynamo/models/test_model_refit.py b/tests/py/dynamo/models/test_model_refit.py index e803c7fad6..c642ae0675 100644 --- a/tests/py/dynamo/models/test_model_refit.py +++ b/tests/py/dynamo/models/test_model_refit.py @@ -9,10 +9,9 @@ import torch import torch.nn.functional as F import torch_tensorrt as torchtrt +import torch_tensorrt as torch_trt import torchvision.models as models from torch import nn - -# from torch import nn from torch_tensorrt.dynamo import refit_module_weights from torch_tensorrt.dynamo._refit import ( construct_refit_mapping, @@ -29,6 +28,10 @@ assertions = unittest.TestCase() +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, + "TorchScript Frontend is not available", +) @pytest.mark.unit def test_mapping(): @@ -81,8 +84,12 @@ def test_mapping(): torch._dynamo.reset() +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, + "TorchScript Frontend is not available", +) @pytest.mark.unit -def test_fast_refit_one_engine(): +def test_refit_one_engine_with_weightmap(): model = models.resnet18(pretrained=False).eval().to("cuda") model2 = models.resnet18(pretrained=True).eval().to("cuda") @@ -126,8 +133,12 @@ def test_fast_refit_one_engine(): torch._dynamo.reset() +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, + "TorchScript Frontend is not available", +) @pytest.mark.unit -def test_fast_refit_one_engine_no_map(): +def test_refit_one_engine_no_map_with_weightmap(): model = models.resnet18(pretrained=False).eval().to("cuda") model2 = models.resnet18(pretrained=True).eval().to("cuda") @@ -173,8 +184,12 @@ def test_fast_refit_one_engine_no_map(): torch._dynamo.reset() +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, + "TorchScript Frontend is not available", +) @pytest.mark.unit -def test_fast_refit_one_engine_wrong_map(): +def test_refit_one_engine_with_wrong_weightmap(): model = models.resnet18(pretrained=False).eval().to("cuda") model2 = models.resnet18(pretrained=True).eval().to("cuda") @@ -224,8 +239,12 @@ def test_fast_refit_one_engine_wrong_map(): torch._dynamo.reset() +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, + "TorchScript Frontend is not available", +) @pytest.mark.unit -def test_fast_refit_one_engine_bert(): +def test_refit_one_engine_bert_with_weightmap(): inputs = [ torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"), ] @@ -275,8 +294,12 @@ def test_fast_refit_one_engine_bert(): torch._dynamo.reset() +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, + "TorchScript Frontend is not available", +) @pytest.mark.unit -def test_fast_refit_one_engine_inline_runtime(): +def test_refit_one_engine_inline_runtime__with_weightmap(): trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep") model = models.resnet18(pretrained=False).eval().to("cuda") model2 = models.resnet18(pretrained=True).eval().to("cuda") @@ -322,7 +345,7 @@ def test_fast_refit_one_engine_inline_runtime(): @pytest.mark.unit -def test_fast_refit_one_engine_python_runtime(): +def test_refit_one_engine_python_runtime_with_weightmap(): model = models.resnet18(pretrained=False).eval().to("cuda") model2 = models.resnet18(pretrained=True).eval().to("cuda") @@ -366,8 +389,12 @@ def test_fast_refit_one_engine_python_runtime(): torch._dynamo.reset() +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, + "TorchScript Frontend is not available", +) @pytest.mark.unit -def test_fast_refit_multiple_engine(): +def test_refit_multiple_engine_with_weightmap(): class net(nn.Module): def __init__(self): @@ -433,8 +460,12 @@ def forward(self, x): torch._dynamo.reset() +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, + "TorchScript Frontend is not available", +) @pytest.mark.unit -def test_refit_one_engine(): +def test_refit_one_engine_without_weightmap(): model = models.resnet18(pretrained=False).eval().to("cuda") model2 = models.resnet18(pretrained=True).eval().to("cuda") @@ -478,8 +509,12 @@ def test_refit_one_engine(): torch._dynamo.reset() +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, + "TorchScript Frontend is not available", +) @pytest.mark.unit -def test_refit_one_engine_bert(): +def test_refit_one_engine_bert_without_weightmap(): inputs = [ torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"), ] @@ -529,8 +564,12 @@ def test_refit_one_engine_bert(): torch._dynamo.reset() +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, + "TorchScript Frontend is not available", +) @pytest.mark.unit -def test_refit_one_engine_inline_runtime(): +def test_refit_one_engine_inline_runtime_without_weightmap(): trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep") model = models.resnet18(pretrained=False).eval().to("cuda") model2 = models.resnet18(pretrained=True).eval().to("cuda") @@ -576,7 +615,7 @@ def test_refit_one_engine_inline_runtime(): @pytest.mark.unit -def test_refit_one_engine_python_runtime(): +def test_refit_one_engine_python_runtime_without_weightmap(): model = models.resnet18(pretrained=False).eval().to("cuda") model2 = models.resnet18(pretrained=True).eval().to("cuda") @@ -620,8 +659,12 @@ def test_refit_one_engine_python_runtime(): torch._dynamo.reset() +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, + "TorchScript Frontend is not available", +) @pytest.mark.unit -def test_refit_multiple_engine(): +def test_refit_multiple_engine_without_weightmap(): class net(nn.Module): def __init__(self):