Skip to content

Commit 910cf06

Browse files
committed
delete gelu
1 parent e878a8c commit 910cf06

File tree

2 files changed

+2
-72
lines changed

2 files changed

+2
-72
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -152,23 +152,6 @@ def aten_ops_fmod(
152152
return impl.elementwise.fmod(network, target, SourceIR.ATEN, name, args[0], args[1])
153153

154154

155-
@dynamo_tensorrt_converter(torch.ops.aten.gelu.default) # type: ignore[misc]
156-
def aten_ops_gelu(
157-
network: TRTNetwork,
158-
target: Target,
159-
args: Tuple[Argument, ...],
160-
kwargs: Dict[str, Argument],
161-
name: str,
162-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
163-
return impl.activation.gelu(
164-
network,
165-
target,
166-
SourceIR.ATEN,
167-
name,
168-
args[0],
169-
)
170-
171-
172155
@dynamo_tensorrt_converter(torch.ops.aten.relu.default)
173156
def aten_ops_relu(
174157
network: TRTNetwork,

py/torch_tensorrt/dynamo/conversion/impl/activation/ops.py

Lines changed: 2 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,15 @@
1-
import math
2-
from typing import Any, Optional, Tuple
1+
from typing import Any, Optional
32

43
import numpy as np
54
import tensorrt as trt
65
import torch
7-
from torch import Tensor
86
from torch.fx.node import Target
97
from torch_tensorrt.dynamo._SourceIR import SourceIR
10-
from torch_tensorrt.fx.converters.converter_utils import (
11-
get_trt_plugin,
12-
mark_as_int8_layer,
13-
set_layer_name,
14-
)
15-
from torch_tensorrt.fx.types import TRTNetwork, TRTPluginFieldCollection, TRTTensor
8+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
169

1710
from .base import convert_activation
1811

1912

20-
def gelu(
21-
network: TRTNetwork,
22-
target: Target,
23-
source_ir: Optional[SourceIR],
24-
name: str,
25-
input_val: TRTTensor,
26-
alpha: Optional[Any] = None,
27-
) -> TRTTensor:
28-
approximate = alpha
29-
if approximate is not None:
30-
raise RuntimeError("GeLU converter currently doesn't support fast gelu compute")
31-
if not isinstance(input_val, TRTTensor):
32-
raise RuntimeError(
33-
f"GELU received input {input_val} that is not part "
34-
"of the TensorRT region!"
35-
)
36-
if network.has_implicit_batch_dimension:
37-
raise RuntimeError(
38-
"GeLU converter currently doesn't support implicit batch dimension"
39-
)
40-
plugin_name = "CustomGeluPluginDynamic"
41-
# type_id 0 for float32, 1 for float16
42-
type_id = trt.PluginField(
43-
"type_id", np.array(0, dtype=np.int32), trt.PluginFieldType.INT32
44-
)
45-
field_collection = TRTPluginFieldCollection([type_id])
46-
plugin_version = "1"
47-
48-
plugin = get_trt_plugin(plugin_name, field_collection, plugin_version)
49-
50-
layer = network.add_plugin_v2([input_val], plugin)
51-
52-
def gelu_dyn_range_fn(
53-
dyn_range: Tuple[Tensor, Tensor]
54-
) -> Tuple[Tensor, Tensor]: # TODO: This probably will not work with fake tensor
55-
return (
56-
dyn_range[0] * 0.5 * (1.0 + torch.erf(dyn_range[0] / math.sqrt(2.0)))
57-
), (dyn_range[1] * 0.5 * (1.0 + torch.erf(dyn_range[0] / math.sqrt(2.0))))
58-
59-
if input_val.dynamic_range is not None:
60-
dyn_range = gelu_dyn_range_fn(input_val.dynamic_range)
61-
mark_as_int8_layer(layer, dyn_range)
62-
set_layer_name(layer, target, name)
63-
return layer.get_output(0)
64-
65-
6613
def relu(
6714
network: TRTNetwork,
6815
target: Target,

0 commit comments

Comments
 (0)