Android概述
在Android应用中使用深度学习和神经网络
在Android应用中使用深度学习和神经网络
内容
一般来说,训练一个神经网络是一项最适合有多个GPU的强大计算机的任务。但如果你想在你简陋的安卓手机或平板电脑上做呢?好吧,这绝对是可能的。然而,考虑到一个普通的Android设备的规格,它很可能会相当慢。如果这对你来说不是问题,继续读下去。
在本教程中,我将向您展示如何使用 Deeplearning4J,一个流行的基于Java的深度学习库,在Android设备上创建和训练神经网络。
为了获得最佳效果,您需要以下各项:
运行API级别21或更高的安卓设备或模拟器,内部存储空间大约为200 MB。我强烈建议您首先使用模拟器,因为您可以快速调整它,以防内存或存储空间不足。
Android Studio 2.2 或更新版本
这里可以找到在Android应用程序中使用DL4J的更深入的研究。本指南涵盖依赖项、内存管理、保存设备训练模型以及在应用程序中加载预先训练的模型。
要在项目中使用Deeplearning4J,请将以下实现依赖项添加到应用程序模块的build.gradle文件中:
如果选择将依赖项的快照版本与gradle一起使用,则需要在根目录中创建pom.xml文件,并从终端对其运行 mvn -U compile 。您还需要在build.gradle文件的repository{}块中包含mavenLocal()。下面提供了一个pom.xml文件示例。
Android Studio 3.0引入了新的Gradle,现在也应该定义annotationProcessors如果您正在使用它,请向Gradle依赖项添加以下代码:
如您所见,DL4J依赖于ND4J,Java的N维缩写,它是一个提供快速N维数组的库。ND4J在内部依赖于一个名为OpenBLAS的库,该库包含特定于平台的本地代码。因此,您必须加载与您的Android设备架构相匹配的OpenBLAS和ND4J版本。
DL4J和ND4J的依赖项有几个同名的文件。为了避免构建错误,请将以下排除参数添加到packagingOptions中。
编译后的代码将有超过65536个方法。要处理此情况,请在defaultConfig中添加以下选项:
现在,按Sync now更新项目。最后,确保APK不同时包含lib/armeabi和lib/armeabi-v7a子目录。如果是,请将所有文件移动到其中一个或另一个,因为某些Android设备将同时存在这两个文件。
训练神经网络是CPU密集型的,这就是为什么您不想在应用程序的UI线程中进行训练。我不太确定DL4J是否在默认情况下异步训练其网络。为了安全起见,我现在将使用AsyncTask类生成一个单独的线程。
因为createAndUseNetwork()方法还不存在,所以创建它。
DL4J有一个非常直观的API。现在让我们用它来创建一个简单的多层感知器与隐藏层。它将获取两个输入值,并输出一个输出值。要创建层,我们将使用DenseLayer和OutputLayer类。相应地,将以下代码添加到在上一步中创建的createAndUseNetwork()方法中:
现在我们的层已经准备好了,让我们创建一个NeuralNetConfiguration.Builder对象来配置我们的神经网络。
我们现在必须创建一个NeuralNetConfiguration.ListBuilder对象来实际连接我们的层并指定它们的顺序。
另外,通过添加以下代码启用反向传播:
此时,我们可以将神经网络生成并初始化为多层网络类的实例。
为了创建我们的训练数据,我们将使用ND4J提供的INDArray类
正如你可能已经猜到的,我们的神经网络将表现得像一个异或门。训练数据有四个样本,您必须在代码中提到它。
现在,为输入和预期输出创建两个INDArray对象,并用零初始化它们。
注意,输入数组中的列数等于输入层中的神经元数。类似地,输出数组中的列数等于输出层中的神经元数。
用训练数据填充这些数组很容易。只需使用putScalar()方法:
我们不会直接使用INDArray对象。相反,我们将把它们转换成一个DataSet。
此时,我们可以通过调用神经网络的fit()方法并将数据集传递给它来开始训练。for循环控制通过网络的数据集的迭代。在本例中,它被设置为1000次迭代。
就这些。你的神经网络已经可以使用了。
在本教程中,您看到了在Android Studio项目中使用Deeplearning4J库创建和训练神经网络是多么容易。不过,我想提醒你的是,在低功耗、电池供电的设备上训练神经网络可能并不总是一个好主意。
第二个例子DL4J Android应用程序包括一个用户界面可以在这里找到。这个例子使用Anderson的iris数据集在设备上训练一个神经网络,用于iris类型分类。该应用程序包括用户输入的测量值,并返回这些测量值属于三种iris类型(Iris serosa, Iris versicolor, 和 Iris virginica)之一的概率。
移动设备处理能力和电池寿命的限制使得训练健壮、多层网络不可行。作为在设备上训练网络的替代方法,应用程序使用的神经网络可以在桌面机上训练,通过ModelSerializer保存,然后作为预先训练的模型加载到应用程序中。第三个例子DL4J Android应用程序可以在这里找到,它加载一个预先训练的MNIST网络,并使用它对用户绘制的数字进行分类。
Last updated
Was this helpful?