Skip to content

Commit

Permalink
more datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
vadimkantorov committed Apr 14, 2017
1 parent 0fb089a commit a3c2207
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 0 deletions.
33 changes: 33 additions & 0 deletions cars196.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import os
import torch
import torch.utils.data as data
import torchvision
from torchvision.datasets import ImageFolder
from torchvision.datasets import CIFAR10

class Cars196(ImageFolder, CIFAR10):
base_folder = 'car_ims'
url = 'http://imagenet.stanford.edu/internal/car196/car_ims.tgz'
filename = 'cars_ims.tgz'
tgz_md5 = 'd5c8f0aa497503f355e17dc7886c3f14'

base_folder_devkit = 'devkit'
url_devkit = 'http://ai.stanford.edu/~jkrause/cars/car_devkit.tgz'
filename_devkit = 'cars_devkit.tgz'
tgz_md5_devkit = 'c3b158d763b6e2245038c8ad08e45376'

train_list = []
test_list = []

def download(self):
pass

def __init__(self, root, train=False, transform=None, target_transform=None, download=False, **kwargs):
self.root = root
if download:
self.download()

if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')
ImageFolder.__init__(self, os.path.join(root, self.base_folder), transform = transform, target_transform = target_transform, **kwargs)
49 changes: 49 additions & 0 deletions stanford_online_products.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import os
import torch
import torch.utils.data as data
import torchvision
from torchvision.datasets import ImageFolder
from torchvision.datasets import CIFAR10
from torchvision.datasets.utils import download_url

class StanfordOnlineProducts(ImageFolder, CIFAR10):
base_folder = 'Stanford_Online_Products'
url = 'ftp://cs.stanford.edu/cs/cvgl/Stanford_Online_Products.zip'
filename = 'Stanford_Online_Products.zip'
zip_md5 = '7f73d41a2f44250d4779881525aea32e'

train_list = [
['bicycle_final/111265328556_0.JPG', '77420a4db9dd9284378d7287a0729edb']
['chair_final/111182689872_0.JPG', 'ce78d10ed68560f4ea5fa1bec90206ba']
]
test_list = [
['table_final/111194782300_0.JPG', '8203e079b5c134161bbfa7ee2a43a0a1'],
['toaster_final/111157129195_0.JPG', 'd6c24ee8c05d986cafffa6af82ae224e']
]

def __init__(self, root, train=None, transform=None, target_transform=None, download=False, **kwargs):
self.root = root
if download:
self.download()

if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')

def download(self):
import zipfile

if self._check_integrity():
print('Files already downloaded and verified')
return

root = self.root
download_url(self.url, root, self.filename, self.zip_md5)

# extract file
cwd = os.getcwd()
zip = zipfile.open(os.path.join(root, self.filename), "r")
os.chdir(root)
zip.extractall()
zip.close()
os.chdir(cwd)

0 comments on commit a3c2207

Please sign in to comment.