Skip to content

BinaryNet on CIFAR10 (Advanced)

Run this notebook here: Binder

In this example we demonstrate how to use Larq to build and train BinaryNet on the CIFAR10 dataset to achieve a validation accuracy of around 90% using a heavy lifting GPU like a Nvidia V100. On a Nvidia V100 it takes approximately 250 minutes to train. Compared to the original papers, BinaryConnect: Training Deep Neural Networks with binary weights during propagations, and Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1, we do not implement image whitening, but we use image augmentation, and a stepped learning rate scheduler.

import tensorflow as tf
import larq as lq
import numpy as np

Import CIFAR10 Dataset

Here we download the CIFAR10 dataset:

train_data, test_data = tf.keras.datasets.cifar10.load_data()

Next, we define our image augmentation technqiues, and create the dataset:

def resize_and_flip(image, labels, training):
    image = tf.cast(image, tf.float32) / (255./2.) - 1.
    if training:
        image = tf.image.resize_image_with_crop_or_pad(image, 40, 40)
        image = tf.random_crop(image, [32, 32, 3])
        image = tf.image.random_flip_left_right(image)
    return image, labels
def create_dataset(data, batch_size, training):
    images, labels = data
    labels = tf.one_hot(np.squeeze(labels), 10)
    dataset = tf.data.Dataset.from_tensor_slices((images, labels))
    dataset = dataset.repeat()
    if training:
        dataset = dataset.shuffle(1000)
    dataset = dataset.map(lambda x, y: resize_and_flip(x, y, training))
    dataset = dataset.batch(batch_size)
    return dataset
batch_size = 50

train_dataset = create_dataset(train_data, batch_size, True)
test_dataset = create_dataset(test_data, batch_size, False)

Build BinaryNet

Here we build the binarynet model layer by layer using a keras sequential model:

# All quantized layers except the first will use the same options
kwargs = dict(input_quantizer="ste_sign",
              kernel_quantizer="ste_sign",
              kernel_constraint="weight_clip",
              use_bias=False)

model = tf.keras.models.Sequential([
    # In the first layer we only quantize the weights and not the input
    lq.layers.QuantConv2D(128, 3,
                          kernel_quantizer="ste_sign",
                          kernel_constraint="weight_clip",
                          use_bias=False,
                          input_shape=(32, 32, 3)),
    tf.keras.layers.BatchNormalization(momentum=0.999, scale=False),

    lq.layers.QuantConv2D(128, 3, padding="same", **kwargs),
    tf.keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2)),
    tf.keras.layers.BatchNormalization(momentum=0.999, scale=False),

    lq.layers.QuantConv2D(256, 3, padding="same", **kwargs),
    tf.keras.layers.BatchNormalization(momentum=0.999, scale=False),

    lq.layers.QuantConv2D(256, 3, padding="same", **kwargs),
    tf.keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2)),
    tf.keras.layers.BatchNormalization(momentum=0.999, scale=False),

    lq.layers.QuantConv2D(512, 3, padding="same", **kwargs),
    tf.keras.layers.BatchNormalization(momentum=0.999, scale=False),

    lq.layers.QuantConv2D(512, 3, padding="same", **kwargs),
    tf.keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2)),
    tf.keras.layers.BatchNormalization(momentum=0.999, scale=False),
    tf.keras.layers.Flatten(),

    lq.layers.QuantDense(1024, **kwargs),
    tf.keras.layers.BatchNormalization(momentum=0.999, scale=False),

    lq.layers.QuantDense(1024, **kwargs),
    tf.keras.layers.BatchNormalization(momentum=0.999, scale=False),

    lq.layers.QuantDense(10, **kwargs),
    tf.keras.layers.BatchNormalization(momentum=0.999, scale=False),
    tf.keras.layers.Activation("softmax")
])

Larq allows you to print a summary of the model that includes bit-precision information:

lq.models.summary(model)
Layer                     Outputs              # 1-bit    # 32-bit
------------------------  -----------------  ---------  ----------
quant_conv2d              (-1, 30, 30, 128)       3456           0
batch_normalization_v1    (-1, 30, 30, 128)          0         384
quant_conv2d_1            (-1, 30, 30, 128)     147456           0
max_pooling2d             (-1, 15, 15, 128)          0           0
batch_normalization_v1_1  (-1, 15, 15, 128)          0         384
quant_conv2d_2            (-1, 15, 15, 256)     294912           0
batch_normalization_v1_2  (-1, 15, 15, 256)          0         768
quant_conv2d_3            (-1, 15, 15, 256)     589824           0
max_pooling2d_1           (-1, 7, 7, 256)            0           0
batch_normalization_v1_3  (-1, 7, 7, 256)            0         768
quant_conv2d_4            (-1, 7, 7, 512)      1179648           0
batch_normalization_v1_4  (-1, 7, 7, 512)            0        1536
quant_conv2d_5            (-1, 7, 7, 512)      2359296           0
max_pooling2d_2           (-1, 3, 3, 512)            0           0
batch_normalization_v1_5  (-1, 3, 3, 512)            0        1536
flatten                   (-1, 4608)                 0           0
quant_dense               (-1, 1024)           4718592           0
batch_normalization_v1_6  (-1, 1024)                 0        3072
quant_dense_1             (-1, 1024)           1048576           0
batch_normalization_v1_7  (-1, 1024)                 0        3072
quant_dense_2             (-1, 10)               10240           0
batch_normalization_v1_8  (-1, 10)                   0          30
activation                (-1, 10)                   0           0
Total                                         10352000       11550

Total params: 10363550
Trainable params: 10355850
Non-trainable params: 7700

Model Training

We compile and train the model as you are used to in Keras:

initial_lr = 1e-3
var_decay = 1e-5

optimizer = tf.keras.optimizers.Adam(lr=initial_lr, decay=var_decay)
model.compile(
    optimizer=lq.optimizers.XavierLearningRateScaling(optimizer, model),
    loss="categorical_crossentropy",
    metrics=["accuracy"],
)
def lr_schedule(epoch):
    return initial_lr * 0.1 ** (epoch // 100)

trained_model = model.fit(
    train_dataset,
    epochs=500,
    steps_per_epoch=train_data[1].shape[0] // batch_size,
    validation_data=test_dataset,
    validation_steps=test_data[1].shape[0] // batch_size,
    verbose=1,
    callbacks=[tf.keras.callbacks.LearningRateScheduler(lr_schedule)]
)