Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
include model trained by Keras
  • Loading branch information
Piasy committed Jun 1, 2017
commit c1008fbe67eeb29556b391f1d9aec2cd80c064eb
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@
import android.view.View;
import android.widget.Button;
import android.widget.TextView;
import java.util.ArrayList;
import java.util.List;
import mariannelinhares.mnistandroid.models.Classification;
import mariannelinhares.mnistandroid.models.Classifier;
import mariannelinhares.mnistandroid.models.TensorFlowClassifier;
import mariannelinhares.mnistandroid.views.DrawModel;
import mariannelinhares.mnistandroid.views.DrawView;

Expand All @@ -38,18 +43,12 @@

public class MainActivity extends Activity implements View.OnClickListener, View.OnTouchListener {

// tensorflow input and output
private static final int INPUT_SIZE = 28;
private static final String INPUT_NAME = "input";
private static final String OUTPUT_NAME = "output";
private static final String MODEL_FILE = "opt_mnist_convnet.pb";
private static final String LABEL_FILE = "labels.txt";
private static final int PIXEL_WIDTH = 28;

// ui related
private Button clearBtn, classBtn;
private TextView resText;
private Classifier classifier;
private List<Classifier> mClassifiers = new ArrayList<>();

// views related
private DrawModel drawModel;
Expand Down Expand Up @@ -103,14 +102,16 @@ private void loadModel() {
@Override
public void run() {
try {
classifier = Classifier.create(getApplicationContext().getAssets(),
MODEL_FILE,
LABEL_FILE,
INPUT_SIZE,
INPUT_NAME,
OUTPUT_NAME);
mClassifiers.add(
TensorFlowClassifier.create(getAssets(), "TensorFlow",
"opt_mnist_convnet-tf.pb", "labels.txt", PIXEL_WIDTH,
"input", "output", true));
mClassifiers.add(
TensorFlowClassifier.create(getAssets(), "Keras",
"opt_mnist_convnet-keras.pb", "labels.txt", PIXEL_WIDTH,
"conv2d_1_input", "dense_2/Softmax", false));
} catch (final Exception e) {
throw new RuntimeException("Error initializing TensorFlow!", e);
throw new RuntimeException("Error initializing classifiers!", e);
}
}
}).start();
Expand All @@ -123,19 +124,21 @@ public void onClick(View view) {
drawView.reset();
drawView.invalidate();

resText.setText("Result: ");
resText.setText("");
} else if (view.getId() == R.id.btn_class) {
float pixels[] = drawView.getPixelData();

final Classification res = classifier.recognize(pixels);
String result = "Result: ";
if (res.getLabel() == null) {
resText.setText(result + "?");
} else {
result += res.getLabel();
result += "\nwith probability: " + res.getConf();
resText.setText(result);
String text = "";
for (Classifier classifier : mClassifiers) {
final Classification res = classifier.recognize(pixels);
if (res.getLabel() == null) {
text += classifier.name() + ": ?\n";
} else {
text += String.format("%s: %s, %f\n", classifier.name(), res.getLabel(),
res.getConf());
}
}
resText.setText(text);
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package mariannelinhares.mnistandroid;
package mariannelinhares.mnistandroid.models;

/**
* Created by marianne-linhares on 20/04/17.
Expand All @@ -9,16 +9,12 @@ public class Classification {
private float conf;
private String label;

public Classification(float conf, String label) {
update(conf, label);
}

public Classification() {
this.conf = (float)-1.0;
Classification() {
this.conf = -1.0F;
this.label = null;
}

public void update(float conf, String label) {
void update(float conf, String label) {
this.conf = conf;
this.label = label;
}
Expand All @@ -30,5 +26,4 @@ public String getLabel() {
public float getConf() {
return conf;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package mariannelinhares.mnistandroid.models;

/**
* Created by Piasy{github.com/Piasy} on 29/05/2017.
*/

public interface Classifier {
String name();

Classification recognize(final float[] pixels);
}
Original file line number Diff line number Diff line change
@@ -1,38 +1,38 @@
package mariannelinhares.mnistandroid;
package mariannelinhares.mnistandroid.models;

import android.content.res.AssetManager;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;

import org.tensorflow.contrib.android.TensorFlowInferenceInterface;

/**
* Changed from https://github.com/MindorksOpenSource/AndroidTensorFlowMNISTExample/blob/master/app/src/main/java/com/mindorks/tensorflowexample/TensorFlowImageClassifier.java
* Changed from https://github.com/MindorksOpenSource/AndroidTensorFlowMNISTExample/blob/master
* /app/src/main/java/com/mindorks/tensorflowexample/TensorFlowImageClassifier.java
* Created by marianne-linhares on 20/04/17.
*/

public class Classifier {
public class TensorFlowClassifier implements Classifier {

// Only returns if at least this confidence
private static final float THRESHOLD = 0.1f;

private TensorFlowInferenceInterface tfHelper;

private String name;
private String inputName;
private String outputName;
private int inputSize;
private boolean feedKeepProb;

private List<String> labels;
private float[] output;
private String[] outputNames;

static private List<String> readLabels(Classifier c, AssetManager am, String fileName) throws IOException {
BufferedReader br = null;
br = new BufferedReader(new InputStreamReader(am.open(fileName)));
private static List<String> readLabels(AssetManager am, String fileName) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(am.open(fileName)));

String line;
List<String> labels = new ArrayList<>();
Expand All @@ -44,36 +44,46 @@ static private List<String> readLabels(Classifier c, AssetManager am, String fil
return labels;
}

public static TensorFlowClassifier create(AssetManager assetManager, String name,
String modelPath, String labelFile, int inputSize, String inputName, String outputName,
boolean feedKeepProb) throws IOException {
TensorFlowClassifier c = new TensorFlowClassifier();

static public Classifier create(AssetManager assetManager, String modelPath, String labelFile,
int inputSize, String inputName, String outputName)
throws IOException {

Classifier c = new Classifier();
c.name = name;

c.inputName = inputName;
c.outputName = outputName;

c.labels = readLabels(c, assetManager, labelFile);
c.labels = readLabels(assetManager, labelFile);

c.tfHelper = new TensorFlowInferenceInterface(assetManager, modelPath);
int numClasses = 10;

c.inputSize = inputSize;

// Pre-allocate buffer.
c.outputNames = new String[]{ outputName };
c.outputNames = new String[] { outputName };

c.outputName = outputName;
c.output = new float[numClasses];

c.feedKeepProb = feedKeepProb;

return c;
}

@Override
public String name() {
return name;
}

@Override
public Classification recognize(final float[] pixels) {

tfHelper.feed(inputName, pixels, 1, inputSize, inputSize, 1);
tfHelper.feed("keep_prob", new float[] {1.0f});
if (feedKeepProb) {
tfHelper.feed("keep_prob", new float[] { 1 });
}
tfHelper.run(outputNames);

tfHelper.fetch(outputName, output);
Expand Down
69 changes: 36 additions & 33 deletions MnistAndroid/app/src/main/res/layout/activity_main.xml
Original file line number Diff line number Diff line change
@@ -1,47 +1,50 @@
<?xml version="1.0" encoding="utf-8"?>
<LinearLayout
xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:paddingLeft="@dimen/activity_horizontal_margin"
android:paddingRight="@dimen/activity_horizontal_margin"
android:paddingTop="@dimen/activity_vertical_margin"
android:paddingBottom="@dimen/activity_vertical_margin"
android:orientation="vertical"
tools:context="mariannelinhares.mnistandroid.MainActivity">
xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:orientation="vertical"
android:paddingBottom="@dimen/activity_vertical_margin"
android:paddingLeft="@dimen/activity_horizontal_margin"
android:paddingRight="@dimen/activity_horizontal_margin"
android:paddingTop="@dimen/activity_vertical_margin"
tools:context="mariannelinhares.mnistandroid.MainActivity"
>

<mariannelinhares.mnistandroid.views.DrawView
android:layout_width="match_parent"
android:layout_height="0dp"
android:id="@+id/draw"
android:layout_weight="1"/>
android:id="@+id/draw"
android:layout_width="match_parent"
android:layout_height="0dp"
android:layout_weight="1"
/>

<LinearLayout
android:layout_width="match_parent"
android:orientation="horizontal"
android:layout_height="100dp">
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:orientation="horizontal"
>

<Button
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="Clear"
android:id="@+id/btn_clear"/>
android:id="@+id/btn_clear"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="Clear"
/>

<Button
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="Detect"
android:id="@+id/btn_class"/>
android:id="@+id/btn_class"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="Detect"
/>
</LinearLayout>

<TextView
<TextView
android:id="@+id/tfRes"
android:layout_width="0dp"
android:layout_height="wrap_content"
android:layout_weight="1"
android:layout_width="match_parent"
android:layout_height="150dp"
android:paddingLeft="10dp"
android:text="Result:"
android:textAppearance="?android:attr/textAppearanceMedium" />
</LinearLayout>

android:textAppearance="?android:attr/textAppearanceMedium"
/>
</LinearLayout>
12 changes: 6 additions & 6 deletions tensorflow_model/mnist_convnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def train(x, keep_prob, y_, train_step, loss, accuracy,
sess.run(init_op)

tf.train.write_graph(sess.graph_def, 'out',
MODEL_NAME + '.graph.bin', False)
MODEL_NAME + '.pbtxt', True)

# op to write logs to Tensorboard
summary_writer = tf.summary.FileWriter('logs/',
Expand All @@ -90,7 +90,7 @@ def train(x, keep_prob, y_, train_step, loss, accuracy,
feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
summary_writer.add_summary(summary, step)

saver.save(sess, 'out/' + MODEL_NAME + '.ckpt')
saver.save(sess, 'out/' + MODEL_NAME + '.chkp')

test_accuracy = accuracy.eval(feed_dict={x: mnist.test.images,
y_: mnist.test.labels,
Expand All @@ -100,8 +100,8 @@ def train(x, keep_prob, y_, train_step, loss, accuracy,
print("training finished!")

def export_model(input_node_names, output_node_name):
freeze_graph.freeze_graph('out/' + MODEL_NAME + '.graph.bin', None, True,
'out/' + MODEL_NAME + '.ckpt', output_node_name, "save/restore_all",
freeze_graph.freeze_graph('out/' + MODEL_NAME + '.pbtxt', None, False,
'out/' + MODEL_NAME + '.chkp', output_node_name, "save/restore_all",
"save/Const:0", 'out/frozen_' + MODEL_NAME + '.pb', True, "")

input_graph_def = tf.GraphDef()
Expand All @@ -112,8 +112,8 @@ def export_model(input_node_names, output_node_name):
input_graph_def, input_node_names, [output_node_name],
tf.float32.as_datatype_enum)

f = tf.gfile.FastGFile('out/opt_' + MODEL_NAME + '.pb', "wb")
f.write(output_graph_def.SerializeToString())
with tf.gfile.FastGFile('out/opt_' + MODEL_NAME + '.pb', "wb") as f:
f.write(output_graph_def.SerializeToString())

print("graph saved!")

Expand Down
Loading