@@ -66,28 +66,33 @@ def __init__(
66
66
self .split_info = self .dataset .info .splits [split ]
67
67
self .num_samples = self .split_info .num_examples
68
68
69
- if isinstance (additional_features , str ):
70
- self .additional_features = [additional_features ]
71
- elif isinstance (additional_features , list ):
72
- self .additional_features = additional_features
69
+ if additional_features is not None :
70
+ if isinstance (additional_features , list ):
71
+ self .additional_features = additional_features
72
+ else :
73
+ self .additional_features = [additional_features ]
73
74
else :
74
- self .additional_features = []
75
+ self .additional_features = None
75
76
76
77
def __getitem__ (self , index ):
77
78
item = self .dataset [index ]
78
79
image = item [self .image_key ]
79
- features = [item [feat ] for feat in self .additional_features ]
80
80
81
81
if 'bytes' in image and image ['bytes' ]:
82
82
image = io .BytesIO (image ['bytes' ])
83
83
else :
84
84
assert 'path' in image and image ['path' ]
85
85
image = open (image ['path' ], 'rb' )
86
+
86
87
label = item [self .label_key ]
87
88
if self .remap_class :
88
89
label = self .class_to_idx [label ]
89
90
90
- return image , label , * features
91
+ if self .additional_features is not None :
92
+ features = [item [feat ] for feat in self .additional_features ]
93
+ return image , label , * features
94
+ else :
95
+ return image , label
91
96
92
97
def __len__ (self ):
93
98
return len (self .dataset )
0 commit comments