CINIC-10 Tensorflow Loader
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

67 lines
2.3 KiB

# -*- coding: utf-8 -*-
""" Deep Residual Network.
https://github.com/tflearn/tflearn/blob/master/examples/images/residual_network_cifar10.py
Applying a Deep Residual Network to CINIC-10 Dataset classification task.
References:
- K. He, X. Zhang, S. Ren, and J. Sun. Deep Residual Learning for Image
Recognition, 2015.
- CINIC-10 is not ImageNet or CIFAR-10, Darlow et al., 2018.
Links:
- [Deep Residual Network](http://arxiv.org/pdf/1512.03385.pdf)
- [CINIC-10 Dataset](https://github.com/BayesWatch/cinic-10)
"""
from __future__ import division, print_function, absolute_import
import tflearn
from tflearn.data_utils import shuffle
# Residual blocks
# 32 layers: n=5, 56 layers: n=9, 110 layers: n=18
n = 18
import cinic10
X, Y, _, _, testX, testY = cinic10.loadData("/home/altinel/Downloads/datasets", oneHot=True)
X, Y = shuffle(X, Y)
# Real-time data preprocessing
img_prep = tflearn.ImagePreprocessing()
img_prep.add_featurewise_zero_center(per_channel=True)
# Real-time data augmentation
img_aug = tflearn.ImageAugmentation()
img_aug.add_random_flip_leftright()
img_aug.add_random_crop([32, 32], padding=4)
# Building Residual Network
net = tflearn.input_data(shape=[None, 32, 32, 3],
data_preprocessing=img_prep,
data_augmentation=img_aug)
net = tflearn.conv_2d(net, 16, 3, regularizer='L2', weight_decay=0.0001)
net = tflearn.residual_block(net, n, 16)
net = tflearn.residual_block(net, 1, 32, downsample=True)
net = tflearn.residual_block(net, n-1, 32)
net = tflearn.residual_block(net, 1, 64, downsample=True)
net = tflearn.residual_block(net, n-1, 64)
net = tflearn.batch_normalization(net)
net = tflearn.activation(net, 'relu')
net = tflearn.global_avg_pool(net)
# Regression
net = tflearn.fully_connected(net, 10, activation='softmax')
mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=32000, staircase=True)
net = tflearn.regression(net, optimizer=mom,
loss='categorical_crossentropy')
# Training
model = tflearn.DNN(net, checkpoint_path='model_resnet_cinic10',
max_checkpoints=10, tensorboard_verbose=0,
clip_gradients=0.)
model.fit(X, Y, n_epoch=200, validation_set=(testX, testY),
snapshot_epoch=False, snapshot_step=500,
show_metric=True, batch_size=128, shuffle=True,
run_id='resnet_cinic10')