Last active
November 27, 2020 09:07
-
-
Save iamyb/779a14c2da79c7d7fe5d7256e4d71ace to your computer and use it in GitHub Desktop.
This is an example to convert tensorflow(v1.13) model with batchnorm to tflite.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import argparse | |
| import numpy as np | |
| import tensorflow as tf | |
| import tensorflow.contrib.slim as slim | |
| from tensorflow.examples.tutorials.mnist import input_data | |
| ## Arguments | |
| parser = argparse.ArgumentParser(prog='python tflite_batch_norm.py') | |
| group = parser.add_mutually_exclusive_group() | |
| group.add_argument('-t', '--train', action='store_true') | |
| group.add_argument('-i', '--infer', action='store_true') | |
| group.add_argument('-f', '--frozen', action='store_true') | |
| group.add_argument('-l', '--lite', action='store_true') | |
| args = parser.parse_args() | |
| def freeze_graph(model_dir, output_node, frozen_model_nm): | |
| if not tf.gfile.Exists(model_dir): | |
| raise AssertionError("Export directory doesn't exists: %s" % model_dir) | |
| if not output_node: | |
| raise AssertionError("You need to supply the output_node_names.") | |
| checkpoint = tf.train.get_checkpoint_state(model_dir) | |
| input_ckpt = checkpoint.model_checkpoint_path | |
| with tf.Session(graph=tf.Graph()) as sess: | |
| saver = tf.train.import_meta_graph(input_ckpt+'.meta', clear_devices=True) | |
| saver.restore(sess, input_ckpt) | |
| output_graph_def = tf.graph_util.convert_variables_to_constants( | |
| sess, tf.get_default_graph().as_graph_def(),output_node.split(",") | |
| ) | |
| output_graph = os.path.join(model_dir, frozen_model_nm) | |
| with tf.gfile.GFile(output_graph, "wb") as f: | |
| f.write(output_graph_def.SerializeToString()) | |
| print("%d ops in the final graph." % len(output_graph_def.node)) | |
| return output_graph_def | |
| def build_network(inputs, output_node, training=True): | |
| bn_param = {'fused':True, 'is_training':training, 'updates_collections':None} | |
| x = tf.reshape(inputs, [-1, 28, 28, 1]) | |
| net = slim.conv2d(x, 32, [5,5], activation_fn=None) | |
| net = slim.batch_norm(net, **bn_param) | |
| net = tf.nn.relu(net) | |
| net = slim.max_pool2d(net, [2, 2]) | |
| net = slim.conv2d(net, 64, [5,5], activation_fn=None) | |
| net = slim.batch_norm(net, **bn_param) | |
| net = tf.nn.relu(net) | |
| net = slim.max_pool2d(net, [2, 2]) | |
| net = slim.flatten(net) | |
| net = slim.fully_connected(net, 1024, activation_fn=None) | |
| net = slim.batch_norm(net, **bn_param) | |
| net = tf.nn.relu(net) | |
| out = slim.fully_connected(net, N_OUTPUT, activation_fn=None) | |
| out = tf.identity(out, output_node) | |
| return out | |
| ## Global configuration | |
| ################################################################################ | |
| INPUT_NODE = 'net_input' | |
| OUTPUT_NODE = 'net_output' | |
| N_INPUT = 784 | |
| N_OUTPUT = 10 | |
| MODEL_DIR = "model" | |
| BEST_TRAIN_CKPT_NAME = 'best.ckpt' | |
| BEST_INFER_CKPT_NAME = 'best_infer.ckpt' | |
| BEST_TRAIN_CKPT_PATH = os.path.join(MODEL_DIR, BEST_TRAIN_CKPT_NAME) | |
| BEST_INFER_CKPT_PATH = os.path.join(MODEL_DIR, BEST_INFER_CKPT_NAME) | |
| FROZEN_MODEL_NAME = 'frozen_model.pb' | |
| TFLITE_MODEL_NAME = 'model.tflite' | |
| FROZEN_MODEL_PATH = os.path.join(MODEL_DIR, FROZEN_MODEL_NAME) | |
| TFLITE_MODEL_PATH = os.path.join(MODEL_DIR, TFLITE_MODEL_NAME) | |
| ## Model Conversion | |
| ################################################################################ | |
| ## Frozen the checkpoints | |
| if args.frozen: | |
| freeze_graph(os.path.join(MODEL_DIR), OUTPUT_NODE, FROZEN_MODEL_NAME) | |
| ## Convert it to tflite | |
| if args.lite: | |
| converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph( | |
| graph_def_file=FROZEN_MODEL_PATH, | |
| input_arrays=[INPUT_NODE], | |
| input_shapes={INPUT_NODE : [None, N_INPUT]}, | |
| output_arrays=[OUTPUT_NODE] | |
| ) | |
| converter.target_ops=[tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS] | |
| converter.inference_type = tf.float32 | |
| tflite_model = converter.convert() | |
| with open(TFLITE_MODEL_PATH, 'wb') as f: | |
| f.write(tflite_model) | |
| ## Build netowrk and session intialization | |
| ################################################################################ | |
| if args.infer or args.train: | |
| n_classes = 10 | |
| x = tf.placeholder("float", [None, N_INPUT], INPUT_NODE) | |
| y = tf.placeholder("float", [None, n_classes]) | |
| # Set the is_training according to args.train flag | |
| pred = build_network(x, OUTPUT_NODE, args.train) | |
| cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits( | |
| labels=y, logits=pred)) | |
| optm = tf.train.AdamOptimizer(learning_rate=0.001).minimize(cost) | |
| corr = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)) | |
| accr = tf.reduce_mean(tf.cast(corr, "float")) | |
| sess = tf.Session() | |
| sess.run(tf.global_variables_initializer()) | |
| saver = tf.train.Saver(max_to_keep=100) | |
| if not os.path.exists(MODEL_DIR): | |
| os.makedirs(MODEL_DIR) | |
| ## Infercence Model, restore the model with is_training setting to False | |
| ################################################################################ | |
| if args.infer: | |
| saver.restore(sess, BEST_TRAIN_CKPT_PATH) | |
| saver.save(sess, BEST_INFER_CKPT_PATH) | |
| mnist = input_data.read_data_sets(r'data/', one_hot=True) | |
| batch_size = 100 | |
| total_batch = int(mnist.test.num_examples/batch_size) | |
| total_acc = 0 | |
| for i in range(total_batch): | |
| xi, yi = mnist.test.next_batch(batch_size) | |
| total_acc += sess.run(accr, feed_dict={x: xi, y: yi}) | |
| val_acc = total_acc/total_batch | |
| print('test accuracy: %s' % ( val_acc)) | |
| ## Train the model and save the best checkpoint | |
| ################################################################################ | |
| if args.train: | |
| mnist = input_data.read_data_sets(r'data/', one_hot=True) | |
| trainimg, trainlabel = mnist.train.images, mnist.train.labels | |
| valimg, vallabel = mnist.validation.images, mnist.validation.labels | |
| training_epochs = 20 | |
| batch_size = 100 | |
| val_acc_max = 0 | |
| total_batch = int(mnist.train.num_examples/batch_size) | |
| total_batch_val = int(mnist.validation.num_examples/batch_size) | |
| for epoch in range(training_epochs): | |
| total_loss = 0. | |
| for i in range(total_batch): | |
| xi, yi = mnist.train.next_batch(batch_size) | |
| sess.run(optm, feed_dict={x: xi, y: yi}) | |
| cur_loss = sess.run(cost, feed_dict={x: xi, y: yi}) | |
| total_loss += cur_loss | |
| avg_loss = total_loss / total_batch | |
| total_acc_val = 0 | |
| for i in range(total_batch_val): | |
| xi, yi = mnist.validation.next_batch(batch_size) | |
| total_acc_val += sess.run(accr, feed_dict={x: xi, y: yi}) | |
| val_acc = total_acc_val/total_batch_val | |
| print('train loss: %s . valid accuracy: %s' % ( avg_loss, val_acc)) | |
| if val_acc > val_acc_max: | |
| val_acc_max = val_acc | |
| saver.save(sess=sess, save_path=BEST_TRAIN_CKPT_PATH) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import tensorflow as tf | |
| import numpy as np | |
| from tensorflow.examples.tutorials.mnist import input_data | |
| mnist = input_data.read_data_sets(r'data/', one_hot=True) | |
| testimg = mnist.test.images | |
| testlabel = mnist.test.labels | |
| # Load TFLite model and allocate tensors. | |
| interpreter = tf.contrib.lite.Interpreter(model_path='model/model.tflite') | |
| interpreter.allocate_tensors() | |
| # Get input and output tensors. | |
| input_details = interpreter.get_input_details() | |
| output_details = interpreter.get_output_details() | |
| input_shape = input_details[0]['shape'] | |
| total_correct = 0 | |
| for index in range(mnist.test.num_examples): | |
| input_data = testimg[index].reshape(input_shape) | |
| interpreter.set_tensor(input_details[0]['index'], input_data) | |
| interpreter.invoke() | |
| output_data = interpreter.get_tensor(output_details[0]['index']) | |
| pred = int(np.argmax(output_data[0])) | |
| label= int(np.argmax(testlabel[index])) | |
| total_correct += int(pred==label) | |
| print('test acc: %s' % (total_correct/(mnist.test.num_examples))) |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Execute below commands one by one to generate tflite model with batch_norm.
Test the generated model:
Only tested with tensorflow 1.13.1
The workaround is to solve the error caused by batch normolization of tensorflow during tflite conversion, things like below: