Basic Autoencoder

Anomaly Detection Using Reconstruction Error

Why use an autoencoder? In practice, autoencoders are often applied to data denoising and dimensionality reduction. This works great for representation learning and a little less great for data compression.

In deep learning, an autoencoder is a neural network that “attempts” to reconstruct its input. It can serve as a form of feature extraction, and autoencoders can be stacked to create “deep” networks. Features generated by an autoencoder can be fed into other algorithms for classification, clustering, and anomaly detection.

Autoencoders are also useful for data visualization when the raw input data has high dimensionality and cannot easily be plotted. By lowering the dimensionality, the output can sometimes be compressed into a 2D or 3D space for better data exploration.

How do autoencoders work?

Autoencoders are comprised of:

  1. Encoding function (the “encoder”)

  2. Decoding function (the “decoder”)

  3. Distance function (a “loss function”)

An input is fed into the autoencoder and turned into a compressed representation. The decoder then learns how to reconstruct the original input from the compressed representation, where during an unsupervised training process, the loss function helps to correct the error produced by the decoder. This process is automatic (hence “auto”-encoder); i.e. it does not require human intervention.

What does this tutorial teach?

Now that you know how to create different network configurations with MultiLayerNetwork and ComputationGraph, we will construct a “stacked” autoencoder that performs anomaly detection on MNIST digits without pretraining. The goal is to identify outlier digits; i.e. digits that are unusual and atypical. Identification of items, events or observations that “stand out” from the norm of a given dataset is broadly known as anomaly detection. Anomaly detection does not require a labeled dataset, and can be undertaken with unsupervised learning, which is helpful because most of the world’s data is not labeled.

This type of anomaly detection uses reconstruction error to measure how well the decoder is performing. Stereotypical examples should have low reconstruction error, whereas outliers should have high reconstruction error.

What is anomaly detection good for?

Network intrusion, fraud detection, systems monitoring, sensor network event detection (IoT), and unusual trajectory sensing are examples of anomaly detection applications.

Imports

import org.apache.commons.lang3.tuple.ImmutablePair
import org.apache.commons.lang3.tuple.Pair
import org.nd4j.linalg.activations.Activation
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator
import org.deeplearning4j.nn.api.OptimizationAlgorithm
import org.deeplearning4j.nn.conf.NeuralNetConfiguration
import org.nd4j.linalg.learning.config.AdaGrad
import org.deeplearning4j.nn.conf.layers.DenseLayer
import org.deeplearning4j.nn.conf.layers.OutputLayer
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork
import org.deeplearning4j.nn.weights.WeightInit
import org.deeplearning4j.optimize.api.IterationListener
import org.deeplearning4j.optimize.listeners.ScoreIterationListener
import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.dataset.DataSet
import org.nd4j.linalg.factory.Nd4j
import org.nd4j.linalg.lossfunctions.LossFunctions
import javax.swing._
import java.awt._
import java.awt.image.BufferedImage
import java.util._
import java.util

import scala.collection.JavaConversions._

The stacked autoencoder

The following autoencoder uses two stacked dense layers for encoding. The MNIST digits are transformed into a flat 1D array of length 784 (MNIST images are 28x28 pixels, which equals 784 when you lay them end to end).

784 → 250 → 10 → 250 → 784

val conf = new NeuralNetConfiguration.Builder()
    .seed(12345)
    .weightInit(WeightInit.XAVIER)
    .updater(new AdaGrad(0.05))
    .activation(Activation.RELU)
    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
    .l2(0.0001)
    .list()
    .layer(0, new DenseLayer.Builder().nIn(784).nOut(250)
            .build())
    .layer(1, new DenseLayer.Builder().nIn(250).nOut(10)
            .build())
    .layer(2, new DenseLayer.Builder().nIn(10).nOut(250)
            .build())
    .layer(3, new OutputLayer.Builder().nIn(250).nOut(784)
            .lossFunction(LossFunctions.LossFunction.MSE)
            .build())
    .build()

val net = new MultiLayerNetwork(conf)
net.setListeners(new ScoreIterationListener(1))

Using the MNIST iterator

The MNIST iterator, like most of Deeplearning4j’s built-in iterators, extends the DataSetIterator class. This API allows for simple instantiation of datasets and the automatic downloading of data in the background.

//Load data and split into training and testing sets. 40000 train, 10000 test
val iter = new MnistDataSetIterator(100,50000,false)

val featuresTrain = new util.ArrayList[INDArray]
val featuresTest = new util.ArrayList[INDArray]
val labelsTest = new util.ArrayList[INDArray]

val rand = new util.Random(12345)

while(iter.hasNext()){
    val next = iter.next()
    val split = next.splitTestAndTrain(80, rand)  //80/20 split (from miniBatch = 100)
    featuresTrain.add(split.getTrain().getFeatures())
    val dsTest = split.getTest()
    featuresTest.add(dsTest.getFeatures())
    val indexes = Nd4j.argMax(dsTest.getLabels(),1) //Convert from one-hot representation -> index
    labelsTest.add(indexes)
}

Unsupervised training

Now that the network configruation is set up and instantiated along with our MNIST test/train iterators, training takes just a few lines of code. The fun begins.

Earlier, we attached a ScoreIterationListener to the model by using the setListeners() method. Depending on the browser used to run this notebook, you can open the debugger/inspector to view listener output. This output is redirected to the console since the internals of Deeplearning4j use SL4J for logging, and the output is being redirected by Zeppelin. This helps reduce clutter in the notebooks.

// the "simple" way to do multiple epochs is to wrap fit() in a loop
val nEpochs = 30
(1 to nEpochs).foreach{ epoch =>  
    featuresTrain.forEach( data => net.fit(data, data))
    println("Epoch " + epoch + " complete");
}

Evaluating the model

Now that the autoencoder has been trained, we’ll evaluate the model on the test data. Each example will be scored individually, and a map will be composed that relates each digit to a list of (score, example) pairs.

Finally, we will calculate the N best and N worst scores per digit.

//Evaluate the model on the test data
//Score each example in the test set separately
//Compose a map that relates each digit to a list of (score, example) pairs
//Then find N best and N worst scores per digit
val listsByDigit = new util.HashMap[Integer, ArrayList[Pair[Double, INDArray]]]

(0 to 9).foreach{ i => listsByDigit.put(i, new util.ArrayList[Pair[Double, INDArray]]) }

(0 to featuresTest.size-1).foreach{ i =>
    val testData = featuresTest.get(i)
    val labels = labelsTest.get(i)
    
    (0 to testData.rows-1).foreach{ j =>
        val example = testData.getRow(j, true)
        val digit = labels.getDouble(j).toInt
        val score = net.score(new DataSet(example, example))
        // Add (score, example) pair to the appropriate list
        val digitAllPairs = listsByDigit.get(digit)
        digitAllPairs.add(new ImmutablePair[Double, INDArray](score, example))
    }
}

//Sort each list in the map by score
val c = new Comparator[Pair[Double, INDArray]]() {
  override def compare(o1: Pair[Double, INDArray],
                       o2: Pair[Double, INDArray]): Int =
    java.lang.Double.compare(o1.getLeft, o2.getLeft)
}

listsByDigit.values().forEach(digitAllPairs => Collections.sort(digitAllPairs, c))

//After sorting, select N best and N worst scores (by reconstruction error) for each digit, where N=5
val best = new util.ArrayList[INDArray](50)
val worst = new util.ArrayList[INDArray](50)

(0 to 9).foreach{ i => 
    val list = listsByDigit.get(i)
    
    (0 to 4).foreach{ j=>
        best.add(list.get(j).getRight)
        worst.add(list.get(list.size - j - 1).getRight)
    }
}

Last updated