18
18
from typing import Any , Optional , TypedDict , Union
19
19
20
20
import numpy as np
21
+ import torch
21
22
22
23
from .image_processing_utils import BaseImageProcessor , BatchFeature , get_size_dict
23
24
from .image_transforms import (
44
45
from .utils import (
45
46
TensorType ,
46
47
auto_docstring ,
47
- is_torch_available ,
48
48
is_torchvision_available ,
49
49
is_torchvision_v2_available ,
50
50
is_vision_available ,
56
56
if is_vision_available ():
57
57
from .image_utils import PILImageResampling
58
58
59
- if is_torch_available ():
60
- import torch
61
59
62
60
if is_torchvision_available ():
63
61
from .image_utils import pil_torch_interpolation_mapping
@@ -115,7 +113,7 @@ def validate_fast_preprocess_arguments(
115
113
raise ValueError ("Only channel first data format is currently supported." )
116
114
117
115
118
- def safe_squeeze (tensor : " torch.Tensor" , axis : Optional [int ] = None ) -> " torch.Tensor" :
116
+ def safe_squeeze (tensor : torch .Tensor , axis : Optional [int ] = None ) -> torch .Tensor :
119
117
"""
120
118
Squeezes a tensor, but only if the axis specified has dim 1.
121
119
"""
@@ -135,7 +133,7 @@ def max_across_indices(values: Iterable[Any]) -> list[Any]:
135
133
return [max (values_i ) for values_i in zip (* values )]
136
134
137
135
138
- def get_max_height_width (images : list [" torch.Tensor" ]) -> tuple [int ]:
136
+ def get_max_height_width (images : list [torch .Tensor ]) -> tuple [int ]:
139
137
"""
140
138
Get the maximum height and width across all images in a batch.
141
139
"""
@@ -145,19 +143,17 @@ def get_max_height_width(images: list["torch.Tensor"]) -> tuple[int]:
145
143
return (max_height , max_width )
146
144
147
145
148
- def divide_to_patches (
149
- image : Union [np .array , "torch.Tensor" ], patch_size : int
150
- ) -> list [Union [np .array , "torch.Tensor" ]]:
146
+ def divide_to_patches (image : Union [np .array , torch .Tensor ], patch_size : int ) -> list [Union [np .array , torch .Tensor ]]:
151
147
"""
152
148
Divides an image into patches of a specified size.
153
149
154
150
Args:
155
- image (`Union[np.array, " torch.Tensor" ]`):
151
+ image (`Union[np.array, torch.Tensor]`):
156
152
The input image.
157
153
patch_size (`int`):
158
154
The size of each patch.
159
155
Returns:
160
- list: A list of Union[np.array, " torch.Tensor" ] representing the patches.
156
+ list: A list of Union[np.array, torch.Tensor] representing the patches.
161
157
"""
162
158
patches = []
163
159
height , width = get_image_size (image , channel_dim = ChannelDimension .FIRST )
@@ -241,12 +237,12 @@ def is_fast(self) -> bool:
241
237
242
238
def resize (
243
239
self ,
244
- image : " torch.Tensor" ,
240
+ image : torch .Tensor ,
245
241
size : SizeDict ,
246
242
interpolation : Optional ["F.InterpolationMode" ] = None ,
247
243
antialias : bool = True ,
248
244
** kwargs ,
249
- ) -> " torch.Tensor" :
245
+ ) -> torch .Tensor :
250
246
"""
251
247
Resize an image to `(size["height"], size["width"])`.
252
248
@@ -295,11 +291,11 @@ def resize(
295
291
296
292
@staticmethod
297
293
def compile_friendly_resize (
298
- image : " torch.Tensor" ,
294
+ image : torch .Tensor ,
299
295
new_size : tuple [int , int ],
300
296
interpolation : Optional ["F.InterpolationMode" ] = None ,
301
297
antialias : bool = True ,
302
- ) -> " torch.Tensor" :
298
+ ) -> torch .Tensor :
303
299
"""
304
300
A wrapper around `F.resize` so that it is compatible with torch.compile when the image is a uint8 tensor.
305
301
"""
@@ -316,10 +312,10 @@ def compile_friendly_resize(
316
312
317
313
def rescale (
318
314
self ,
319
- image : " torch.Tensor" ,
315
+ image : torch .Tensor ,
320
316
scale : float ,
321
317
** kwargs ,
322
- ) -> " torch.Tensor" :
318
+ ) -> torch .Tensor :
323
319
"""
324
320
Rescale an image by a scale factor. image = image * scale.
325
321
@@ -336,11 +332,11 @@ def rescale(
336
332
337
333
def normalize (
338
334
self ,
339
- image : " torch.Tensor" ,
335
+ image : torch .Tensor ,
340
336
mean : Union [float , Iterable [float ]],
341
337
std : Union [float , Iterable [float ]],
342
338
** kwargs ,
343
- ) -> " torch.Tensor" :
339
+ ) -> torch .Tensor :
344
340
"""
345
341
Normalize an image. image = (image - image_mean) / image_std.
346
342
@@ -376,13 +372,13 @@ def _fuse_mean_std_and_rescale_factor(
376
372
377
373
def rescale_and_normalize (
378
374
self ,
379
- images : " torch.Tensor" ,
375
+ images : torch .Tensor ,
380
376
do_rescale : bool ,
381
377
rescale_factor : float ,
382
378
do_normalize : bool ,
383
379
image_mean : Union [float , list [float ]],
384
380
image_std : Union [float , list [float ]],
385
- ) -> " torch.Tensor" :
381
+ ) -> torch .Tensor :
386
382
"""
387
383
Rescale and normalize images.
388
384
"""
@@ -404,16 +400,16 @@ def rescale_and_normalize(
404
400
405
401
def center_crop (
406
402
self ,
407
- image : " torch.Tensor" ,
403
+ image : torch .Tensor ,
408
404
size : dict [str , int ],
409
405
** kwargs ,
410
- ) -> " torch.Tensor" :
406
+ ) -> torch .Tensor :
411
407
"""
412
408
Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `crop_size` along
413
409
any edge, the image is padded with 0's and then center cropped.
414
410
415
411
Args:
416
- image (`" torch.Tensor" `):
412
+ image (`torch.Tensor`):
417
413
Image to center crop.
418
414
size (`dict[str, int]`):
419
415
Size of the output image.
@@ -479,7 +475,7 @@ def _process_image(
479
475
do_convert_rgb : Optional [bool ] = None ,
480
476
input_data_format : Optional [Union [str , ChannelDimension ]] = None ,
481
477
device : Optional ["torch.device" ] = None ,
482
- ) -> " torch.Tensor" :
478
+ ) -> torch .Tensor :
483
479
image_type = get_image_type (image )
484
480
if image_type not in [ImageType .PIL , ImageType .TORCH , ImageType .NUMPY ]:
485
481
raise ValueError (f"Unsupported input image type { image_type } " )
@@ -518,7 +514,7 @@ def _prepare_image_like_inputs(
518
514
input_data_format : Optional [Union [str , ChannelDimension ]] = None ,
519
515
device : Optional ["torch.device" ] = None ,
520
516
expected_ndims : int = 3 ,
521
- ) -> list [" torch.Tensor" ]:
517
+ ) -> list [torch .Tensor ]:
522
518
"""
523
519
Prepare image-like inputs for processing.
524
520
@@ -685,7 +681,7 @@ def _preprocess_image_like_inputs(
685
681
686
682
def _preprocess (
687
683
self ,
688
- images : list [" torch.Tensor" ],
684
+ images : list [torch .Tensor ],
689
685
do_resize : bool ,
690
686
size : SizeDict ,
691
687
interpolation : Optional ["F.InterpolationMode" ],
0 commit comments