介绍
ND4J是JVM的科学计算库。它是用来在生产环境中使用的,而不是作为一个研究工具,这意味着常规的设计是为了以最低的RAM需求快速运行。主要特点是:
多平台功能,包括GPU。
所有主要操作系统: win/linux/osx/android.
此快速入门遵循与Numpy快速入门 相同的布局和方法。这将帮助熟悉Python和Numpy的人快速开始使用Nd4J。
先决条件
您可以从任何JVM语言 中使用Nd4J。(例如:Scala、Kotlin)。可以将Nd4J与任何构建工具一起使用。此快速入门中的示例代码使用以下内容:
为了提高可读性,我们向您展示System.out.println(...)
的输出。但是我们还没有在示例代码中显示print语句。如果您有信心知道如何使用maven和git,请随时跳到基础部分。在本节的其余部分中,我们将构建一个小型的“hello ND4J”应用程序,以验证先决条件设置是否正确。
执行以下命令从github获取项目。
Copy git clone https://github.com/RobAltena/HelloNd4J.git
cd HelloNd4J
mvn install
mvn exec:java -Dexec.mainClass="HelloNd4j"
当一切设置正确时,您应该看到以下输出:
基础
Nd4j的主要特点是具有多功能的n维阵列接口INDArray。为了提高性能,Nd4j使用堆外内存来存储数据。INDArray不同于标准Java数组。
INDArray x的一些关键属性和方法如下:
Copy import org . nd4j . linalg . factory . Nd4j ;
import org . nd4j . linalg . api . buffer . DataType ;
INDArray x = Nd4j . zeros ( 3 , 4 );
// 数组的轴数(维度)。
int dimensions = x . rank ();
// 数组的维数。每个维度的大小。
long [] shape = x . shape ();
// 元素的总数。
long length = x . length ();
// 数组元素的类型。
DataType dt = x . dataType ();
数组创建
要创建INDArray,可以使用Nd4j 类的静态工厂方法。
Nd4j.createFromArray
函数被重载,以便于从常规Java数组创建INDArrays。下面的示例使用Javadouble
数组。类似的create方法对于float、int和long重载。对于所有类型,Nd4j.createFromArray
函数都有高达4d的重载。
Copy double arr_2d[][] = {{ 1.0 , 2.0 , 3.0 } , { 4.0 , 5.0 , 6.0 } , { 7.0 , 8.0 , 9.0 }};
INDArray x_2d = Nd4j . createFromArray (arr_2d);
double arr_1d[] = { 1.0 , 2.0 , 3.0 };
INDArray x_1d = Nd4j . createFromArray (arr_1d);
Nd4j可以使用函数zeros
和ones
创建用0和1初始化的数组。rand函数允许您创建用随机值初始化的数组。创建的INDArray的默认数据类型是float。有些重载允许您设置数据类型。
Copy INDArray x = Nd4j . zeros ( 5 );
//[ 0, 0, 0, 0, 0], FLOAT
int [] shape = { 5 };
x = Nd4j . zeros ( DataType . DOUBLE , 5 );
//[ 0, 0, 0, 0, 0], DOUBLE
// 对于更高的维度,可以提供形状数组。二维随机矩阵示例:
int rows = 4 ;
int cols = 5 ;
int [] shape = {rows , cols};
INDArray x = Nd4j . rand (shape);
使用arange
函数创建一个均匀空间值数组:
Copy INDArray x = Nd4j . arange ( 5 );
// [ 0, 1.0000, 2.0000, 3.0000, 4.0000]
INDArray x = Nd4j . arange ( 2 , 7 );
// [ 2.0000, 3.0000, 4.0000, 5.0000, 6.0000]
linspace
函数允许您指定生成的点数:
Copy //开始数, 停止数, 个数.
INDArray x = Nd4j . linspace ( 1 , 10 , 5 );
// [ 1.0000, 3.2500, 5.5000, 7.7500, 10.0000]
// 对函数进行多点评估。
import static org . nd4j . linalg . ops . transforms . Transforms . sin ;
INDArray x = Nd4j . linspace ( 0.0 , Math . PI , 100 , DataType . DOUBLE );
INDArray y = sin(x) ;
打印数组
INDArray支持Java的toString()
方法。当前的实现具有有限的精度和有限的元素数。输出类似于打印NumPy数组:
Copy //1维数组
INDArray x = Nd4j . arange ( 6 );
//我们现在只给出print命令的输出。
System . out . println (x);
// [ 0, 1.0000, 2.0000, 3.0000, 4.0000, 5.0000]
int [] shape = { 4 , 3 };
//2维数组
x = Nd4j . arange ( 12 ) . reshape (shape);
/*
[[ 0, 1.0000, 2.0000],
[ 3.0000, 4.0000, 5.0000],
[ 6.0000, 7.0000, 8.0000],
[ 9.0000, 10.0000, 11.0000]]
*/
int [] shape2 = { 2 , 3 , 4 };
//3维数组
x = Nd4j . arange ( 24 ) . reshape (shape2);
/*
[[[ 0, 1.0000, 2.0000, 3.0000],
[ 4.0000, 5.0000, 6.0000, 7.0000],
[ 8.0000, 9.0000, 10.0000, 11.0000]],
[[ 12.0000, 13.0000, 14.0000, 15.0000],
[ 16.0000, 17.0000, 18.0000, 19.0000],
[ 20.0000, 21.0000, 22.0000, 23.0000]]]
*/
基本操作
必须使用INDArray方法对数组执行操作。有就地(in-place)和复制重载、标量和元素级重载版本。就地(in-place)运算符返回对数组的引用,因此可以方便地将操作链接在一起。尽可能使用就地(in-place)运算符来提高性能。复制运算符有新的数组创建开销。
Copy //复制
//返回一个新数组,并将标量添加到arr的每个元素。
arr_new = arr . add (scalar);
//返回一个新数组,它是arr和其他arr元素级别的加法。
arr_new = arr . add (other_arr);
//就地
arr_new = arr . addi (scalar);
arr_new = arr . addi (other_arr);
加法: arr.add(...), arr.addi(...) 减法: arr.sub(...), arr.subi(...) 乘法: arr.mul(...), arr.muli(...) 除法 : arr.div(...), arr.divi(...)
执行基本操作时,必须确保基础数据类型相同。
Copy int [] shape = { 5 };
INDArray x = Nd4j . zeros (shape , DataType . DOUBLE );
INDArray x2 = Nd4j . zeros (shape , DataType . INT );
INDArray x3 = x . add (x2);
// java.lang.IllegalArgumentException: Op.X 和 Op.Y must have the same data type, but got INT vs DOUBLE
// 将x2转换为DOUBLE可以解决以下问题:
INDArray x3 = x . add ( x2 . castTo ( DataType . DOUBLE ));
INDArray有实现缩减/累加操作的方法,如 sum
, min
, max
.
Copy int [] shape = { 2 , 3 };
INDArray x = Nd4j . rand (shape);
x;
x . sum ();
x . min ();
x . max ();
/*
[[ 0.8621, 0.9224, 0.8407],
[ 0.1504, 0.5489, 0.9584]]
4.2830
0.1504
0.9584
*/
提供维度参数以在指定维度上应用操作:
Copy INDArray x = Nd4j . arange ( 12 ) . reshape ( 3 , 4 );
/*
[[ 0, 1.0000, 2.0000, 3.0000],
[ 4.0000, 5.0000, 6.0000, 7.0000],
[ 8.0000, 9.0000, 10.0000, 11.0000]]
*/
//每列的总和。
x . sum ( 0 );
//[ 12.0000, 15.0000, 18.0000, 21.0000]
//每行最小值
x . min ( 1 );
//[ 0, 4.0000, 8.0000]
//每行的累计和,
x . cumsum ( 1 );
/*
[[ 0, 1.0000, 3.0000, 6.0000],
[ 4.0000, 9.0000, 15.0000, 22.0000],
[ 8.0000, 17.0000, 27.0000, 38.0000]]
*/
转换操作
Nd4j提供熟悉的数学函数,如sin、cos和exp,这些称为转换操作。结果作为INDArray返回。
Copy import static org . nd4j . linalg . ops . transforms . Transforms . exp ;
import static org . nd4j . linalg . ops . transforms . Transforms . sqrt ;
INDArray x = Nd4j . arange ( 3 );
// [ 0, 1.0000, 2.0000]
exp(x) ;
// [ 1.0000, 2.7183, 7.3891]
sqrt(x) ;
// [ 0, 1.0000, 1.4142]
您可以在Javadoc 中查看转换操作的完整列表
矩陈乘法
我们已经在基本操作中看到了元素的乘法运算。其他矩阵运算有自己的方法:
Copy INDArray x = Nd4j . arange ( 12 ) . reshape ( 3 , 4 );
/*
[[ 0, 1.0000, 2.0000, 3.0000],
[ 4.0000, 5.0000, 6.0000, 7.0000],
[ 8.0000, 9.0000, 10.0000, 11.0000]]
*/
INDArray y = Nd4j . arange ( 12 ) . reshape ( 4 , 3 );
/*
[[ 0, 1.0000, 2.0000],
[ 3.0000, 4.0000, 5.0000],
[ 6.0000, 7.0000, 8.0000],
[ 9.0000, 10.0000, 11.0000]]
*/
// 矩阵乘积.
x . mmul (y);
/*
[[ 42.0000, 48.0000, 54.0000],
[ 114.0000, 136.0000, 158.0000],
[ 186.0000, 224.0000, 262.0000]]
*/
//点积
INDArray x = Nd4j . arange ( 12 );
INDArray y = Nd4j . arange ( 12 );
dot(x , y) ;
//506.0000
索引、切片和迭代
索引、切片和迭代在Java中比Python更困难。 要从INDArray检索单个值,可以使用getDouble
、getFloat
或getInt
方法。INDArrays不能像Java数组那样被索引。可以使用toDoubleVector()
、toDoubleMatrix()
、toFloatVector()
和toFloatMatrix()
从INDArray获取Java数组。
Copy INDArray x = Nd4j . arange ( 12 );
// [ 0, 1.0000, 2.0000, 3.0000, 4.0000, 5.0000, 6.0000, 7.0000, 8.0000, 9.0000, 10.0000, 11.0000]
//单一元素访问。其他方法: getDouble, getInt, ...
float f = x . getFloat ( 3 );
// 3.0
//转换为Java数组。
float [] fArr = x . toFloatVector ();
// [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0]
INDArray x2 = x . get ( NDArrayIndex . interval ( 2 , 6 ));
// [ 2.0000, 3.0000, 4.0000, 5.0000]
//在x的副本上:从开始到位置6(不包括),将每2个元素设置为-1.0
INDArray y = x . dup ();
y . get ( NDArrayIndex . interval ( 0 , 2 , 6 )) . assign ( - 1.0 );
//[ -1.0000, 1.0000, -1.0000, 3.0000, -1.0000, 5.0000, 6.0000, 7.0000, 8.0000, 9.0000, 10.0000, 11.0000]
//y的反向副本。
INDArray y2 = Nd4j . reverse ( y . dup ());
//[ 11.0000, 10.0000, 9.0000, 8.0000, 7.0000, 6.0000, 5.0000, -1.0000, 3.0000, -1.0000, 1.0000, -1.0000]
对于多维数组,应该使用INDArray.get(NDArrayIndex...)
。下面的示例演示如何遍历二维数组的行和列。注意,对于2D数组,我们可以使用getColumn
和getRow
便利方法。
Copy // 在2d数组的行和列上迭代。
int rows = 4 ;
int cols = 5 ;
int [] shape = {rows , cols};
INDArray x = Nd4j . rand (shape);
/*
[[ 0.2228, 0.2871, 0.3880, 0.7167, 0.9951],
[ 0.7181, 0.8106, 0.9062, 0.9291, 0.5115],
[ 0.5483, 0.7515, 0.3623, 0.7797, 0.5887],
[ 0.6822, 0.7785, 0.4456, 0.4231, 0.9157]]
*/
for ( int row = 0 ; row < rows; row ++ ) {
INDArray y = x . get ( NDArrayIndex . point (row) , NDArrayIndex . all ());
}
/*
[ 0.2228, 0.2871, 0.3880, 0.7167, 0.9951]
[ 0.7181, 0.8106, 0.9062, 0.9291, 0.5115]
[ 0.5483, 0.7515, 0.3623, 0.7797, 0.5887]
[ 0.6822, 0.7785, 0.4456, 0.4231, 0.9157]
*/
for ( int col = 0 ; col < cols; col ++ ) {
INDArray y = x . get ( NDArrayIndex . all () , NDArrayIndex . point (col));
}
/*
[ 0.2228, 0.7181, 0.5483, 0.6822]
[ 0.2871, 0.8106, 0.7515, 0.7785]
[ 0.3880, 0.9062, 0.3623, 0.4456]
[ 0.7167, 0.9291, 0.7797, 0.4231]
[ 0.9951, 0.5115, 0.5887, 0.9157]
*/
形状操作
改变数组的形状
沿着每个轴的元素的数量是形状。形状可以用各种方法改变。
Copy INDArray x = Nd4j . rand ( 3 , 4 );
x . shape ();
// [3, 4]
INDArray x2 = x . ravel ();
x2 . shape ();
// [12]
INDArray x3 = x . reshape ( 6 , 2 ) . shape ();
x3 . shape ();
//[6, 2]
//注意x、x2和x3共享相同的数据。
x2 . putScalar ( 5 , - 1.0 );
System . out . println ( x);
/*
[[ 0.0270, 0.3799, 0.5576, 0.3086],
[ 0.2266, -1.0000, 0.1107, 0.4895],
[ 0.8431, 0.6011, 0.2996, 0.7500]]
*/
System . out . println ( x2);
// [ 0.0270, 0.3799, 0.5576, 0.3086, 0.2266, -1.0000, 0.1107, 0.4895, 0.8431, 0.6011, 0.2996, 0.7500]
System . out . println ( x3);
/*
[[ 0.0270, 0.3799],
[ 0.5576, 0.3086],
[ 0.2266, -1.0000],
[ 0.1107, 0.4895],
[ 0.8431, 0.6011],
[ 0.2996, 0.7500]]
*/
将不同的数组堆叠在一起
可以使用vstack
和hstack
方法将数组堆叠在一起。
Copy INDArray x = Nd4j . rand ( 2 , 2 );
INDArray y = Nd4j . rand ( 2 , 2 );
x
/*
[[ 0.1462, 0.5037],
[ 0.1418, 0.8645]]
*/
y;
/*
[[ 0.2305, 0.4798],
[ 0.9407, 0.9735]]
*/
Nd4j . vstack (x , y);
/*
[[ 0.1462, 0.5037],
[ 0.1418, 0.8645],
[ 0.2305, 0.4798],
[ 0.9407, 0.9735]]
*/
Nd4j . hstack (x , y);
/*
[[ 0.1462, 0.5037, 0.2305, 0.4798],
[ 0.1418, 0.8645, 0.9407, 0.9735]]
*/
副本和视图
使用INDArrays时,并不总是复制数据。这里有三种情况你应该注意。
完全没有复制
简单的指派不会复制数据。Java通过引用传递对象。在方法调用上不生成任何副本。
Copy INDArray x = Nd4j . rand ( 2 , 2 );
//y和x指向同一个INDArray对象。
INDArray y = x;
public static void f( INDArray x) {
//没有副本。对x的任何更改在函数调用之后都是可见的。
}
视图或浅复制
一些函数将返回数组的视图。
Copy INDArray x = Nd4j . rand ( 3 , 4 );
INDArray x2 = x . ravel ();
INDArray x3 = x . reshape ( 6 , 2 );
// 修改 x, x2 和 x3
x2 . putScalar ( 5 , - 1.0 );
x
/*
[[ 0.8546, 0.1509, 0.0331, 0.1308],
[ 0.1753, -1.0000, 0.2277, 0.1998],
[ 0.2741, 0.8257, 0.6946, 0.6851]]
*/
x2
// [ 0.8546, 0.1509, 0.0331, 0.1308, 0.1753, -1.0000, 0.2277, 0.1998, 0.2741, 0.8257, 0.6946, 0.6851]
x3
/*
[[ 0.8546, 0.1509],
[ 0.0331, 0.1308],
[ 0.1753, -1.0000],
[ 0.2277, 0.1998],
[ 0.2741, 0.8257],
[ 0.6946, 0.6851]]
*/
深考贝
要复制数组,请使用dup
方法。这将为您提供一个包含新数据的新数组。
Copy INDArray x = Nd4j . rand ( 3 , 4 );
INDArray x2 = x . ravel () . dup ();
//现在只改变x2。
x2 . putScalar ( 5 , - 1.0 );
x
/*
[[ 0.1604, 0.0322, 0.8910, 0.4604],
[ 0.7724, 0.1267, 0.1617, 0.7586],
[ 0.6117, 0.5385, 0.1251, 0.6886]]
*/
x2
// [ 0.1604, 0.0322, 0.8910, 0.4604, 0.7724, -1.0000, 0.1617, 0.7586, 0.6117, 0.5385, 0.1251, 0.6886]
功能和方法概述
数组创建
arange , create , copy , empty , empty_like , eye , linspace , meshgrid , ones , ones_like , rand , readTxt , zeros , zeros_like
转换
convertToDoubles , convertToFloats , convertToHalfs
操作
concatenate , hstack , ravel , repeat , reshape , squeeze , swapaxes , tear , transpose , vstack
排序
argmax , max , min , sort
操作
choice , cumsum , mmul , prod , put , putWhere , sum
基本统计
covarianceMatrix , mean , std , var
基本线性代数
cross , dot , gesvd , mmul