Skip to content

Commit f0369e0

Browse files
committed
load npu accelerating module
1 parent 801727d commit f0369e0

File tree

9 files changed

+53
-33
lines changed

9 files changed

+53
-33
lines changed

benchmark.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,6 @@
2424
from timm.optim import create_optimizer_v2
2525
from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry, ParseKwargs,\
2626
reparameterize_model
27-
from timm.utils.distributed import is_torch_npu_available
28-
29-
has_torch_npu = is_torch_npu_available()
3027

3128
has_apex = False
3229
try:

inference.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,6 @@
2121
from timm.layers import apply_test_time_pool
2222
from timm.models import create_model
2323
from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser, ParseKwargs
24-
from timm.utils.distributed import is_torch_npu_available
25-
26-
has_torch_npu = is_torch_npu_available()
2724

2825

2926
try:

timm/__init__.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,41 @@
22
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable
33
from .models import create_model, list_models, list_pretrained, is_model, list_modules, model_entrypoint, \
44
is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value
5+
import os
6+
import json
7+
import importlib
8+
import sys
9+
10+
# import device specific accelerator module
11+
device_extension_info = open("./timm/device_extension.json", 'r')
12+
device_extension_info = json.load(device_extension_info)
13+
14+
os_var = ""
15+
for device_key in device_extension_info.keys():
16+
os_var_modules = device_extension_info[device_key]
17+
os_var += device_key + ':'
18+
19+
for module in os_var_modules:
20+
os_var += module + ':'
21+
os_var = os_var[:-1]
22+
os_var += ','
23+
24+
os.environ["DEVICE_EXT"] = os_var[:-1]
25+
26+
if os.getenv('DEVICE_EXT'):
27+
this_module = sys.modules[__name__]
28+
backends = os.getenv('DEVICE_EXT').split(',')
29+
for backend in backends:
30+
module_info = backend.split(':')
31+
module_name = module_info[1].strip()
32+
module_alias = list()
33+
if len(module_info) > 2:
34+
for i in range(2, len(module_info)):
35+
module_alias.append(module_info[i].strip())
36+
try:
37+
extra_module = importlib.import_module(module_name)
38+
for alia in module_alias:
39+
setattr(this_module, alia, extra_module)
40+
print(module_alias)
41+
except ImportError:
42+
pass

timm/device_extension.json

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
"ascend_npu_modules": ["torch_npu"]
3+
}

timm/models/_builder.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from timm.models._pretrained import PretrainedCfg
1616
from timm.models._prune import adapt_model_from_file
1717
from timm.models._registry import get_pretrained_cfg
18-
from timm.utils.distributed import is_torch_npu_available
1918

2019
_logger = logging.getLogger(__name__)
2120

@@ -24,7 +23,6 @@
2423
_DOWNLOAD_PROGRESS = False
2524
_CHECK_HASH = False
2625
_USE_OLD_CACHE = int(os.environ.get('TIMM_USE_OLD_CACHE', 0)) > 0
27-
has_torch_npu = is_torch_npu_available()
2826

2927
__all__ = ['set_pretrained_download_progress', 'set_pretrained_check_hash', 'load_custom_pretrained', 'load_pretrained',
3028
'pretrained_cfg_for_features', 'resolve_pretrained_cfg', 'build_model_with_cfg']

timm/models/_factory.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,6 @@
77
from ._hub import load_model_config_from_hf
88
from ._pretrained import PretrainedCfg
99
from ._registry import is_model, model_entrypoint, split_model_name_tag
10-
from timm.utils.distributed import is_torch_npu_available
11-
12-
has_torch_npu = is_torch_npu_available()
1310

1411

1512
__all__ = ['parse_model_name', 'safe_model_name', 'create_model']

timm/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from .cuda import ApexScaler, NativeScaler
55
from .decay_batch import decay_batch_step, check_batch_size_retry
66
from .distributed import distribute_bn, reduce_tensor, init_distributed_device,\
7-
world_info_from_env, is_distributed_env, is_primary, is_torch_npu_available
7+
world_info_from_env, is_distributed_env, is_primary
88
from .jit import set_jit_legacy, set_jit_fuser
99
from .log import setup_default_logging, FormatterNoInfo
1010
from .metrics import AverageMeter, accuracy

timm/utils/distributed.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
33
Hacked together by / Copyright 2020 Ross Wightman
44
"""
5-
import importlib
6-
import importlib.metadata as imp_meta
75
import logging
86
import os
97
from typing import Optional
@@ -48,19 +46,6 @@ def is_primary(args, local=False):
4846
return is_local_primary(args) if local else is_global_primary(args)
4947

5048

51-
def is_torch_npu_available():
52-
_torch_npu_available = importlib.util.find_spec("torch_npu") is not None
53-
if _torch_npu_available:
54-
try:
55-
torch_npu_version = imp_meta.version("torch_npu")
56-
import torch_npu # noqa: F401
57-
torch.npu.set_device(0)
58-
_logger.info(f"torch_npu version {torch_npu_version} is available.")
59-
except ImportError:
60-
_torch_npu_available = False
61-
return _torch_npu_available
62-
63-
6449
def is_distributed_env():
6550
if 'WORLD_SIZE' in os.environ:
6651
return int(os.environ['WORLD_SIZE']) > 1
@@ -170,7 +155,11 @@ def init_distributed_device_so(
170155
if 'cuda' in device:
171156
assert torch.cuda.is_available(), f'CUDA is not available but {device} was specified.'
172157
if 'npu' in device:
173-
assert is_torch_npu_available(), f'NPU is not available but {device} was specified.'
158+
try:
159+
TORCH_NPU_AVAILABLE = torch.npu.is_available()
160+
assert TORCH_NPU_AVAILABLE, f'NPU is not available but {device} was specified.'
161+
except ImportError:
162+
_logger.info(f"NPU is not available but {device} was specified.")
174163

175164
if distributed and device != 'cpu':
176165
device, *device_idx = device.split(':', maxsplit=1)

validate.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,6 @@
2727
from timm.models import create_model, load_checkpoint, is_model, list_models
2828
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser, \
2929
decay_batch_step, check_batch_size_retry, ParseKwargs, reparameterize_model
30-
from timm.utils.distributed import is_torch_npu_available
31-
32-
has_torch_npu = is_torch_npu_available()
3330

3431
try:
3532
from apex import amp
@@ -399,8 +396,12 @@ def _try_run(args, initial_batch_size):
399396
try:
400397
if torch.cuda.is_available() and 'cuda' in args.device:
401398
torch.cuda.empty_cache()
402-
if torch.npu.is_available() and 'npu' in args.device:
403-
torch.npu.empty_cache()
399+
if 'npu' in args.device:
400+
try:
401+
torch.npu.is_available()
402+
torch.npu.empty_cache()
403+
except ImportError:
404+
_logger.info("NPU is not available but {args.device} was specified.")
404405
results = validate(args)
405406
return results
406407
except RuntimeError as e:

0 commit comments

Comments
 (0)