Skip to content

Commit a3c2207

Browse files
committed
more datasets
1 parent 0fb089a commit a3c2207

File tree

2 files changed

+82
-0
lines changed

2 files changed

+82
-0
lines changed

cars196.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import os
2+
import torch
3+
import torch.utils.data as data
4+
import torchvision
5+
from torchvision.datasets import ImageFolder
6+
from torchvision.datasets import CIFAR10
7+
8+
class Cars196(ImageFolder, CIFAR10):
9+
base_folder = 'car_ims'
10+
url = 'http://imagenet.stanford.edu/internal/car196/car_ims.tgz'
11+
filename = 'cars_ims.tgz'
12+
tgz_md5 = 'd5c8f0aa497503f355e17dc7886c3f14'
13+
14+
base_folder_devkit = 'devkit'
15+
url_devkit = 'http://ai.stanford.edu/~jkrause/cars/car_devkit.tgz'
16+
filename_devkit = 'cars_devkit.tgz'
17+
tgz_md5_devkit = 'c3b158d763b6e2245038c8ad08e45376'
18+
19+
train_list = []
20+
test_list = []
21+
22+
def download(self):
23+
pass
24+
25+
def __init__(self, root, train=False, transform=None, target_transform=None, download=False, **kwargs):
26+
self.root = root
27+
if download:
28+
self.download()
29+
30+
if not self._check_integrity():
31+
raise RuntimeError('Dataset not found or corrupted.' +
32+
' You can use download=True to download it')
33+
ImageFolder.__init__(self, os.path.join(root, self.base_folder), transform = transform, target_transform = target_transform, **kwargs)

stanford_online_products.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import os
2+
import torch
3+
import torch.utils.data as data
4+
import torchvision
5+
from torchvision.datasets import ImageFolder
6+
from torchvision.datasets import CIFAR10
7+
from torchvision.datasets.utils import download_url
8+
9+
class StanfordOnlineProducts(ImageFolder, CIFAR10):
10+
base_folder = 'Stanford_Online_Products'
11+
url = 'ftp://cs.stanford.edu/cs/cvgl/Stanford_Online_Products.zip'
12+
filename = 'Stanford_Online_Products.zip'
13+
zip_md5 = '7f73d41a2f44250d4779881525aea32e'
14+
15+
train_list = [
16+
['bicycle_final/111265328556_0.JPG', '77420a4db9dd9284378d7287a0729edb']
17+
['chair_final/111182689872_0.JPG', 'ce78d10ed68560f4ea5fa1bec90206ba']
18+
]
19+
test_list = [
20+
['table_final/111194782300_0.JPG', '8203e079b5c134161bbfa7ee2a43a0a1'],
21+
['toaster_final/111157129195_0.JPG', 'd6c24ee8c05d986cafffa6af82ae224e']
22+
]
23+
24+
def __init__(self, root, train=None, transform=None, target_transform=None, download=False, **kwargs):
25+
self.root = root
26+
if download:
27+
self.download()
28+
29+
if not self._check_integrity():
30+
raise RuntimeError('Dataset not found or corrupted.' +
31+
' You can use download=True to download it')
32+
33+
def download(self):
34+
import zipfile
35+
36+
if self._check_integrity():
37+
print('Files already downloaded and verified')
38+
return
39+
40+
root = self.root
41+
download_url(self.url, root, self.filename, self.zip_md5)
42+
43+
# extract file
44+
cwd = os.getcwd()
45+
zip = zipfile.open(os.path.join(root, self.filename), "r")
46+
os.chdir(root)
47+
zip.extractall()
48+
zip.close()
49+
os.chdir(cwd)

0 commit comments

Comments
 (0)