Skip to content

Commit d0a30b9

Browse files
committed
Assume torch in certain files
Signed-off-by: Yuanyuan Chen <[email protected]>
1 parent bb45d36 commit d0a30b9

File tree

6 files changed

+36
-53
lines changed

6 files changed

+36
-53
lines changed

src/transformers/debug_utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,9 @@
1414

1515
import collections
1616

17-
from .utils import ExplicitEnum, is_torch_available, logging
17+
import torch
1818

19-
20-
if is_torch_available():
21-
import torch
19+
from .utils import ExplicitEnum, logging
2220

2321

2422
logger = logging.get_logger(__name__)

src/transformers/image_processing_utils_fast.py

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from typing import Any, Optional, TypedDict, Union
1919

2020
import numpy as np
21+
import torch
2122

2223
from .image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
2324
from .image_transforms import (
@@ -44,7 +45,6 @@
4445
from .utils import (
4546
TensorType,
4647
auto_docstring,
47-
is_torch_available,
4848
is_torchvision_available,
4949
is_torchvision_v2_available,
5050
is_vision_available,
@@ -56,8 +56,6 @@
5656
if is_vision_available():
5757
from .image_utils import PILImageResampling
5858

59-
if is_torch_available():
60-
import torch
6159

6260
if is_torchvision_available():
6361
from .image_utils import pil_torch_interpolation_mapping
@@ -115,7 +113,7 @@ def validate_fast_preprocess_arguments(
115113
raise ValueError("Only channel first data format is currently supported.")
116114

117115

118-
def safe_squeeze(tensor: "torch.Tensor", axis: Optional[int] = None) -> "torch.Tensor":
116+
def safe_squeeze(tensor: torch.Tensor, axis: Optional[int] = None) -> torch.Tensor:
119117
"""
120118
Squeezes a tensor, but only if the axis specified has dim 1.
121119
"""
@@ -135,7 +133,7 @@ def max_across_indices(values: Iterable[Any]) -> list[Any]:
135133
return [max(values_i) for values_i in zip(*values)]
136134

137135

138-
def get_max_height_width(images: list["torch.Tensor"]) -> tuple[int]:
136+
def get_max_height_width(images: list[torch.Tensor]) -> tuple[int]:
139137
"""
140138
Get the maximum height and width across all images in a batch.
141139
"""
@@ -145,19 +143,17 @@ def get_max_height_width(images: list["torch.Tensor"]) -> tuple[int]:
145143
return (max_height, max_width)
146144

147145

148-
def divide_to_patches(
149-
image: Union[np.array, "torch.Tensor"], patch_size: int
150-
) -> list[Union[np.array, "torch.Tensor"]]:
146+
def divide_to_patches(image: Union[np.array, torch.Tensor], patch_size: int) -> list[Union[np.array, torch.Tensor]]:
151147
"""
152148
Divides an image into patches of a specified size.
153149
154150
Args:
155-
image (`Union[np.array, "torch.Tensor"]`):
151+
image (`Union[np.array, torch.Tensor]`):
156152
The input image.
157153
patch_size (`int`):
158154
The size of each patch.
159155
Returns:
160-
list: A list of Union[np.array, "torch.Tensor"] representing the patches.
156+
list: A list of Union[np.array, torch.Tensor] representing the patches.
161157
"""
162158
patches = []
163159
height, width = get_image_size(image, channel_dim=ChannelDimension.FIRST)
@@ -241,12 +237,12 @@ def is_fast(self) -> bool:
241237

242238
def resize(
243239
self,
244-
image: "torch.Tensor",
240+
image: torch.Tensor,
245241
size: SizeDict,
246242
interpolation: Optional["F.InterpolationMode"] = None,
247243
antialias: bool = True,
248244
**kwargs,
249-
) -> "torch.Tensor":
245+
) -> torch.Tensor:
250246
"""
251247
Resize an image to `(size["height"], size["width"])`.
252248
@@ -295,11 +291,11 @@ def resize(
295291

296292
@staticmethod
297293
def compile_friendly_resize(
298-
image: "torch.Tensor",
294+
image: torch.Tensor,
299295
new_size: tuple[int, int],
300296
interpolation: Optional["F.InterpolationMode"] = None,
301297
antialias: bool = True,
302-
) -> "torch.Tensor":
298+
) -> torch.Tensor:
303299
"""
304300
A wrapper around `F.resize` so that it is compatible with torch.compile when the image is a uint8 tensor.
305301
"""
@@ -316,10 +312,10 @@ def compile_friendly_resize(
316312

317313
def rescale(
318314
self,
319-
image: "torch.Tensor",
315+
image: torch.Tensor,
320316
scale: float,
321317
**kwargs,
322-
) -> "torch.Tensor":
318+
) -> torch.Tensor:
323319
"""
324320
Rescale an image by a scale factor. image = image * scale.
325321
@@ -336,11 +332,11 @@ def rescale(
336332

337333
def normalize(
338334
self,
339-
image: "torch.Tensor",
335+
image: torch.Tensor,
340336
mean: Union[float, Iterable[float]],
341337
std: Union[float, Iterable[float]],
342338
**kwargs,
343-
) -> "torch.Tensor":
339+
) -> torch.Tensor:
344340
"""
345341
Normalize an image. image = (image - image_mean) / image_std.
346342
@@ -376,13 +372,13 @@ def _fuse_mean_std_and_rescale_factor(
376372

377373
def rescale_and_normalize(
378374
self,
379-
images: "torch.Tensor",
375+
images: torch.Tensor,
380376
do_rescale: bool,
381377
rescale_factor: float,
382378
do_normalize: bool,
383379
image_mean: Union[float, list[float]],
384380
image_std: Union[float, list[float]],
385-
) -> "torch.Tensor":
381+
) -> torch.Tensor:
386382
"""
387383
Rescale and normalize images.
388384
"""
@@ -404,16 +400,16 @@ def rescale_and_normalize(
404400

405401
def center_crop(
406402
self,
407-
image: "torch.Tensor",
403+
image: torch.Tensor,
408404
size: dict[str, int],
409405
**kwargs,
410-
) -> "torch.Tensor":
406+
) -> torch.Tensor:
411407
"""
412408
Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `crop_size` along
413409
any edge, the image is padded with 0's and then center cropped.
414410
415411
Args:
416-
image (`"torch.Tensor"`):
412+
image (`torch.Tensor`):
417413
Image to center crop.
418414
size (`dict[str, int]`):
419415
Size of the output image.
@@ -479,7 +475,7 @@ def _process_image(
479475
do_convert_rgb: Optional[bool] = None,
480476
input_data_format: Optional[Union[str, ChannelDimension]] = None,
481477
device: Optional["torch.device"] = None,
482-
) -> "torch.Tensor":
478+
) -> torch.Tensor:
483479
image_type = get_image_type(image)
484480
if image_type not in [ImageType.PIL, ImageType.TORCH, ImageType.NUMPY]:
485481
raise ValueError(f"Unsupported input image type {image_type}")
@@ -518,7 +514,7 @@ def _prepare_image_like_inputs(
518514
input_data_format: Optional[Union[str, ChannelDimension]] = None,
519515
device: Optional["torch.device"] = None,
520516
expected_ndims: int = 3,
521-
) -> list["torch.Tensor"]:
517+
) -> list[torch.Tensor]:
522518
"""
523519
Prepare image-like inputs for processing.
524520
@@ -685,7 +681,7 @@ def _preprocess_image_like_inputs(
685681

686682
def _preprocess(
687683
self,
688-
images: list["torch.Tensor"],
684+
images: list[torch.Tensor],
689685
do_resize: bool,
690686
size: SizeDict,
691687
interpolation: Optional["F.InterpolationMode"],

src/transformers/integrations/higgs.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,15 @@
1616
from math import sqrt
1717
from typing import Optional
1818

19+
import torch
20+
from torch import nn
21+
1922
from ..utils import (
2023
is_flute_available,
2124
is_hadamard_available,
22-
is_torch_available,
2325
)
2426

2527

26-
if is_torch_available():
27-
import torch
28-
from torch import nn
29-
30-
3128
if is_flute_available():
3229
from flute.integrations.higgs import prepare_data_transposed
3330
from flute.tune import TuneMetaData, qgemm_v2

src/transformers/model_debugging_utils.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,17 @@
2121
from io import StringIO
2222
from typing import Optional
2323

24-
from .utils.import_utils import is_torch_available, requires
24+
import torch
25+
from safetensors.torch import save_file
2526

27+
from .utils.import_utils import requires
2628

27-
if is_torch_available():
28-
import torch
29-
from safetensors.torch import save_file
3029

31-
# Note to code inspectors: this toolbox is intended for people who add models to `transformers`.
32-
if torch.distributed.is_available():
33-
import torch.distributed.tensor
30+
# Note to code inspectors: this toolbox is intended for people who add models to `transformers`.
31+
if torch.distributed.is_available():
32+
import torch.distributed.tensor
3433

35-
_torch_distributed_available = True
36-
else:
37-
_torch_distributed_available = False
34+
_torch_distributed_available = True
3835
from .utils import logging
3936

4037

src/transformers/trainer_pt_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
from .tokenization_utils_base import BatchEncoding
4343
from .utils import (
4444
is_sagemaker_mp_enabled,
45-
is_torch_available,
4645
is_torch_xla_available,
4746
is_training_run_on_sagemaker,
4847
logging,
@@ -55,8 +54,7 @@
5554
if is_torch_xla_available():
5655
import torch_xla.runtime as xr
5756

58-
if is_torch_available():
59-
from torch.optim.lr_scheduler import LRScheduler
57+
from torch.optim.lr_scheduler import LRScheduler
6058

6159

6260
logger = logging.get_logger(__name__)

src/transformers/video_processing_utils.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from typing import Any, Callable, Optional, Union
2222

2323
import numpy as np
24+
import torch
2425

2526
from .dynamic_module_utils import custom_object_save
2627
from .image_processing_utils import (
@@ -44,7 +45,6 @@
4445
download_url,
4546
is_offline_mode,
4647
is_remote_url,
47-
is_torch_available,
4848
is_torchcodec_available,
4949
is_torchvision_available,
5050
is_torchvision_v2_available,
@@ -65,9 +65,6 @@
6565
)
6666

6767

68-
if is_torch_available():
69-
import torch
70-
7168
if is_torchvision_available():
7269
if is_torchvision_v2_available():
7370
from torchvision.transforms.v2 import functional as F

0 commit comments

Comments
 (0)