Skip to content

Commit 512ec4f

Browse files
edpizziBordarohitgr7justusschockawaelchli
committed
Avoid non-blocking GPU->CPU copies. (#11288)
Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: Justus Schock <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]>
1 parent 2aeb339 commit 512ec4f

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
385385
- Fixed data fetcher selection ([#11294](https://github.com/PyTorchLightning/pytorch-lightning/pull/11294))
386386

387387

388+
- Fixed a race condition that could result in incorrect (zero) values being observed in prediction writer callbacks ([#11288](https://github.com/PyTorchLightning/pytorch-lightning/pull/11288))
389+
390+
388391
## [1.5.7] - 2021-12-21
389392

390393
### Fixed

pytorch_lightning/utilities/apply_func.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@
3535
Batch = type(None)
3636

3737

38+
_CPU_DEVICES = ("cpu", torch.device("cpu"))
39+
40+
3841
def to_dtype_tensor(
3942
value: Union[int, float, List[Union[int, float]]], dtype: torch.dtype, device: Union[str, torch.device]
4043
) -> torch.Tensor:
@@ -274,7 +277,10 @@ def batch_to(data: Any) -> Any:
274277
setattr(device_data, field, device_field)
275278
return device_data
276279

277-
kwargs = dict(non_blocking=True) if isinstance(data, torch.Tensor) else {}
280+
kwargs = {}
281+
# Don't issue non-blocking transfers to CPU
282+
if isinstance(data, torch.Tensor) and device not in _CPU_DEVICES:
283+
kwargs["non_blocking"] = True
278284
data_output = data.to(device, **kwargs)
279285
if data_output is not None:
280286
return data_output

0 commit comments

Comments
 (0)