数据集迭代器允许将数据轻松加载到神经网络中,并帮助组织批处理、转换和掩码。包含在Eclipse DL4J中的迭代器有助于用户提供的数据,或者自动加载公共的基准数据集如MNIST和IRIS。
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
// 传入一个 MNIST 数据集迭代器,自动获取数据
DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, rngSeed);
net.fit(mnistTrain);
// 直接传递给神经网络
DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, rngSeed);
net.eval(mnistTest);
//使用一个评估类
Evaluation eval = new Evaluation(10); //创建一个带有10个可能分类的评估对象
while(mnistTest.hasNext()){
DataSet next = mnistTest.next();
INDArray output = model.output(next.getFeatureMatrix()); //得到网络预测
eval.eval(next.getLabels(), output); //检查对真实分类的预测
}