|
| 1 | +import os |
| 2 | +import torch |
| 3 | +import torch.utils.data as data |
| 4 | +import torchvision |
| 5 | +from torchvision.datasets import ImageFolder |
| 6 | +from torchvision.datasets import CIFAR10 |
| 7 | +from torchvision.datasets.utils import download_url |
| 8 | + |
| 9 | +class StanfordOnlineProducts(ImageFolder, CIFAR10): |
| 10 | + base_folder = 'Stanford_Online_Products' |
| 11 | + url = 'ftp://cs.stanford.edu/cs/cvgl/Stanford_Online_Products.zip' |
| 12 | + filename = 'Stanford_Online_Products.zip' |
| 13 | + zip_md5 = '7f73d41a2f44250d4779881525aea32e' |
| 14 | + |
| 15 | + train_list = [ |
| 16 | + ['bicycle_final/111265328556_0.JPG', '77420a4db9dd9284378d7287a0729edb'] |
| 17 | + ['chair_final/111182689872_0.JPG', 'ce78d10ed68560f4ea5fa1bec90206ba'] |
| 18 | + ] |
| 19 | + test_list = [ |
| 20 | + ['table_final/111194782300_0.JPG', '8203e079b5c134161bbfa7ee2a43a0a1'], |
| 21 | + ['toaster_final/111157129195_0.JPG', 'd6c24ee8c05d986cafffa6af82ae224e'] |
| 22 | + ] |
| 23 | + |
| 24 | + def __init__(self, root, train=None, transform=None, target_transform=None, download=False, **kwargs): |
| 25 | + self.root = root |
| 26 | + if download: |
| 27 | + self.download() |
| 28 | + |
| 29 | + if not self._check_integrity(): |
| 30 | + raise RuntimeError('Dataset not found or corrupted.' + |
| 31 | + ' You can use download=True to download it') |
| 32 | + |
| 33 | + def download(self): |
| 34 | + import zipfile |
| 35 | + |
| 36 | + if self._check_integrity(): |
| 37 | + print('Files already downloaded and verified') |
| 38 | + return |
| 39 | + |
| 40 | + root = self.root |
| 41 | + download_url(self.url, root, self.filename, self.zip_md5) |
| 42 | + |
| 43 | + # extract file |
| 44 | + cwd = os.getcwd() |
| 45 | + zip = zipfile.open(os.path.join(root, self.filename), "r") |
| 46 | + os.chdir(root) |
| 47 | + zip.extractall() |
| 48 | + zip.close() |
| 49 | + os.chdir(cwd) |
0 commit comments