@@ -446,6 +446,14 @@ class Artifact(Dataclass):
446
446
default = None , required = False , also_positional = False
447
447
)
448
448
449
+ def __init_subclass__ (cls , ** kwargs ):
450
+ super ().__init_subclass__ (** kwargs )
451
+ module = inspect .getmodule (cls )
452
+ # standardize module name
453
+ module_name = getattr (module , "__name__" , None )
454
+ if not is_library_module (module_name ):
455
+ cls .register_class ()
456
+
449
457
@classmethod
450
458
def is_possible_identifier (cls , obj ):
451
459
return isinstance (obj , str ) or is_artifact_dict (obj )
@@ -458,18 +466,15 @@ def get_artifact_type(cls):
458
466
if not is_library_module (module_name ):
459
467
non_library_module_warning = f"module named { module_name } is not importable. Class { cls } is thus registered into Artifact.class_register, indexed by { cls .__name__ } , accessible there as long as this class_register lives."
460
468
warnings .warn (non_library_module_warning , ImportWarning , stacklevel = 2 )
461
- cls .register_class (cls )
469
+ cls .register_class ()
462
470
return {"module" : "class_register" , "name" : cls .__name__ }
463
471
if hasattr (cls , "__qualname__" ) and "." in cls .__qualname__ :
464
472
return {"module" : module_name , "name" : cls .__qualname__ }
465
473
return {"module" : module_name , "name" : cls .__name__ }
466
474
467
475
@classmethod
468
- def register_class (cls , artifact_class ):
469
- Artifact ._class_register [artifact_class .__name__ ] = artifact_class
470
-
471
- def __init_subclass__ (cls , ** kwargs ):
472
- super ().__init_subclass__ (** kwargs )
476
+ def register_class (cls ):
477
+ Artifact ._class_register [cls .__name__ ] = cls
473
478
474
479
@classmethod
475
480
def is_artifact_file (cls , path ):
@@ -603,7 +608,7 @@ def maybe_fix_type_to_ensure_instantiation_ability(self):
603
608
not is_library_module (self .__type__ ["module" ])
604
609
or "<locals>" in self .__type__ ["name" ]
605
610
):
606
- self .__class__ .register_class (self . __class__ )
611
+ self .__class__ .register_class ()
607
612
self .__type__ = {
608
613
"module" : "class_register" ,
609
614
"name" : self .__class__ .__name__ ,
0 commit comments