Skip to content

Commit 3f2e559

Browse files
Bordacarmocca
authored andcommitted
simplify torch.Tensor (#16190)
1 parent 6f3e8f9 commit 3f2e559

20 files changed

+63
-47
lines changed

tests/tests_fabric/test_parity.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from lightning_utilities.core.apply_func import apply_to_collection
2626
from tests_fabric.helpers.models import RandomDataset
2727
from tests_fabric.helpers.runif import RunIf
28-
from torch import nn
28+
from torch import nn, Tensor
2929
from torch.nn.parallel.distributed import DistributedDataParallel
3030
from torch.utils.data import DataLoader
3131
from torch.utils.data.distributed import DistributedSampler
@@ -131,7 +131,7 @@ def test_boring_lite_model_single_device(precision, strategy, devices, accelerat
131131
model.load_state_dict(state_dict)
132132
pure_state_dict = main(lite.to_device, model, train_dataloader, num_epochs=num_epochs)
133133

134-
state_dict = apply_to_collection(state_dict, torch.Tensor, lite.to_device)
134+
state_dict = apply_to_collection(state_dict, Tensor, lite.to_device)
135135
for w_pure, w_lite in zip(state_dict.values(), lite_state_dict.values()):
136136
# TODO: This should be torch.equal, but MPS does not yet support this operation (torch 1.12)
137137
assert not torch.allclose(w_pure, w_lite)

tests/tests_fabric/utilities/test_apply_func.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,15 @@
1313
# limitations under the License.
1414
import pytest
1515
import torch
16+
from torch import Tensor
1617

1718
from lightning_fabric.utilities.apply_func import move_data_to_device
1819

1920

2021
@pytest.mark.parametrize("should_return", [False, True])
2122
def test_wrongly_implemented_transferable_data_type(should_return):
2223
class TensorObject:
23-
def __init__(self, tensor: torch.Tensor, should_return: bool = True):
24+
def __init__(self, tensor: Tensor, should_return: bool = True):
2425
self.tensor = tensor
2526
self.should_return = should_return
2627

tests/tests_fabric/utilities/test_data.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55
import torch
66
from tests_fabric.helpers.models import RandomDataset, RandomIterableDataset
7+
from torch import Tensor
78
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
89

910
from lightning_fabric.utilities.data import (
@@ -87,7 +88,7 @@ def __init__(self, attribute2, *args, **kwargs):
8788

8889

8990
class MyDataLoader(MyBaseDataLoader):
90-
def __init__(self, data: torch.Tensor, *args, **kwargs):
91+
def __init__(self, data: Tensor, *args, **kwargs):
9192
self.data = data
9293
super().__init__(range(data.size(0)), *args, **kwargs)
9394

@@ -209,7 +210,7 @@ def test_replace_dunder_methods_dataloader(cls, args, kwargs, arg_names, dataset
209210

210211
for key, value in checked_values.items():
211212
dataloader_value = getattr(dataloader, key)
212-
if isinstance(dataloader_value, torch.Tensor):
213+
if isinstance(dataloader_value, Tensor):
213214
assert dataloader_value is value
214215
else:
215216
assert dataloader_value == value
@@ -227,7 +228,7 @@ def test_replace_dunder_methods_dataloader(cls, args, kwargs, arg_names, dataset
227228

228229
for key, value in checked_values.items():
229230
dataloader_value = getattr(dataloader, key)
230-
if isinstance(dataloader_value, torch.Tensor):
231+
if isinstance(dataloader_value, Tensor):
231232
assert dataloader_value is value
232233
else:
233234
assert dataloader_value == value

tests/tests_fabric/utilities/test_optimizer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import collections
22

33
import torch
4+
from torch import Tensor
45

56
from lightning_fabric.utilities.optimizer import _optimizer_to_device
67

@@ -22,9 +23,9 @@ def __init__(self, *args, **kwargs):
2223
def assert_opt_parameters_on_device(opt, device: str):
2324
for param in opt.state.values():
2425
# Not sure there are any global tensors in the state dict
25-
if isinstance(param, torch.Tensor):
26+
if isinstance(param, Tensor):
2627
assert param.data.device.type == device
2728
elif isinstance(param, collections.Mapping):
2829
for subparam in param.values():
29-
if isinstance(subparam, torch.Tensor):
30+
if isinstance(subparam, Tensor):
3031
assert param.data.device.type == device

tests/tests_pytorch/core/test_metric_result_integration.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import torch
2222
import torchmetrics
2323
from lightning_utilities.test.warning import no_warning_call
24+
from torch import Tensor
2425
from torch.nn import ModuleDict, ModuleList
2526
from torchmetrics import Metric, MetricCollection
2627

@@ -662,7 +663,7 @@ def test_logger_sync_dist(distributed_env, log_val):
662663
# self.log('bar', 0.5, ..., sync_dist=False)
663664
meta = _Metadata("foo", "bar")
664665
meta.sync = _Sync(_should=False)
665-
is_tensor = isinstance(log_val, torch.Tensor)
666+
is_tensor = isinstance(log_val, Tensor)
666667

667668
if not is_tensor:
668669
log_val.update(torch.tensor([0, 1]), torch.tensor([0, 0], dtype=torch.long))

tests/tests_pytorch/helpers/datasets.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from typing import Optional, Sequence, Tuple
2020

2121
import torch
22+
from torch import Tensor
2223
from torch.utils.data import Dataset
2324

2425

@@ -69,7 +70,7 @@ def __init__(
6970
data_file = self.TRAIN_FILE_NAME if self.train else self.TEST_FILE_NAME
7071
self.data, self.targets = self._try_load(os.path.join(self.cached_folder_path, data_file))
7172

72-
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
73+
def __getitem__(self, idx: int) -> Tuple[Tensor, int]:
7374
img = self.data[idx].float().unsqueeze(0)
7475
target = int(self.targets[idx])
7576

@@ -125,7 +126,7 @@ def _try_load(path_data, trials: int = 30, delta: float = 1.0):
125126
return res
126127

127128
@staticmethod
128-
def normalize_tensor(tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0) -> torch.Tensor:
129+
def normalize_tensor(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> Tensor:
129130
mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
130131
std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device)
131132
return tensor.sub(mean).div(std)
@@ -160,7 +161,7 @@ def __init__(self, root: str, num_samples: int = 100, digits: Optional[Sequence]
160161
super().__init__(root, normalize=(0.5, 1.0), **kwargs)
161162

162163
@staticmethod
163-
def _prepare_subset(full_data: torch.Tensor, full_targets: torch.Tensor, num_samples: int, digits: Sequence):
164+
def _prepare_subset(full_data: Tensor, full_targets: Tensor, num_samples: int, digits: Sequence):
164165
classes = {d: 0 for d in digits}
165166
indexes = []
166167
for idx, target in enumerate(full_targets):

tests/tests_pytorch/helpers/deterministic_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import torch
15-
from torch import nn
15+
from torch import nn, Tensor
1616
from torch.utils.data import DataLoader, Dataset
1717

1818
from pytorch_lightning.core.module import LightningModule
@@ -56,7 +56,7 @@ def step(self, batch, batch_idx):
5656

5757
def count_num_graphs(self, result, num_graphs=0):
5858
for k, v in result.items():
59-
if isinstance(v, torch.Tensor) and v.grad_fn is not None:
59+
if isinstance(v, Tensor) and v.grad_fn is not None:
6060
num_graphs += 1
6161
if isinstance(v, dict):
6262
num_graphs += self.count_num_graphs(v)

tests/tests_pytorch/loops/test_evaluation_loop_flow.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"""Tests the evaluation loop."""
1515

1616
import torch
17+
from torch import Tensor
1718

1819
from pytorch_lightning import Trainer
1920
from pytorch_lightning.core.module import LightningModule
@@ -68,7 +69,7 @@ def backward(self, loss, optimizer, optimizer_idx):
6869

6970
assert len(train_step_out) == 1
7071
train_step_out = train_step_out[0][0]
71-
assert isinstance(train_step_out["loss"], torch.Tensor)
72+
assert isinstance(train_step_out["loss"], Tensor)
7273
assert train_step_out["loss"].item() == 171
7374

7475
# make sure the optimizer closure returns the correct things
@@ -129,7 +130,7 @@ def backward(self, loss, optimizer, optimizer_idx):
129130

130131
assert len(train_step_out) == 1
131132
train_step_out = train_step_out[0][0]
132-
assert isinstance(train_step_out["loss"], torch.Tensor)
133+
assert isinstance(train_step_out["loss"], Tensor)
133134
assert train_step_out["loss"].item() == 171
134135

135136
# make sure the optimizer closure returns the correct things

tests/tests_pytorch/loops/test_training_loop_flow_scalar.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import pytest
15-
import torch
1615
from lightning_utilities.test.warning import no_warning_call
16+
from torch import Tensor
1717
from torch.utils.data import DataLoader
1818
from torch.utils.data._utils.collate import default_collate
1919

@@ -151,7 +151,7 @@ def backward(self, loss, optimizer, optimizer_idx):
151151

152152
assert len(train_step_out) == 1
153153
train_step_out = train_step_out[0][0]
154-
assert isinstance(train_step_out["loss"], torch.Tensor)
154+
assert isinstance(train_step_out["loss"], Tensor)
155155
assert train_step_out["loss"].item() == 171
156156

157157
# make sure the optimizer closure returns the correct things
@@ -172,7 +172,7 @@ def training_step(self, batch, batch_idx):
172172
return acc
173173

174174
def training_step_end(self, tr_step_output):
175-
assert isinstance(tr_step_output, torch.Tensor)
175+
assert isinstance(tr_step_output, Tensor)
176176
assert self.count_num_graphs({"loss": tr_step_output}) == 1
177177
self.training_step_end_called = True
178178
return tr_step_output
@@ -221,7 +221,7 @@ def backward(self, loss, optimizer, optimizer_idx):
221221

222222
assert len(train_step_out) == 1
223223
train_step_out = train_step_out[0][0]
224-
assert isinstance(train_step_out["loss"], torch.Tensor)
224+
assert isinstance(train_step_out["loss"], Tensor)
225225
assert train_step_out["loss"].item() == 171
226226

227227
# make sure the optimizer closure returns the correct things

tests/tests_pytorch/models/test_gpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def to(self, *args, **kwargs):
183183

184184
@RunIf(min_cuda_gpus=1)
185185
def test_non_blocking():
186-
"""Tests that non_blocking=True only gets passed on torch.Tensor.to, but not on other objects."""
186+
"""Tests that non_blocking=True only gets passed on Tensor.to, but not on other objects."""
187187
trainer = Trainer()
188188

189189
batch = torch.zeros(2, 3)

0 commit comments

Comments
 (0)