@@ -37,6 +37,7 @@ def __init__(
37
37
class_map : dict = None ,
38
38
input_key : str = 'image' ,
39
39
target_key : str = 'label' ,
40
+ additional_features : Optional [list [str ]] = None ,
40
41
download : bool = False ,
41
42
trust_remote_code : bool = False
42
43
):
@@ -65,9 +66,18 @@ def __init__(
65
66
self .split_info = self .dataset .info .splits [split ]
66
67
self .num_samples = self .split_info .num_examples
67
68
69
+ if isinstance (additional_features , str ):
70
+ self .additional_features = [additional_features ]
71
+ elif isinstance (additional_features , list ):
72
+ self .additional_features = additional_features
73
+ else :
74
+ self .additional_features = []
75
+
68
76
def __getitem__ (self , index ):
69
77
item = self .dataset [index ]
70
78
image = item [self .image_key ]
79
+ features = [item [feat ] for feat in self .additional_features ]
80
+
71
81
if 'bytes' in image and image ['bytes' ]:
72
82
image = io .BytesIO (image ['bytes' ])
73
83
else :
@@ -76,7 +86,8 @@ def __getitem__(self, index):
76
86
label = item [self .label_key ]
77
87
if self .remap_class :
78
88
label = self .class_to_idx [label ]
79
- return image , label
89
+
90
+ return image , label , * features
80
91
81
92
def __len__ (self ):
82
93
return len (self .dataset )
0 commit comments