CINIC-10 Tensorflow Loader
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

140 lines
5.2 KiB

import os
import sys
import numpy as np
import matplotlib.pyplot as plt
from scipy import misc
from six.moves import urllib
import tarfile
def loadData(pathToDatasetFolder, oneHot=False):
"""
pathToDatasetFolder: Parent folder of CINIC-10 dataset folder or CINIC-10.tar.gz file
oneHot: Label encoding (one hot encoding or not)
Return: Train, validation and test sets and label numpy arrays
"""
sourceUrl = "https://datashare.is.ed.ac.uk/bitstream/handle/10283/3192/CINIC-10.tar.gz"
pathToFile = downloadDataset(pathToDatasetFolder, "CINIC-10.tar.gz", sourceUrl)
labelDict = {'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3,
'deer': 4, 'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8,
'truck': 9}
pathToTrain = os.path.join(pathToFile, "train")
pathToVal = os.path.join(pathToFile, "valid")
pathToTest = os.path.join(pathToFile, "test")
imgNamesTrain = [f for dp, dn, fn in os.walk(os.path.expanduser(pathToTrain)) for f in fn]
imgDirsTrain = [dp for dp, dn, fn in os.walk(os.path.expanduser(pathToTrain)) for f in fn]
imgNamesVal = [f for dp, dn, fn in os.walk(os.path.expanduser(pathToVal)) for f in fn]
imgDirsVal = [dp for dp, dn, fn in os.walk(os.path.expanduser(pathToVal)) for f in fn]
imgNamesTest = [f for dp, dn, fn in os.walk(os.path.expanduser(pathToTest)) for f in fn]
imgDirsTest = [dp for dp, dn, fn in os.walk(os.path.expanduser(pathToTest)) for f in fn]
XTrain = np.empty((len(imgNamesTrain), 32, 32, 3), dtype=np.float32)
YTrain = np.empty((len(imgNamesTrain)), dtype=np.int32)
XVal = np.empty((len(imgNamesVal), 32, 32, 3), dtype=np.float32)
YVal = np.empty((len(imgNamesVal)), dtype=np.int32)
XTest = np.empty((len(imgNamesTest), 32, 32, 3), dtype=np.float32)
YTest = np.empty((len(imgNamesTest)), dtype=np.int32)
print("Loading")
for i in range(len(imgNamesTrain)):
# img = plt.imread(os.path.join(imgDirsTrain[i], imgNamesTrain[i]))
img = misc.imread(os.path.join(imgDirsTrain[i], imgNamesTrain[i]))
if len(img.shape) == 2:
XTrain[i, :, :, 2] = XTrain[i, :, :, 1] = XTrain[i, :, :, 0] = img/255.
else:
XTrain[i] = img/255.
YTrain[i] = labelDict[os.path.basename(imgDirsTrain[i])]
for i in range(len(imgNamesVal)):
# img = plt.imread(os.path.join(imgDirsVal[i], imgNamesVal[i]))
img = misc.imread(os.path.join(imgDirsVal[i], imgNamesVal[i]))
if len(img.shape) == 2:
XVal[i, :, :, 2] = XVal[i, :, :, 1] = XVal[i, :, :, 0] = img/255.
else:
XVal[i] = img/255.
YVal[i] = labelDict[os.path.basename(imgDirsVal[i])]
for i in range(len(imgNamesTest)):
# img = plt.imread(os.path.join(imgDirsTest[i], imgNamesTest[i]))
img = misc.imread(os.path.join(imgDirsTest[i], imgNamesTest[i]))
if len(img.shape) == 2:
XTest[i, :, :, 2] = XTest[i, :, :, 1] = XTest[i, :, :, 0] = img/255.
else:
XTest[i] = img/255.
YTest[i] = labelDict[os.path.basename(imgDirsTest[i])]
if oneHot:
YTrain = toOneHot(YTrain, 10)
YVal = toOneHot(YVal, 10)
YTest = toOneHot(YTest, 10)
print("+ Dataset loaded")
return XTrain, YTrain, XVal, YVal, XTest, YTest
def downloadDataset(dirName, fileName, sourceUrl):
"""
https://github.com/tflearn/tflearn/blob/master/tflearn/datasets/cifar10.py
"""
cinicDirName = os.path.join(dirName, "CINIC-10/")
if not os.path.exists(cinicDirName):
os.mkdir(cinicDirName)
pathToFile = os.path.join(dirName, fileName)
if not os.path.exists(pathToFile):
print("Downloading")
pathToFile, _ = urllib.request.urlretrieve(sourceUrl, pathToFile, reporthook)
print("+ Downloaded")
untar(pathToFile, cinicDirName)
else:
print("+ Dataset already downloaded")
return cinicDirName
def reporthook(blocknum, blocksize, totalsize):
"""
reporthook from stackoverflow #13881092
https://github.com/tflearn/tflearn/blob/master/tflearn/datasets/cifar10.py
"""
readsofar = blocknum * blocksize
if totalsize > 0:
percent = readsofar * 1e2 / totalsize
s = "\r%5.1f%% %*d / %d" % (
percent, len(str(totalsize)), readsofar, totalsize)
sys.stderr.write(s)
if readsofar >= totalsize: # near the end
sys.stderr.write("\n")
else: # total size is unknown
sys.stderr.write("read %d\n" % (readsofar,))
def untar(fname, path):
if (fname.endswith("tar.gz")):
print("Extracting tar file")
tar = tarfile.open(fname)
tar.extractall(path=path)
tar.close()
print("+ Extracted")
else:
print("Not a tar.gz file")
def toOneHot(y, nb_classes=None):
"""
https://github.com/tflearn/tflearn/blob/master/tflearn/data_utils.py#L36
"""
if nb_classes:
# y = np.asarray(y, dtype='int32')
if len(y.shape) > 2:
print("Warning: data array ndim > 2")
if len(y.shape) > 1:
y = y.reshape(-1)
Y = np.zeros((len(y), nb_classes))
Y[np.arange(len(y)), y] = 1.
return Y
else:
y = np.array(y)
return (y[:, None] == np.unique(y)).astype(np.float32)