|
2 | 2 | from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable
|
3 | 3 | from .models import create_model, list_models, list_pretrained, is_model, list_modules, model_entrypoint, \
|
4 | 4 | is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value
|
| 5 | +import os |
| 6 | +import json |
| 7 | +import importlib |
| 8 | +import sys |
| 9 | + |
| 10 | +# import device specific accelerator module |
| 11 | +device_extension_info = open("./timm/device_extension.json", 'r') |
| 12 | +device_extension_info = json.load(device_extension_info) |
| 13 | + |
| 14 | +os_var = "" |
| 15 | +for device_key in device_extension_info.keys(): |
| 16 | + os_var_modules = device_extension_info[device_key] |
| 17 | + os_var += device_key + ':' |
| 18 | + |
| 19 | + for module in os_var_modules: |
| 20 | + os_var += module + ':' |
| 21 | + os_var = os_var[:-1] |
| 22 | + os_var += ',' |
| 23 | + |
| 24 | +os.environ["DEVICE_EXT"] = os_var[:-1] |
| 25 | + |
| 26 | +if os.getenv('DEVICE_EXT'): |
| 27 | + this_module = sys.modules[__name__] |
| 28 | + backends = os.getenv('DEVICE_EXT').split(',') |
| 29 | + for backend in backends: |
| 30 | + module_info = backend.split(':') |
| 31 | + module_name = module_info[1].strip() |
| 32 | + module_alias = list() |
| 33 | + if len(module_info) > 2: |
| 34 | + for i in range(2, len(module_info)): |
| 35 | + module_alias.append(module_info[i].strip()) |
| 36 | + try: |
| 37 | + extra_module = importlib.import_module(module_name) |
| 38 | + for alia in module_alias: |
| 39 | + setattr(this_module, alia, extra_module) |
| 40 | + print(module_alias) |
| 41 | + except ImportError: |
| 42 | + pass |
0 commit comments