implementation (group: 'org.deeplearning4j', name: 'deeplearning4j-core', version: '{{page.version}}') {
exclude group: 'org.bytedeco', module: 'opencv-platform'
exclude group: 'org.bytedeco', module: 'leptonica-platform'
exclude group: 'org.bytedeco', module: 'hdf5-platform'
exclude group: 'org.nd4j', module: 'nd4j-base64'
}
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}'
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm"
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm64"
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86"
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86_64"
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2'
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-arm"
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-arm64"
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-x86"
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-x86_64"
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2'
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-arm"
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-arm64"
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-x86"
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-x86_64"
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2'
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-arm"
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-arm64"
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-x86"
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-x86_64"
implementation 'com.google.code.gson:gson:2.8.2'
annotationProcessor 'org.projectlombok:lombok:1.16.16'
//This corrects for a junit version conflict.
configurations.all {
resolutionStrategy.force 'junit:junit:4.12'
}multiDexEnabled trueconfigurations.all {
resolutionStrategy.force 'junit:junit:4.12'
}private class AsyncTaskRunner extends AsyncTask<String, Integer, INDArray> {
//在调用后台线程之前在UI中运行。
@Override
protected void onPreExecute() {
super.onPreExecute();
}
@Override
protected INDArray doInBackground(String... params) {
//主后台线程,这将加载模型并测试输入图像
//图像的尺寸设置在这里
int height = 28;
int width = 28;
int channels = 1;
//现在我们使用try/catch块从raw文件夹加载模型
try {
// Load the pretrained network.
InputStream inputStream = getResources().openRawResource(R.raw.trained_mnist_model);
MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(inputStream);
//加载要测试的图像文件
File f=new File(absolutePath, "drawn_image.jpg");
//使用nativeImageLoader转换为数值矩阵
NativeImageLoader loader = new NativeImageLoader(height, width, channels);
//将图像放入INDArray
INDArray image = loader.asMatrix(f);
//值需要缩放
DataNormalization scalar = new ImagePreProcessingScaler(0, 1);
//然后调用图像数据集上的标量
scalar.transform(image);
//通过神经网络存储在输出数组中
output = model.output(image);
} catch (IOException e) {
e.printStackTrace();
}
return output;
} //图形输入代码
public class DrawingView extends View {
private Path mPath;
private Paint mBitmapPaint;
private Paint mPaint;
private Bitmap mBitmap;
private Canvas mCanvas;
public DrawingView(Context c) {
super(c);
mPath = new Path();
mBitmapPaint = new Paint(Paint.DITHER_FLAG);
mPaint = new Paint();
mPaint.setAntiAlias(true);
mPaint.setStrokeJoin(Paint.Join.ROUND);
mPaint.setStrokeCap(Paint.Cap.ROUND);
mPaint.setStrokeWidth(60);
mPaint.setDither(true);
mPaint.setColor(Color.WHITE);
mPaint.setStyle(Paint.Style.STROKE);
}
@Override
protected void onSizeChanged(int W, int H, int oldW, int oldH) {
super.onSizeChanged(W, H, oldW, oldH);
mBitmap = Bitmap.createBitmap(W, H, Bitmap.Config.ARGB_4444);
mCanvas = new Canvas(mBitmap);
}
@Override
protected void onDraw(Canvas canvas) {
canvas.drawBitmap(mBitmap, 0, 0, mBitmapPaint);
canvas.drawPath(mPath, mPaint);
}
private float mX, mY;
private static final float TOUCH_TOLERANCE = 4;
private void touch_start(float x, float y) {
mPath.reset();
mPath.moveTo(x, y);
mX = x;
mY = y;
}
private void touch_move(float x, float y) {
float dx = Math.abs(x - mX);
float dy = Math.abs(y - mY);
if (dx >= TOUCH_TOLERANCE || dy >= TOUCH_TOLERANCE) {
mPath.quadTo(mX, mY, (x + mX)/2, (y + mY)/2);
mX = x;
mY = y;
}
}
private void touch_up() {
mPath.lineTo(mX, mY);
mCanvas.drawPath(mPath, mPaint);
mPath.reset();
}
@Override
public boolean onTouchEvent(MotionEvent event) {
float x = event.getX();
float y = event.getY();
switch (event.getAction()) {
case MotionEvent.ACTION_DOWN:
invalidate();
clear();
touch_start(x, y);
invalidate();
break;
case MotionEvent.ACTION_MOVE:
touch_move(x, y);
invalidate();
break;
case MotionEvent.ACTION_UP:
touch_up();
absolutePath = saveDrawing();
invalidate();
clear();
loadImageFromStorage(absolutePath);
onProgressBar();
//保存图像后立即启动asyncTask
AsyncTaskRunner runner = new AsyncTaskRunner();
runner.execute(absolutePath);
break;
}
return true;
}
public void clear(){
mBitmap.eraseColor(Color.TRANSPARENT);
invalidate();
System.gc();
}
}public String saveDrawing(){
drawingView.setDrawingCacheEnabled(true);
Bitmap b = drawingView.getDrawingCache();
ContextWrapper cw = new ContextWrapper(getApplicationContext());
//设置存储路径
File directory = cw.getDir("imageDir", Context.MODE_PRIVATE);
//创建imageDir并将文件存储在那里。每个新图形都将覆盖上一个
File mypath=new File(directory,"drawn_image.jpg");
//使用fileOutputStream将文件写入try/catch块中的位置
FileOutputStream fos = null;
try {
fos = new FileOutputStream(mypath);
b.compress(Bitmap.CompressFormat.JPEG, 100, fos);
} catch (Exception e) {
e.printStackTrace();
} finally {
try {
fos.close();
} catch (IOException e) {
e.printStackTrace();
}
}
return directory.getAbsolutePath();
} private void loadImageFromStorage(String path)
{
//使用fileOutputStream将文件写入try/catch块中的位置
try {
File f=new File(path, "drawn_image.jpg");
Bitmap b = BitmapFactory.decodeStream(new FileInputStream(f));
ImageView img=(ImageView)findViewById(R.id.outputView);
img.setImageBitmap(b);
}
catch (FileNotFoundException e)
{
e.printStackTrace();
}
} //helper类返回输出数组中的最大值
public static double arrayMaximum(double[] arr) {
double max = Double.NEGATIVE_INFINITY;
for(double cur: arr)
max = Math.max(max, cur);
return max;
}
//helper类查找最大置信分数的索引(因此是数值)
public int getIndexOfLargestValue( double[] array )
{
if ( array == null || array.length == 0 ) return -1;
int largest = 0;
for ( int i = 1; i < array.length; i++ )
{if ( array[i] > array[largest] ) largest = i; }
return largest;
} public void onProgressBar(){
TextView bar = findViewById(R.id.processing);
bar.setVisibility(View.VISIBLE);
}
public void offProgressBar(){
TextView bar = findViewById(R.id.processing);
bar.setVisibility(View.INVISIBLE);
}public class MainActivity extends AppCompatActivity {
MainActivity.DrawingView drawingView;
String absolutePath;
public static INDArray output;
@Override
public void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
RelativeLayout parent = findViewById(R.id.layout2);
drawingView = new MainActivity.DrawingView(this);
parent.addView(drawingView);
} @Override
protected void onProgressUpdate(Integer... values) {
super.onProgressUpdate(values);
} @Override
protected void onPostExecute(INDArray result) {
super.onPostExecute(result);
//用于控制输出概率的小数位数
DecimalFormat df2 = new DecimalFormat(".##");
//将神经网络输出传输到数组
double[] results = {result.getDouble(0,0),result.getDouble(0,1),result.getDouble(0,2),
result.getDouble(0,3),result.getDouble(0,4),result.getDouble(0,5),result.getDouble(0,6),
result.getDouble(0,7),result.getDouble(0,8),result.getDouble(0,9),};
//找到UI tvs以显示预测值和置信值
TextView out1 = findViewById(R.id.prediction);
TextView out2 = findViewById(R.id.confidence);
//使用下面定义的助手函数显示值
out2.setText(String.valueOf(df2.format(arrayMaximum(results))));
out1.setText(String.valueOf(getIndexOfLargestValue(results)));
//关闭进度测试的助手函数
offProgressBar();
}