4
4
import pytest
5
5
import torch
6
6
from tests_fabric .helpers .models import RandomDataset , RandomIterableDataset
7
+ from torch import Tensor
7
8
from torch .utils .data import BatchSampler , DataLoader , RandomSampler , SequentialSampler
8
9
9
10
from lightning_fabric .utilities .data import (
@@ -87,7 +88,7 @@ def __init__(self, attribute2, *args, **kwargs):
87
88
88
89
89
90
class MyDataLoader (MyBaseDataLoader ):
90
- def __init__ (self , data : torch . Tensor , * args , ** kwargs ):
91
+ def __init__ (self , data : Tensor , * args , ** kwargs ):
91
92
self .data = data
92
93
super ().__init__ (range (data .size (0 )), * args , ** kwargs )
93
94
@@ -209,7 +210,7 @@ def test_replace_dunder_methods_dataloader(cls, args, kwargs, arg_names, dataset
209
210
210
211
for key , value in checked_values .items ():
211
212
dataloader_value = getattr (dataloader , key )
212
- if isinstance (dataloader_value , torch . Tensor ):
213
+ if isinstance (dataloader_value , Tensor ):
213
214
assert dataloader_value is value
214
215
else :
215
216
assert dataloader_value == value
@@ -227,7 +228,7 @@ def test_replace_dunder_methods_dataloader(cls, args, kwargs, arg_names, dataset
227
228
228
229
for key , value in checked_values .items ():
229
230
dataloader_value = getattr (dataloader , key )
230
- if isinstance (dataloader_value , torch . Tensor ):
231
+ if isinstance (dataloader_value , Tensor ):
231
232
assert dataloader_value is value
232
233
else :
233
234
assert dataloader_value == value
0 commit comments