Skip to content

Commit 6327754

Browse files
grodinorwightman
authored andcommitted
Allow user to specify additional features to be returned by Image Dataset when using ReaderHfds
1 parent cedba69 commit 6327754

File tree

3 files changed

+24
-4
lines changed

3 files changed

+24
-4
lines changed

timm/data/dataset.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def __init__(
3030
input_img_mode='RGB',
3131
transform=None,
3232
target_transform=None,
33+
additional_features=None,
3334
**kwargs,
3435
):
3536
if reader is None or isinstance(reader, str):
@@ -38,17 +39,19 @@ def __init__(
3839
root=root,
3940
split=split,
4041
class_map=class_map,
42+
additional_features=additional_features,
4143
**kwargs,
4244
)
4345
self.reader = reader
4446
self.load_bytes = load_bytes
4547
self.input_img_mode = input_img_mode
4648
self.transform = transform
4749
self.target_transform = target_transform
50+
self.additional_features = additional_features
4851
self._consecutive_errors = 0
4952

5053
def __getitem__(self, index):
51-
img, target = self.reader[index]
54+
img, target, *features = self.reader[index]
5255

5356
try:
5457
img = img.read() if self.load_bytes else Image.open(img)
@@ -71,7 +74,10 @@ def __getitem__(self, index):
7174
elif self.target_transform is not None:
7275
target = self.target_transform(target)
7376

74-
return img, target
77+
if self.additional_features is None:
78+
return img, target
79+
else:
80+
return img, target, *features
7581

7682
def __len__(self):
7783
return len(self.reader)

timm/data/readers/reader_factory.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,14 @@ def create_reader(
1919
prefix = name[0]
2020
name = name[-1]
2121

22+
# FIXME the additional features are only supported by ReaderHfds for now.
23+
additional_features = kwargs.pop("additional_features", None)
24+
2225
# FIXME improve the selection right now just tfds prefix or fallback path, will need options to
2326
# explicitly select other options shortly
2427
if prefix == 'hfds':
2528
from .reader_hfds import ReaderHfds # defer Hf datasets import
26-
reader = ReaderHfds(name=name, root=root, split=split, **kwargs)
29+
reader = ReaderHfds(name=name, root=root, split=split, additional_features=additional_features, **kwargs)
2730
elif prefix == 'hfids':
2831
from .reader_hfids import ReaderHfids # defer HF datasets import
2932
reader = ReaderHfids(name=name, root=root, split=split, **kwargs)

timm/data/readers/reader_hfds.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(
3737
class_map: dict = None,
3838
input_key: str = 'image',
3939
target_key: str = 'label',
40+
additional_features: Optional[list[str]] = None,
4041
download: bool = False,
4142
trust_remote_code: bool = False
4243
):
@@ -65,9 +66,18 @@ def __init__(
6566
self.split_info = self.dataset.info.splits[split]
6667
self.num_samples = self.split_info.num_examples
6768

69+
if isinstance(additional_features, str):
70+
self.additional_features = [additional_features]
71+
elif isinstance(additional_features, list):
72+
self.additional_features = additional_features
73+
else:
74+
self.additional_features = []
75+
6876
def __getitem__(self, index):
6977
item = self.dataset[index]
7078
image = item[self.image_key]
79+
features = [item[feat] for feat in self.additional_features]
80+
7181
if 'bytes' in image and image['bytes']:
7282
image = io.BytesIO(image['bytes'])
7383
else:
@@ -76,7 +86,8 @@ def __getitem__(self, index):
7686
label = item[self.label_key]
7787
if self.remap_class:
7888
label = self.class_to_idx[label]
79-
return image, label
89+
90+
return image, label, *features
8091

8192
def __len__(self):
8293
return len(self.dataset)

0 commit comments

Comments
 (0)