Source code for deepobs.mnist.mnist_input

# -*- coding: utf-8 -*-
"""
This module contains TensorFlow data loading functionality for MNIST.
"""

import tensorflow as tf
import numpy as np
import gzip
import os

from .. import dataset_utils


[docs]class data_loading: """Class providing the data loading functionality for the MNIST data set. Args: batch_size (int): Batch size of the input-output pairs. No default value is given. Attributes: batch_size (int): Batch size of the input-output pairs. train_eval_size (int): Number of data points to evaluate during the `train eval` phase. Currently set to ``10000`` the size of the test set. D_train (tf.data.Dataset): The training data set. D_train_eval (tf.data.Dataset): The training evaluation data set. It is the same data as `D_train` but we go through it separately. D_test (tf.data.Dataset): The test data set. phase (tf.Variable): Variable to describe which phase we are currently in. Can be "train", "train_eval" or "test". The phase variable can determine the behaviour of the network, for example deactivate dropout during evaluation. iterator (tf.data.Iterator): A single iterator for all three data sets. We us the initialization operators (see below) to switch this iterator to the data sets. X (tf.Tensor): Tensor holding the MNIST images. It has dimension `batch_size` x ``28`` (image size) x ``28`` (image size) x ``1`` (rgb). y (tf.Tensor): Label of the MNIST images. It has dimension `batch_size` x ``10`` (number of classes). train_init_op (tf.Operation): A TensorFlow operation to be performed before starting every training epoch. It sets the `phase` variable to "train" and initializes the iterator to the training data set. train_eval_init_op (tf.Operation): A TensorFlow operation to be performed before starting every training eval phase. It sets the `phase` variable to "train_eval" and initializes the iterator to the training eval data set. test_init_op (tf.Operation): A TensorFlow operation to be performed before starting every test evaluation phase. It sets the `phase` variable to "test" and initializes the iterator to the test data set. """ def __init__(self, batch_size): """Initializes the data loading class. Args: batch_size (int): Batch size of the input-output pairs. No default value is given. """ self.train_eval_size = 10000 # The size of the test set self.batch_size = batch_size self.D_train = self.train_dataset(batch_size) self.D_train_eval = self.train_eval_dataset(batch_size) self.D_test = self.test_dataset(batch_size) self.phase = tf.Variable("train", name="phase", trainable=False) # Reinitializable iterator given types and shapes of the outputs (needs to be the same for train and test of course) self.iterator = tf.data.Iterator.from_structure( self.D_train.output_types, self.D_train.output_shapes) self.X, self.y = self.iterator.get_next() # Operations to do when switching the phase (initialize iterator and assign phase to phase variable) self.train_init_op = tf.group([self.iterator.make_initializer( self.D_train), tf.assign(self.phase, "train")], name="train_init_op") self.train_eval_init_op = tf.group([self.iterator.make_initializer( self.D_train_eval), tf.assign(self.phase, "train_eval")], name="train_eval_init_op") self.test_init_op = tf.group([self.iterator.make_initializer( self.D_test), tf.assign(self.phase, "test")], name="test_init_op")
[docs] def load(self): """Returns the data (`X` (images) and `y` (labels)) and the phase variable. Returns: tupel: Tupel consisting of the images (`X`), the label (`y`) and the phase variable (`phase`). """ return self.X, self.y, self.phase
[docs] def train_dataset(self, batch_size): """Creates the training data set. Args: batch_size (int): Batch size of the input-output pairs. Returns: tf.data.Dataset: The training data set. """ data_dir = dataset_utils.get_data_dir() train_images_file = os.path.join( data_dir, "mnist", "train-images-idx3-ubyte.gz") train_labels_file = os.path.join( data_dir, "mnist", "train-labels-idx1-ubyte.gz") return self.make_dataset(train_images_file, train_labels_file, batch_size, one_hot=True, shuffle=True, shuffle_buffer_size=10000, num_prefetched_batches=10)
[docs] def train_eval_dataset(self, batch_size): """Creates the train eval data set. Args: batch_size (int): Batch size of the input-output pairs. Returns: tf.data.Dataset: The train eval data set. """ data_dir = dataset_utils.get_data_dir() train_images_file = os.path.join( data_dir, "mnist", "train-images-idx3-ubyte.gz") train_labels_file = os.path.join( data_dir, "mnist", "train-labels-idx1-ubyte.gz") return self.make_dataset(train_images_file, train_labels_file, batch_size, one_hot=True, shuffle=True, shuffle_buffer_size=60000, num_prefetched_batches=5, data_set_size=self.train_eval_size)
[docs] def test_dataset(self, batch_size): """Creates the test data set. Args: batch_size (int): Batch size of the input-output pairs. Returns: tf.data.Dataset: The test data set. """ data_dir = dataset_utils.get_data_dir() test_images_file = os.path.join( data_dir, "mnist", "t10k-images-idx3-ubyte.gz") test_labels_file = os.path.join( data_dir, "mnist", "t10k-labels-idx1-ubyte.gz") return self.make_dataset(test_images_file, test_labels_file, batch_size, one_hot=True, shuffle=False, shuffle_buffer_size=-1, num_prefetched_batches=5)
[docs] def make_dataset(self, images_file, labels_file, batch_size, one_hot=True, shuffle=True, shuffle_buffer_size=10000, num_prefetched_batches=10, data_set_size=-1): """Creates a data set from given images and label files. Args: images_file (str): Path to the images in compressed ``.gz`` files. labels_file (str): Path to the labels in compressed ``.gz`` files. batch_size (int): Batch size of the input-output pairs. one_hot (bool): Switch to turn on or off one-hot encoding of the labels. Defaults to ``True``. shuffle (bool): Switch to turn on or off shuffling of the data set. Defaults to ``True``. shuffle_buffer_size (int): Size of the shuffle buffer. Defaults to ``10000`` the size of the `test` and `train eval` data set, meaning that they will be completely shuffled. num_prefetched_batches (int): Number of prefeteched batches, defaults to ``10``. data_set_size (int): Size of the data set to extract from the images and label files. Defaults to ``-1`` meaning that the full data set is used. Returns: tf.data.Dataset: Data set object created from the images and label files. """ X, y = self.read_mnist_data(images_file, labels_file, one_hot) with tf.name_scope("mnist"): with tf.device('/cpu:0'): D = tf.data.Dataset.from_tensor_slices((X, y)) if shuffle: D = D.shuffle(buffer_size=shuffle_buffer_size) D = D.take(data_set_size) D = D.batch(batch_size) D = D.prefetch(buffer_size=num_prefetched_batches) return D
[docs] def read32(self, bytestream): """Helper function to read a bytestream. Args: bytestream (bytestream): Input bytestream. Returns: np.array: Bytestream as a np array. """ dt = np.dtype(np.uint32).newbyteorder('>') return np.frombuffer(bytestream.read(4), dtype=dt)[0]
[docs] def read_mnist_data(self, images_file, labels_file, one_hot=True): """Read the MNIST images and labels from the downloaded files. Args: images_file (str): Path to the images in compressed ``.gz`` files. labels_file (str): Path to the labels in compressed ``.gz`` files. one_hot (bool): Switch to turn on or off one-hot encoding of the labels. Defaults to ``True``. Returns: tupel: Tupel consisting of all the images (`X`) and the labels (`y`). """ # Load images from images_file with tf.gfile.Open(images_file, 'rb') as f: print('Extracting %s' % f.name) with gzip.GzipFile(fileobj=f) as bytestream: magic = self.read32(bytestream) if magic != 2051: raise ValueError('Invalid magic number %d in MNIST image file: %s' % (magic, f.name)) num_images = self.read32(bytestream) rows = self.read32(bytestream) cols = self.read32(bytestream) buf = bytestream.read(rows * cols * num_images) data = np.frombuffer(buf, dtype=np.uint8) X = data.reshape(num_images, rows, cols, 1) X = X.astype(np.float32) / 255.0 # Load labels from labels file with tf.gfile.Open(labels_file, 'rb') as f: print('Extracting %s' % f.name) with gzip.GzipFile(fileobj=f) as bytestream: magic = self.read32(bytestream) if magic != 2049: raise ValueError('Invalid magic number %d in MNIST label file: %s' % (magic, f.name)) num_items = self.read32(bytestream) buf = bytestream.read(num_items) y = np.frombuffer(buf, dtype=np.uint8) if one_hot: y = dataset_utils.dense_to_one_hot(y, 10) y = y.astype(np.int32) return X, y