Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion timm/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,16 @@ def __init__(
else:
self.random_erasing = None
self.is_cuda = torch.cuda.is_available() and device.type == 'cuda'
self.is_npu = torch.npu.is_available() and device.type == 'npu'

def __iter__(self):
first = True
if self.is_cuda:
stream = torch.cuda.Stream()
stream_context = partial(torch.cuda.stream, stream=stream)
elif self.is_npu:
stream = torch.npu.Stream()
stream_context = partial(torch.npu.stream, stream=stream)
else:
stream = None
stream_context = suppress
Expand All @@ -139,7 +143,10 @@ def __iter__(self):
first = False

if stream is not None:
torch.cuda.current_stream().wait_stream(stream)
if self.is_cuda:
torch.cuda.current_stream().wait_stream(stream)
elif self.is_npu:
torch.npu.current_stream().wait_stream(stream)

input = next_input
target = next_target
Expand Down
4 changes: 4 additions & 0 deletions timm/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ def init_distributed_device_so(

if 'cuda' in device:
assert torch.cuda.is_available(), f'CUDA is not available but {device} was specified.'
elif 'npu' in device:
assert torch.npu.is_available(), f'NPU is not available but {device} was specified.'

if distributed and device != 'cpu':
device, *device_idx = device.split(':', maxsplit=1)
Expand All @@ -165,6 +167,8 @@ def init_distributed_device_so(

if device.startswith('cuda:'):
torch.cuda.set_device(device)
elif device.startswith('npu:'):
torch.npu.set_device(device)

return dict(
device=device,
Expand Down
4 changes: 4 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -1054,6 +1054,8 @@ def _backward(_loss):

if args.synchronize_step and device.type == 'cuda':
torch.cuda.synchronize()
elif args.synchronize_step and device.type == 'npu':
torch.npu.synchronize()
time_now = time.time()
update_time_m.update(time.time() - update_start_time)
update_start_time = time_now
Expand Down Expand Up @@ -1153,6 +1155,8 @@ def validate(

if device.type == 'cuda':
torch.cuda.synchronize()
elif device.type == 'npu':
torch.npu.synchronize()

losses_m.update(reduced_loss.item(), input.size(0))
top1_m.update(acc1.item(), output.size(0))
Expand Down