@@ -34,7 +34,8 @@ std::unordered_map<const torch::jit::Value*, torch::jit::IValue> generateRandomI
34
34
35
35
void getSegmentsOutputByRunning (
36
36
SegmentedBlock& seg_block,
37
- std::unordered_map<const torch::jit::Value*, torch::jit::IValue>& ivalues_maps) {
37
+ std::unordered_map<const torch::jit::Value*, torch::jit::IValue>& ivalues_maps,
38
+ const PartitionInfo& partition_info) {
38
39
// create a module to run the graph
39
40
auto g = seg_block.g ();
40
41
auto copy_g = g->copy ();
@@ -108,7 +109,28 @@ void getSegmentsOutputByRunning(
108
109
std::vector<at::ScalarType> input_types;
109
110
for (auto & i : seg_block.raw_inputs ()) {
110
111
if (ivalues_maps[i].isTensor ()) {
111
- input_shapes.push_back (util::toVec (util::toDims (ivalues_maps[i].toTensor ().sizes ())));
112
+ // set the input_shape and data_type
113
+ at::ScalarType t = ivalues_maps[i].toTensor ().scalar_type ();
114
+ if (!partition_info.truncate_long_and_double && (t == at::kLong || t == at::kDouble )) {
115
+ TORCHTRT_THROW_ERROR (
116
+ " Unable to process subgraph input type of at::kLong/at::kDouble, try to compile model with truncate_long_and_double enabled" );
117
+ } else if (partition_info.truncate_long_and_double && t == at::kLong ) {
118
+ ivalues_maps[i] = ivalues_maps[i].toTensor ().to (at::kInt );
119
+ LOG_WARNING (" Truncating graph input type from at::kLong to at::kInt" );
120
+ } else if (partition_info.truncate_long_and_double && t == at::kDouble ) {
121
+ ivalues_maps[i] = ivalues_maps[i].toTensor ().to (at::kFloat );
122
+ LOG_WARNING (" Truncating graph input type from at::kDouble to at::kFloat" );
123
+ }
124
+ c10::optional<nvinfer1::DataType> dtype = util::optTypeMetaToTRTDataType (ivalues_maps[i].toTensor ().dtype ());
125
+ if (dtype == c10::nullopt ) {
126
+ TORCHTRT_THROW_ERROR (" Unsupported input data type " << ivalues_maps[i].toTensor ().dtype ());
127
+ }
128
+ if (ivalues_maps[i].toTensor ().sizes ().size () == 0 ) {
129
+ // handle Scalar types, which has sizes of []
130
+ input_shapes.push_back (util::toVec (util::toDims (c10::List<long int >({1 }))));
131
+ } else {
132
+ input_shapes.push_back (util::toVec (util::toDims (ivalues_maps[i].toTensor ().sizes ())));
133
+ }
112
134
input_types.push_back (ivalues_maps[i].toTensor ().scalar_type ());
113
135
}
114
136
}
@@ -119,11 +141,12 @@ void getSegmentsOutputByRunning(
119
141
120
142
void runShapeAnalysis (
121
143
std::vector<SegmentedBlock>& segmented_blocks,
122
- std::unordered_map<const torch::jit::Value*, torch::jit::IValue>& example_tensor_map) {
144
+ std::unordered_map<const torch::jit::Value*, torch::jit::IValue>& example_tensor_map,
145
+ const PartitionInfo& partition_info) {
123
146
// register every segment's input shape, and it's running output IValues
124
147
for (auto & seg_block : segmented_blocks) {
125
148
torch::jit::ConstantPooling (seg_block.g ());
126
- getSegmentsOutputByRunning (seg_block, example_tensor_map);
149
+ getSegmentsOutputByRunning (seg_block, example_tensor_map, partition_info );
127
150
}
128
151
return ;
129
152
}
0 commit comments