You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
140 lines
5.2 KiB
140 lines
5.2 KiB
import os
|
|
import sys
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
from scipy import misc
|
|
from six.moves import urllib
|
|
import tarfile
|
|
|
|
|
|
def loadData(pathToDatasetFolder, oneHot=False):
|
|
"""
|
|
pathToDatasetFolder: Parent folder of CINIC-10 dataset folder or CINIC-10.tar.gz file
|
|
oneHot: Label encoding (one hot encoding or not)
|
|
|
|
Return: Train, validation and test sets and label numpy arrays
|
|
"""
|
|
sourceUrl = "https://datashare.is.ed.ac.uk/bitstream/handle/10283/3192/CINIC-10.tar.gz"
|
|
pathToFile = downloadDataset(pathToDatasetFolder, "CINIC-10.tar.gz", sourceUrl)
|
|
|
|
labelDict = {'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3,
|
|
'deer': 4, 'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8,
|
|
'truck': 9}
|
|
|
|
pathToTrain = os.path.join(pathToFile, "train")
|
|
pathToVal = os.path.join(pathToFile, "valid")
|
|
pathToTest = os.path.join(pathToFile, "test")
|
|
|
|
imgNamesTrain = [f for dp, dn, fn in os.walk(os.path.expanduser(pathToTrain)) for f in fn]
|
|
imgDirsTrain = [dp for dp, dn, fn in os.walk(os.path.expanduser(pathToTrain)) for f in fn]
|
|
imgNamesVal = [f for dp, dn, fn in os.walk(os.path.expanduser(pathToVal)) for f in fn]
|
|
imgDirsVal = [dp for dp, dn, fn in os.walk(os.path.expanduser(pathToVal)) for f in fn]
|
|
imgNamesTest = [f for dp, dn, fn in os.walk(os.path.expanduser(pathToTest)) for f in fn]
|
|
imgDirsTest = [dp for dp, dn, fn in os.walk(os.path.expanduser(pathToTest)) for f in fn]
|
|
|
|
XTrain = np.empty((len(imgNamesTrain), 32, 32, 3), dtype=np.float32)
|
|
YTrain = np.empty((len(imgNamesTrain)), dtype=np.int32)
|
|
XVal = np.empty((len(imgNamesVal), 32, 32, 3), dtype=np.float32)
|
|
YVal = np.empty((len(imgNamesVal)), dtype=np.int32)
|
|
XTest = np.empty((len(imgNamesTest), 32, 32, 3), dtype=np.float32)
|
|
YTest = np.empty((len(imgNamesTest)), dtype=np.int32)
|
|
|
|
print("Loading")
|
|
|
|
for i in range(len(imgNamesTrain)):
|
|
# img = plt.imread(os.path.join(imgDirsTrain[i], imgNamesTrain[i]))
|
|
img = misc.imread(os.path.join(imgDirsTrain[i], imgNamesTrain[i]))
|
|
if len(img.shape) == 2:
|
|
XTrain[i, :, :, 2] = XTrain[i, :, :, 1] = XTrain[i, :, :, 0] = img/255.
|
|
else:
|
|
XTrain[i] = img/255.
|
|
YTrain[i] = labelDict[os.path.basename(imgDirsTrain[i])]
|
|
for i in range(len(imgNamesVal)):
|
|
# img = plt.imread(os.path.join(imgDirsVal[i], imgNamesVal[i]))
|
|
img = misc.imread(os.path.join(imgDirsVal[i], imgNamesVal[i]))
|
|
if len(img.shape) == 2:
|
|
XVal[i, :, :, 2] = XVal[i, :, :, 1] = XVal[i, :, :, 0] = img/255.
|
|
else:
|
|
XVal[i] = img/255.
|
|
YVal[i] = labelDict[os.path.basename(imgDirsVal[i])]
|
|
for i in range(len(imgNamesTest)):
|
|
# img = plt.imread(os.path.join(imgDirsTest[i], imgNamesTest[i]))
|
|
img = misc.imread(os.path.join(imgDirsTest[i], imgNamesTest[i]))
|
|
if len(img.shape) == 2:
|
|
XTest[i, :, :, 2] = XTest[i, :, :, 1] = XTest[i, :, :, 0] = img/255.
|
|
else:
|
|
XTest[i] = img/255.
|
|
YTest[i] = labelDict[os.path.basename(imgDirsTest[i])]
|
|
|
|
if oneHot:
|
|
YTrain = toOneHot(YTrain, 10)
|
|
YVal = toOneHot(YVal, 10)
|
|
YTest = toOneHot(YTest, 10)
|
|
|
|
print("+ Dataset loaded")
|
|
|
|
return XTrain, YTrain, XVal, YVal, XTest, YTest
|
|
|
|
|
|
def downloadDataset(dirName, fileName, sourceUrl):
|
|
"""
|
|
https://github.com/tflearn/tflearn/blob/master/tflearn/datasets/cifar10.py
|
|
"""
|
|
cinicDirName = os.path.join(dirName, "CINIC-10/")
|
|
if not os.path.exists(cinicDirName):
|
|
os.mkdir(cinicDirName)
|
|
pathToFile = os.path.join(dirName, fileName)
|
|
if not os.path.exists(pathToFile):
|
|
print("Downloading")
|
|
pathToFile, _ = urllib.request.urlretrieve(sourceUrl, pathToFile, reporthook)
|
|
print("+ Downloaded")
|
|
untar(pathToFile, cinicDirName)
|
|
else:
|
|
print("+ Dataset already downloaded")
|
|
return cinicDirName
|
|
|
|
|
|
def reporthook(blocknum, blocksize, totalsize):
|
|
"""
|
|
reporthook from stackoverflow #13881092
|
|
https://github.com/tflearn/tflearn/blob/master/tflearn/datasets/cifar10.py
|
|
"""
|
|
readsofar = blocknum * blocksize
|
|
if totalsize > 0:
|
|
percent = readsofar * 1e2 / totalsize
|
|
s = "\r%5.1f%% %*d / %d" % (
|
|
percent, len(str(totalsize)), readsofar, totalsize)
|
|
sys.stderr.write(s)
|
|
if readsofar >= totalsize: # near the end
|
|
sys.stderr.write("\n")
|
|
else: # total size is unknown
|
|
sys.stderr.write("read %d\n" % (readsofar,))
|
|
|
|
|
|
def untar(fname, path):
|
|
if (fname.endswith("tar.gz")):
|
|
print("Extracting tar file")
|
|
tar = tarfile.open(fname)
|
|
tar.extractall(path=path)
|
|
tar.close()
|
|
print("+ Extracted")
|
|
else:
|
|
print("Not a tar.gz file")
|
|
|
|
|
|
def toOneHot(y, nb_classes=None):
|
|
"""
|
|
https://github.com/tflearn/tflearn/blob/master/tflearn/data_utils.py#L36
|
|
"""
|
|
if nb_classes:
|
|
# y = np.asarray(y, dtype='int32')
|
|
if len(y.shape) > 2:
|
|
print("Warning: data array ndim > 2")
|
|
if len(y.shape) > 1:
|
|
y = y.reshape(-1)
|
|
Y = np.zeros((len(y), nb_classes))
|
|
Y[np.arange(len(y)), y] = 1.
|
|
return Y
|
|
else:
|
|
y = np.array(y)
|
|
return (y[:, None] == np.unique(y)).astype(np.float32)
|
|
|