File tree Expand file tree Collapse file tree 1 file changed +5
-1
lines changed Expand file tree Collapse file tree 1 file changed +5
-1
lines changed Original file line number Diff line number Diff line change @@ -33,6 +33,7 @@ def fast_collate(batch):
33
33
if isinstance(batch[0][0], tuple):
34
34
# This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position
35
35
# such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position
36
+ is_np = isinstance(batch[0][0], np.ndarray)
36
37
inner_tuple_size = len(batch[0][0])
37
38
flattened_batch_size = batch_size * inner_tuple_size
38
39
targets = torch.zeros(flattened_batch_size, dtype=torch.int64)
@@ -41,7 +42,10 @@ def fast_collate(batch):
41
42
assert len(batch[i][0]) == inner_tuple_size # all input tensor tuples must be same length
42
43
for j in range(inner_tuple_size):
43
44
targets[i + j * batch_size] = batch[i][1]
44
- tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j])
45
+ if is_np:
46
+ tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j])
47
+ else:
48
+ tensor[i + j * batch_size] += batch[i][0][j]
45
49
return tensor, targets
46
50
elif isinstance(batch[0][0], np.ndarray):
47
51
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
You can’t perform that action at this time.
0 commit comments