导入Keras函数模型入门
假设你使用Keras的函数API开始定义一个简单的MLP:
Copy from keras.models import Model
from keras.layers import Dense, Input
inputs = Input(shape=(100,))
x = Dense(64, activation='relu')(inputs)
predictions = Dense(10, activation='softmax')(x)
model = Model(inputs=inputs, outputs=predictions)
model.compile(loss='categorical_crossentropy',optimizer='sgd', metrics=['accuracy'])
在Keras,有几种保存模型的方法。你可以将整个模型(模型定义、权重和训练配置)存储为HDF5文件,仅存储模型配置(作为JSON或YAML文件)或仅存储权重(作为HDF5文件)。以下是你如何做每一件事:
Copy model.save('full_model.h5') # save everything in HDF5 format
model_json = model.to_json() # save just the config. replace with "to_yaml" for YAML serialization
with open("model_config.json", "w") as f:
f.write(model_json)
model.save_weights('model_weights.h5') # save just the weights.
如果你决定保存完整的模型,那么你将能够访问模型的训练配置,否则你将不访问。因此,如果你想在导入之后在DL4J中进一步训练模型,请记住这一点,并使用model.save(...)来持久化你的模型。
载加你的Keras模型
让我们从推荐的方法开始,将完整模型加载回DL4J(我们假设它在类路径上):
Copy String fullModel = new ClassPathResource("full_model.h5").getFile().getPath();
ComputationGraph model = KerasModelImport.importKerasModelAndWeights(fullModel);
万一你没有编译你的Keras模型,它就不会有一个训练配置。在这种情况下,你需要显式地告诉模型导入忽略训练配置,方法是将enforceTrainingConfig标志设置为false,如下所示:
Copy ComputationGraph model = KerasModelImport.importKerasModelAndWeights(fullModel, false);
若要仅从JSON加载模型配置,请按如下使用KerasModelImport
Copy String modelJson = new ClassPathResource("model_config.json").getFile().getPath();
ComputationGraphConfiguration modelConfig = KerasModelImport.importKerasModelConfiguration(modelJson)
如果另外你还想加载模型权重与配置,那么以下是你要做的:
Copy String modelWeights = new ClassPathResource("model_weights.h5").getFile().getPath();
MultiLayerNetwork network = KerasModelImport.importKerasModelAndWeights(modelJson, modelWeights)
在后面两种情况下,将不读取训练配置。
KerasModel
[源码]
从Keras(函数API)模型或序列模型配置构建计算图。
KerasModel
Copy public KerasModel(KerasModelBuilder modelBuilder)
throws UnsupportedKerasConfigurationException, IOException, InvalidKerasConfigurationException
(建议)(函数API)模型的构建器模式构造器。
抛出 InvalidKerasConfigurationException 无效的 Keras 配置
抛出 UnsupportedKerasConfigurationException 不支持的 Keras 配置
getComputationGraphConfiguration
Copy public ComputationGraphConfiguration getComputationGraphConfiguration()
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException
(不推荐)来自模型配置(JSON或YAML)、训练配置(JSON)、权重和“训练模式”布尔指示符的(函数 API)模型的构造器。当内置在训练模式时,某些不支持的配置(例如,未知的正则化器)将抛出异常。当强制TrainingConfig= false时,这些将生成警告,但将被忽略。
参数 modelJson 模型配置JSON 字符串
参数 modelYaml 模型配置 YAML 字符串
参数 enforceTrainingConfig 是否实施训练相关配置
抛出 InvalidKerasConfigurationException 无效的 Keras 配置
抛出 UnsupportedKerasConfigurationException 不支持的 Keras 配置
getComputationGraph
Copy public ComputationGraph getComputationGraph()
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException
从这个Keras模型配置构建计算图并导入权重。
getComputationGraph
Copy public ComputationGraph getComputationGraph(boolean importWeights)
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException
从这个Keras模型配置构建计算图并(可选的)导入权重。