|
1 |
| -import math |
2 |
| -from typing import Any, Optional, Tuple |
| 1 | +from typing import Any, Optional |
3 | 2 |
|
4 | 3 | import numpy as np
|
5 | 4 | import tensorrt as trt
|
6 | 5 | import torch
|
7 |
| -from torch import Tensor |
8 | 6 | from torch.fx.node import Target
|
9 | 7 | from torch_tensorrt.dynamo._SourceIR import SourceIR
|
10 |
| -from torch_tensorrt.fx.converters.converter_utils import ( |
11 |
| - get_trt_plugin, |
12 |
| - mark_as_int8_layer, |
13 |
| - set_layer_name, |
14 |
| -) |
15 |
| -from torch_tensorrt.fx.types import TRTNetwork, TRTPluginFieldCollection, TRTTensor |
| 8 | +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor |
16 | 9 |
|
17 | 10 | from .base import convert_activation
|
18 | 11 |
|
19 | 12 |
|
20 |
| -def gelu( |
21 |
| - network: TRTNetwork, |
22 |
| - target: Target, |
23 |
| - source_ir: Optional[SourceIR], |
24 |
| - name: str, |
25 |
| - input_val: TRTTensor, |
26 |
| - alpha: Optional[Any] = None, |
27 |
| -) -> TRTTensor: |
28 |
| - approximate = alpha |
29 |
| - if approximate is not None: |
30 |
| - raise RuntimeError("GeLU converter currently doesn't support fast gelu compute") |
31 |
| - if not isinstance(input_val, TRTTensor): |
32 |
| - raise RuntimeError( |
33 |
| - f"GELU received input {input_val} that is not part " |
34 |
| - "of the TensorRT region!" |
35 |
| - ) |
36 |
| - if network.has_implicit_batch_dimension: |
37 |
| - raise RuntimeError( |
38 |
| - "GeLU converter currently doesn't support implicit batch dimension" |
39 |
| - ) |
40 |
| - plugin_name = "CustomGeluPluginDynamic" |
41 |
| - # type_id 0 for float32, 1 for float16 |
42 |
| - type_id = trt.PluginField( |
43 |
| - "type_id", np.array(0, dtype=np.int32), trt.PluginFieldType.INT32 |
44 |
| - ) |
45 |
| - field_collection = TRTPluginFieldCollection([type_id]) |
46 |
| - plugin_version = "1" |
47 |
| - |
48 |
| - plugin = get_trt_plugin(plugin_name, field_collection, plugin_version) |
49 |
| - |
50 |
| - layer = network.add_plugin_v2([input_val], plugin) |
51 |
| - |
52 |
| - def gelu_dyn_range_fn( |
53 |
| - dyn_range: Tuple[Tensor, Tensor] |
54 |
| - ) -> Tuple[Tensor, Tensor]: # TODO: This probably will not work with fake tensor |
55 |
| - return ( |
56 |
| - dyn_range[0] * 0.5 * (1.0 + torch.erf(dyn_range[0] / math.sqrt(2.0))) |
57 |
| - ), (dyn_range[1] * 0.5 * (1.0 + torch.erf(dyn_range[0] / math.sqrt(2.0)))) |
58 |
| - |
59 |
| - if input_val.dynamic_range is not None: |
60 |
| - dyn_range = gelu_dyn_range_fn(input_val.dynamic_range) |
61 |
| - mark_as_int8_layer(layer, dyn_range) |
62 |
| - set_layer_name(layer, target, name) |
63 |
| - return layer.get_output(0) |
64 |
| - |
65 |
| - |
66 | 13 | def relu(
|
67 | 14 | network: TRTNetwork,
|
68 | 15 | target: Target,
|
|
0 commit comments