Deeplearning4j
Community ForumND4J JavadocDL4J Javadoc
ZH 1.0.0-beta6
ZH 1.0.0-beta6
  • 核心概念
  • 开始
    • 快速入门
    • 速查表
    • 示例教程
    • 初学者
    • Eclipse贡献者
    • 从源码构建
    • 贡献
    • 基准测试准则
    • 关于
    • 发行说明
  • 配置
    • GPU/CPU设置
    • CPU 与 AVX
    • 内存管理
    • Maven
    • SBT/Gradle和其它构建工具
    • cuDNN
    • 快照
    • 内存工作间
  • ND4J
    • 快速入门
    • 概述
  • SAMEDIFF
    • 变量
    • 操作
    • 添加操作
  • 调优与训练
    • 故障排查
    • 可视化
    • 评估
    • 迁移学习
    • 早停
    • T-SNE数据可视化
  • 分布式深度学习
    • 介绍与入门
    • 在Spark上使用DL4J:操作指南
    • 技术说明
    • Spark数据管道指南
    • API参考
    • 参数服务器
  • Keras导入
    • 概述
    • 入门
    • 支持功能
      • 正则化器
      • 损失
      • 初始化器
      • 约束
      • 激活
      • 优化器
    • Functional模型
    • Sequential模型
  • ARBITER
    • 概述
    • 层空间
    • 参数空间
  • DATAVEC
    • 概述
    • 记录
    • 概要
    • 序列化
    • 转换
    • 分析
    • 读取器
    • 执行器
    • 过滤器
    • 运算
  • 语言处理
    • 概述
    • Word2Vec
    • Doc2Vec
    • SentenceIterator
    • Tokenization
    • Vocabulary Cache
  • 模型
    • 计算图
    • 多层网络
    • 循环神经网络
    • 层
    • 顶点
    • 迭代器
    • 监听器
    • 自定义层
    • 模型持久化
    • 动物园用法
    • 激活
    • 更新器
  • 移动端
    • Android概述
    • Android先决条件
    • Android分类器
    • Android图片分类器
  • FAQ
  • 新闻
  • 支持
  • 为什么要深度学习?
Powered by GitBook
On this page
  • DL4J: Keras模型导入
  • 入门:在60秒内导入一个Keras模型
  • 项目设置
  • 后端
  • 流行的模型与应用
  • 故障排除与支持
  • 为什么需要Keras模型导入?

Was this helpful?

Edit on Git
Export as PDF
  1. Keras导入

概述

模型导入概述

Previous参数服务器Next入门

Last updated 5 years ago

Was this helpful?

DL4J: Keras模型导入

为导入最初用配置和训练的神经网络模型提供了例程,Keras是一个流行的Python深度学习库。

一旦你的模型导入到DL4J,我们的整个生产栈是由你来处理的。我们支持导入所有的Keras模型类型、大多数层和几乎所有的实用功能。请在查看支持的Keras特性的完整列表。

入门:在60秒内导入一个Keras模型

要导入Keras模型,首先需要创建和这样的模型。这里有一个你可以使用的简单例子。该模型是一个简单的MLP,它采用长度为100的小批量向量,具有两个密集层,并预测总共10个类别。在定义模型之后,我们将其序列化为HDF5格式。

from keras.models import Sequential
from keras.layers import Dense

model = Sequential()
model.add(Dense(units=64, activation='relu', input_dim=100))
model.add(Dense(units=10, activation='softmax'))
model.compile(loss='categorical_crossentropy',optimizer='sgd', metrics=['accuracy'])

model.save('simple_mlp.h5')

如果将这个模型文件(simple_mlp.h5)放到项目的资源文件夹的根目录中,则可以将Keras模型加载为DL4J 的 MultiLayerNetwork,如下所示

String simpleMlp = new ClassPathResource("simple_mlp.h5").getFile().getPath();
MultiLayerNetwork model = KerasModelImport.importKerasSequentialModelAndWeights(simpleMlp);

现在可以使用导入的模型进行推断(这里使用简单的数据来简化)

INDArray input = Nd4j.create(256, 100);
INDArray output = model.output(input);

以下是你如何在DL4J中为你导入的模型做训练:

model.fit(input, output);

项目设置

要在现有项目中使用Keras模型导入,只需要将下列依赖项添加到pom.xml中。

<dependency>
    <groupId>org.deeplearning4j</groupId>
    <artifactId>deeplearning4j-modelimport</artifactId>
    <version>1.0.0-beta</version> // 此版本应与其他DL4J项目依赖项匹配。
</dependency>

后端

DL4J Keras 模型导入与后端无关。 不管你选择哪一个后端 (TensorFlow, Theano, CNTK), 你的模号可以导入DL4J。

流行的模型与应用

  • Deep convolutional and Wasserstein GANs

  • UNET

  • ResNet50

  • SqueezeNet

  • MobileNet

  • Inception

  • Xception

故障排除与支持

IncompatibleKerasConfigurationException信息说明你正在尝试导入一个当前不被DL4J支持的Keras模型(要么因为模型导入不覆盖它,要么DL4J不实现该层或特征)。

一旦导入了模型,我们就推荐我们自己的“ModelSerializer”类来进一步保存和重新加载模型。

为什么需要Keras模型导入?

Keras是用Python编写的一个流行的、用户友好的深度学习库。Keras的直观API使得Python轻松地定义和运行你的深度学习模型。Keras允许你选择它在哪个底层库上运行,但是为每个这样的后端提供了统一的API。目前,Keras支持Tensorflow、CNTK和Theano后端,但是Skymind也在为Keras开发ND4J后端。

一个公司的生产系统和它的数据科学家的实验设置之间经常有差距。Keras模型导入允许数据科学家用Python编写他们的模型,但是仍然与生产栈无缝集成。

Keras模型导入主要针对在Python中熟悉用Keras编写模型的用户。通过模型导入,你可以通过允许用户将模型导入DL4J生态圈以进行进一步的训练或评估,从而将Python模型带到生产中。

当项目的试验阶段完成并且需要将模型交付生产时,你应该使用这个模块。Skymind商业支持Keras在企业中的实现。

就是这样!KerasModelImport是模型导入的主要入口点,类负责在内部将Keras映射到DL4J概念。作为用户,你只需要提供你的模型文件,请参阅我们的,以了解将Keras模型加载到DL4J中的更多细节和选项。

在我们的中可以找到刚才所示的完整示例。

如果首先需要开始一个项目,请考虑克隆,并按照仓库中的说明来构建项目。

我们支持为越来越多的应用程序导入,在查看当前所覆盖的模型的完整列表。这些应用包括

你可以通过访问频道进一步咨询。你可能会考虑,这样这个缺失的功能可以放在DL4J开发路线图上,或者甚至向我们发送一个带有必要更改的pull请求!

Keras模型导入
Keras
这里
序列化
入门指南
DL4J示例
DL4J示例
这里
DL4J gitter
通过Github来提交一个特性请求