Deeplearning4j
Community ForumND4J JavadocDL4J Javadoc
ZH 1.0.0-beta6
ZH 1.0.0-beta6
  • 核心概念
  • 开始
    • 快速入门
    • 速查表
    • 示例教程
    • 初学者
    • Eclipse贡献者
    • 从源码构建
    • 贡献
    • 基准测试准则
    • 关于
    • 发行说明
  • 配置
    • GPU/CPU设置
    • CPU 与 AVX
    • 内存管理
    • Maven
    • SBT/Gradle和其它构建工具
    • cuDNN
    • 快照
    • 内存工作间
  • ND4J
    • 快速入门
    • 概述
  • SAMEDIFF
    • 变量
    • 操作
    • 添加操作
  • 调优与训练
    • 故障排查
    • 可视化
    • 评估
    • 迁移学习
    • 早停
    • T-SNE数据可视化
  • 分布式深度学习
    • 介绍与入门
    • 在Spark上使用DL4J:操作指南
    • 技术说明
    • Spark数据管道指南
    • API参考
    • 参数服务器
  • Keras导入
    • 概述
    • 入门
    • 支持功能
      • 正则化器
      • 损失
      • 初始化器
      • 约束
      • 激活
      • 优化器
    • Functional模型
    • Sequential模型
  • ARBITER
    • 概述
    • 层空间
    • 参数空间
  • DATAVEC
    • 概述
    • 记录
    • 概要
    • 序列化
    • 转换
    • 分析
    • 读取器
    • 执行器
    • 过滤器
    • 运算
  • 语言处理
    • 概述
    • Word2Vec
    • Doc2Vec
    • SentenceIterator
    • Tokenization
    • Vocabulary Cache
  • 模型
    • 计算图
    • 多层网络
    • 循环神经网络
    • 层
    • 顶点
    • 迭代器
    • 监听器
    • 自定义层
    • 模型持久化
    • 动物园用法
    • 激活
    • 更新器
  • 移动端
    • Android概述
    • Android先决条件
    • Android分类器
    • Android图片分类器
  • FAQ
  • 新闻
  • 支持
  • 为什么要深度学习?
Powered by GitBook
On this page
  • 导入Keras函数模型入门
  • 载加你的Keras模型

Was this helpful?

Edit on Git
Export as PDF
  1. Keras导入

Functional模型

导入functional模型

导入Keras函数模型入门

假设你使用Keras的函数API开始定义一个简单的MLP:

from keras.models import Model
from keras.layers import Dense, Input

inputs = Input(shape=(100,))
x = Dense(64, activation='relu')(inputs)
predictions = Dense(10, activation='softmax')(x)
model = Model(inputs=inputs, outputs=predictions)
model.compile(loss='categorical_crossentropy',optimizer='sgd', metrics=['accuracy'])

在Keras,有几种保存模型的方法。你可以将整个模型(模型定义、权重和训练配置)存储为HDF5文件,仅存储模型配置(作为JSON或YAML文件)或仅存储权重(作为HDF5文件)。以下是你如何做每一件事:

model.save('full_model.h5')  # save everything in HDF5 format

model_json = model.to_json()  # save just the config. replace with "to_yaml" for YAML serialization
with open("model_config.json", "w") as f:
    f.write(model_json)

model.save_weights('model_weights.h5') # save just the weights.

如果你决定保存完整的模型,那么你将能够访问模型的训练配置,否则你将不访问。因此,如果你想在导入之后在DL4J中进一步训练模型,请记住这一点,并使用model.save(...)来持久化你的模型。

载加你的Keras模型

让我们从推荐的方法开始,将完整模型加载回DL4J(我们假设它在类路径上):

String fullModel = new ClassPathResource("full_model.h5").getFile().getPath();
ComputationGraph model = KerasModelImport.importKerasModelAndWeights(fullModel);

万一你没有编译你的Keras模型,它就不会有一个训练配置。在这种情况下,你需要显式地告诉模型导入忽略训练配置,方法是将enforceTrainingConfig标志设置为false,如下所示:

ComputationGraph model = KerasModelImport.importKerasModelAndWeights(fullModel, false);

若要仅从JSON加载模型配置,请按如下使用KerasModelImport

String modelJson = new ClassPathResource("model_config.json").getFile().getPath();
ComputationGraphConfiguration modelConfig = KerasModelImport.importKerasModelConfiguration(modelJson)

如果另外你还想加载模型权重与配置,那么以下是你要做的:

String modelWeights = new ClassPathResource("model_weights.h5").getFile().getPath();
MultiLayerNetwork network = KerasModelImport.importKerasModelAndWeights(modelJson, modelWeights)

在后面两种情况下,将不读取训练配置。

KerasModel

从Keras(函数API)模型或序列模型配置构建计算图。

KerasModel

public KerasModel(KerasModelBuilder modelBuilder)
            throws UnsupportedKerasConfigurationException, IOException, InvalidKerasConfigurationException

(建议)(函数API)模型的构建器模式构造器。

  • 参数 modelBuilder 构建器对象

  • 抛出 IOException IO 异常

  • 抛出 InvalidKerasConfigurationException 无效的 Keras 配置

  • 抛出 UnsupportedKerasConfigurationException 不支持的 Keras 配置

getComputationGraphConfiguration

public ComputationGraphConfiguration getComputationGraphConfiguration()
            throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException

(不推荐)来自模型配置(JSON或YAML)、训练配置(JSON)、权重和“训练模式”布尔指示符的(函数 API)模型的构造器。当内置在训练模式时,某些不支持的配置(例如,未知的正则化器)将抛出异常。当强制TrainingConfig= false时,这些将生成警告,但将被忽略。

  • 参数 modelJson 模型配置JSON 字符串

  • 参数 modelYaml 模型配置 YAML 字符串

  • 参数 enforceTrainingConfig 是否实施训练相关配置

  • 抛出 IOException IO 异常

  • 抛出 InvalidKerasConfigurationException 无效的 Keras 配置

  • 抛出 UnsupportedKerasConfigurationException 不支持的 Keras 配置

getComputationGraph

public ComputationGraph getComputationGraph()
            throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException

从这个Keras模型配置构建计算图并导入权重。

  • 返回 ComputationGraph

getComputationGraph

public ComputationGraph getComputationGraph(boolean importWeights)
            throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException

从这个Keras模型配置构建计算图并(可选的)导入权重。

  • 参数 importWeights 是否导入权重

  • 返回 ComputationGraph

Previous优化器NextSequential模型

Last updated 5 years ago

Was this helpful?

[源码]