From 234f975787df74fa31a58661dad5c45e9eb2e33d Mon Sep 17 00:00:00 2001 From: MengqingCao Date: Wed, 16 Oct 2024 07:13:45 +0000 Subject: [PATCH 1/2] add npu support --- timm/data/loader.py | 9 ++++++++- timm/utils/distributed.py | 3 +++ train.py | 9 +++++++-- validate.py | 2 ++ 4 files changed, 20 insertions(+), 3 deletions(-) diff --git a/timm/data/loader.py b/timm/data/loader.py index ff61ad56f6..d3300ea8f5 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -114,12 +114,16 @@ def __init__( else: self.random_erasing = None self.is_cuda = torch.cuda.is_available() and device.type == 'cuda' + self.is_npu = torch.npu.is_available() and device.type == 'npu' def __iter__(self): first = True if self.is_cuda: stream = torch.cuda.Stream() stream_context = partial(torch.cuda.stream, stream=stream) + elif self.is_npu: + stream = torch.npu.Stream() + stream_context = partial(torch.npu.stream, stream=stream) else: stream = None stream_context = suppress @@ -139,7 +143,10 @@ def __iter__(self): first = False if stream is not None: - torch.cuda.current_stream().wait_stream(stream) + if self.is_cuda: + torch.cuda.current_stream().wait_stream(stream) + elif self.is_npu: + torch.npu.current_stream().wait_stream(stream) input = next_input target = next_target diff --git a/timm/utils/distributed.py b/timm/utils/distributed.py index 18f526bb83..cca2cdbb89 100644 --- a/timm/utils/distributed.py +++ b/timm/utils/distributed.py @@ -116,6 +116,7 @@ def init_distributed_device_so( "xpu": "ccl", "hpu": "hccl", "cuda": "nccl", + "npu": "hccl", } dist_backend = dist_backends.get(device_type, 'gloo') dist_url = dist_url or 'env://' @@ -159,6 +160,8 @@ def init_distributed_device_so( if device_type == 'cuda': assert torch.cuda.is_available(), f'CUDA is not available but {device} was specified.' + if device_type == 'npu': + assert torch.npu.is_available(), f'Ascend NPU is not available but {device} was specified.' if distributed and device != 'cpu': # Ignore manually specified device index in distributed mode and diff --git a/train.py b/train.py index ebd9bc80b2..63b89b58e5 100755 --- a/train.py +++ b/train.py @@ -1054,8 +1054,11 @@ def _backward(_loss): if model_ema is not None: model_ema.update(model, step=num_updates) - if args.synchronize_step and device.type == 'cuda': - torch.cuda.synchronize() + if args.synchronize_step: + if device.type == 'cuda': + torch.cuda.synchronize() + elif device.type == 'npu': + torch.npu.synchronize() time_now = time.time() update_time_m.update(time.time() - update_start_time) update_start_time = time_now @@ -1155,6 +1158,8 @@ def validate( if device.type == 'cuda': torch.cuda.synchronize() + elif device.type == "npu": + torch.npu.synchronize() losses_m.update(reduced_loss.item(), input.size(0)) top1_m.update(acc1.item(), output.size(0)) diff --git a/validate.py b/validate.py index 6115de7aeb..6623453b36 100755 --- a/validate.py +++ b/validate.py @@ -397,6 +397,8 @@ def _try_run(args, initial_batch_size): try: if torch.cuda.is_available() and 'cuda' in args.device: torch.cuda.empty_cache() + elif torch.npu.is_available() and "npu" in args.device: + torch.npu.empty_cache() results = validate(args) return results except RuntimeError as e: From 37c731ca370e26e1d888f308559ee60a68951779 Mon Sep 17 00:00:00 2001 From: MengqingCao Date: Thu, 17 Oct 2024 12:38:02 +0000 Subject: [PATCH 2/2] fix device check --- timm/data/loader.py | 4 ++-- validate.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/timm/data/loader.py b/timm/data/loader.py index d3300ea8f5..3b4a6d0ed6 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -113,8 +113,8 @@ def __init__( ) else: self.random_erasing = None - self.is_cuda = torch.cuda.is_available() and device.type == 'cuda' - self.is_npu = torch.npu.is_available() and device.type == 'npu' + self.is_cuda = device.type == 'cuda' and torch.cuda.is_available() + self.is_npu = device.type == 'npu' and torch.npu.is_available() def __iter__(self): first = True diff --git a/validate.py b/validate.py index 6623453b36..ce0e4b2541 100755 --- a/validate.py +++ b/validate.py @@ -395,9 +395,9 @@ def _try_run(args, initial_batch_size): while batch_size: args.batch_size = batch_size * args.num_gpu # multiply by num-gpu for DataParallel case try: - if torch.cuda.is_available() and 'cuda' in args.device: + if 'cuda' in args.device and torch.cuda.is_available(): torch.cuda.empty_cache() - elif torch.npu.is_available() and "npu" in args.device: + elif "npu" in args.device and torch.npu.is_available(): torch.npu.empty_cache() results = validate(args) return results