@@ -23,44 +23,52 @@ def test_compile_traced(self):
23
23
"enabled_precisions" : {torch .float }
24
24
}
25
25
26
- trt_mod = trtorch .compile (self .traced_model , compile_spec )
26
+ trt_mod = trtorch .compile (self .traced_model , ** compile_spec )
27
27
same = (trt_mod (self .input ) - self .traced_model (self .input )).abs ().max ()
28
28
self .assertTrue (same < 2e-2 )
29
29
30
30
def test_compile_script (self ):
31
+ trt_mod = trtorch .compile (self .scripted_model , inputs = [self .input ], device = trtorch .Device (gpu_id = 0 ), enabled_precisions = {torch .float })
32
+ same = (trt_mod (self .input ) - self .scripted_model (self .input )).abs ().max ()
33
+ self .assertTrue (same < 2e-2 )
34
+
35
+ def test_from_torch_tensor (self ):
31
36
compile_spec = {
32
- "inputs" : [trtorch . Input ( shape = self .input . shape ) ],
37
+ "inputs" : [self .input ],
33
38
"device" : {
34
39
"device_type" : trtorch .DeviceType .GPU ,
35
40
"gpu_id" : 0 ,
36
41
},
37
42
"enabled_precisions" : {torch .float }
38
43
}
39
44
40
- trt_mod = trtorch .compile (self .scripted_model , compile_spec )
45
+ trt_mod = trtorch .compile (self .scripted_model , ** compile_spec )
41
46
same = (trt_mod (self .input ) - self .scripted_model (self .input )).abs ().max ()
42
47
self .assertTrue (same < 2e-2 )
43
48
44
- def test_from_torch_tensor (self ):
49
+ def test_device (self ):
50
+ compile_spec = {"inputs" : [self .input ], "device" : trtorch .Device ("gpu:0" ), "enabled_precisions" : {torch .float }}
51
+
52
+ trt_mod = trtorch .compile (self .scripted_model , ** compile_spec )
53
+ same = (trt_mod (self .input ) - self .scripted_model (self .input )).abs ().max ()
54
+ self .assertTrue (same < 2e-2 )
55
+
56
+
57
+ def test_compile_script_from_dict (self ):
45
58
compile_spec = {
46
- "inputs" : [self .input ],
59
+ "inputs" : [trtorch . Input ( shape = self .input . shape ) ],
47
60
"device" : {
48
61
"device_type" : trtorch .DeviceType .GPU ,
49
62
"gpu_id" : 0 ,
50
63
},
51
64
"enabled_precisions" : {torch .float }
52
65
}
53
66
54
- trt_mod = trtorch .compile (self .scripted_model , compile_spec )
55
- same = (trt_mod (self .input ) - self .scripted_model (self .input )).abs ().max ()
67
+ trt_mod = trtorch .compile (self .traced_model , ** compile_spec )
68
+ same = (trt_mod (self .input ) - self .traced_model (self .input )).abs ().max ()
56
69
self .assertTrue (same < 2e-2 )
57
70
58
- def test_device (self ):
59
- compile_spec = {"inputs" : [self .input ], "device" : trtorch .Device ("gpu:0" ), "enabled_precisions" : {torch .float }}
60
71
61
- trt_mod = trtorch .compile (self .scripted_model , compile_spec )
62
- same = (trt_mod (self .input ) - self .scripted_model (self .input )).abs ().max ()
63
- self .assertTrue (same < 2e-2 )
64
72
65
73
66
74
class TestCompileHalf (ModelTestCase ):
@@ -80,7 +88,7 @@ def test_compile_script_half(self):
80
88
"enabled_precisions" : {torch .half }
81
89
}
82
90
83
- trt_mod = trtorch .compile (self .scripted_model , compile_spec )
91
+ trt_mod = trtorch .compile (self .scripted_model , ** compile_spec )
84
92
same = (trt_mod (self .input .half ()) - self .scripted_model (self .input .half ())).abs ().max ()
85
93
trtorch .logging .log (trtorch .logging .Level .Debug , "Max diff: " + str (same ))
86
94
self .assertTrue (same < 3e-2 )
@@ -103,7 +111,7 @@ def test_compile_script_half_by_default(self):
103
111
"enabled_precisions" : {torch .float , torch .half }
104
112
}
105
113
106
- trt_mod = trtorch .compile (self .scripted_model , compile_spec )
114
+ trt_mod = trtorch .compile (self .scripted_model , ** compile_spec )
107
115
same = (trt_mod (self .input .half ()) - self .scripted_model (self .input .half ())).abs ().max ()
108
116
trtorch .logging .log (trtorch .logging .Level .Debug , "Max diff: " + str (same ))
109
117
self .assertTrue (same < 3e-2 )
@@ -132,7 +140,7 @@ def test_compile_script(self):
132
140
}
133
141
}
134
142
135
- trt_mod = trtorch .compile (self .scripted_model , compile_spec )
143
+ trt_mod = trtorch .compile (self .scripted_model , ** compile_spec )
136
144
same = (trt_mod (self .input ) - self .scripted_model (self .input )).abs ().max ()
137
145
self .assertTrue (same < 2e-3 )
138
146
@@ -160,7 +168,7 @@ def test_compile_script(self):
160
168
}
161
169
}
162
170
163
- trt_mod = trtorch .compile (self .scripted_model , compile_spec )
171
+ trt_mod = trtorch .compile (self .scripted_model , ** compile_spec )
164
172
same = (trt_mod (self .input ) - self .scripted_model (self .input )).abs ().max ()
165
173
self .assertTrue (same < 2e-3 )
166
174
@@ -183,7 +191,7 @@ def test_pt_to_trt_to_pt(self):
183
191
}
184
192
}
185
193
186
- trt_engine = trtorch .convert_method_to_trt_engine (self .ts_model , "forward" , compile_spec )
194
+ trt_engine = trtorch .convert_method_to_trt_engine (self .ts_model , "forward" , ** compile_spec )
187
195
trt_mod = trtorch .embed_engine_in_new_module (trt_engine , trtorch .Device ("cuda:0" ))
188
196
same = (trt_mod (self .input ) - self .ts_model (self .input )).abs ().max ()
189
197
self .assertTrue (same < 2e-3 )
0 commit comments