You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
67 lines
2.3 KiB
67 lines
2.3 KiB
# -*- coding: utf-8 -*-
|
|
|
|
""" Deep Residual Network.
|
|
https://github.com/tflearn/tflearn/blob/master/examples/images/residual_network_cifar10.py
|
|
|
|
Applying a Deep Residual Network to CINIC-10 Dataset classification task.
|
|
|
|
References:
|
|
- K. He, X. Zhang, S. Ren, and J. Sun. Deep Residual Learning for Image
|
|
Recognition, 2015.
|
|
- CINIC-10 is not ImageNet or CIFAR-10, Darlow et al., 2018.
|
|
|
|
Links:
|
|
- [Deep Residual Network](http://arxiv.org/pdf/1512.03385.pdf)
|
|
- [CINIC-10 Dataset](https://github.com/BayesWatch/cinic-10)
|
|
|
|
"""
|
|
|
|
from __future__ import division, print_function, absolute_import
|
|
|
|
import tflearn
|
|
from tflearn.data_utils import shuffle
|
|
|
|
# Residual blocks
|
|
# 32 layers: n=5, 56 layers: n=9, 110 layers: n=18
|
|
n = 18
|
|
|
|
import cinic10
|
|
X, Y, _, _, testX, testY = cinic10.loadData("/home/altinel/Downloads/datasets", oneHot=True)
|
|
X, Y = shuffle(X, Y)
|
|
|
|
# Real-time data preprocessing
|
|
img_prep = tflearn.ImagePreprocessing()
|
|
img_prep.add_featurewise_zero_center(per_channel=True)
|
|
|
|
# Real-time data augmentation
|
|
img_aug = tflearn.ImageAugmentation()
|
|
img_aug.add_random_flip_leftright()
|
|
img_aug.add_random_crop([32, 32], padding=4)
|
|
|
|
# Building Residual Network
|
|
net = tflearn.input_data(shape=[None, 32, 32, 3],
|
|
data_preprocessing=img_prep,
|
|
data_augmentation=img_aug)
|
|
net = tflearn.conv_2d(net, 16, 3, regularizer='L2', weight_decay=0.0001)
|
|
net = tflearn.residual_block(net, n, 16)
|
|
net = tflearn.residual_block(net, 1, 32, downsample=True)
|
|
net = tflearn.residual_block(net, n-1, 32)
|
|
net = tflearn.residual_block(net, 1, 64, downsample=True)
|
|
net = tflearn.residual_block(net, n-1, 64)
|
|
net = tflearn.batch_normalization(net)
|
|
net = tflearn.activation(net, 'relu')
|
|
net = tflearn.global_avg_pool(net)
|
|
# Regression
|
|
net = tflearn.fully_connected(net, 10, activation='softmax')
|
|
mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=32000, staircase=True)
|
|
net = tflearn.regression(net, optimizer=mom,
|
|
loss='categorical_crossentropy')
|
|
# Training
|
|
model = tflearn.DNN(net, checkpoint_path='model_resnet_cinic10',
|
|
max_checkpoints=10, tensorboard_verbose=0,
|
|
clip_gradients=0.)
|
|
|
|
model.fit(X, Y, n_epoch=200, validation_set=(testX, testY),
|
|
snapshot_epoch=False, snapshot_step=500,
|
|
show_metric=True, batch_size=128, shuffle=True,
|
|
run_id='resnet_cinic10')
|
|
|