Skip to content

Instantly share code, notes, and snippets.

@iamyb
Last active November 27, 2020 09:07
Show Gist options
  • Select an option

  • Save iamyb/779a14c2da79c7d7fe5d7256e4d71ace to your computer and use it in GitHub Desktop.

Select an option

Save iamyb/779a14c2da79c7d7fe5d7256e4d71ace to your computer and use it in GitHub Desktop.
This is an example to convert tensorflow(v1.13) model with batchnorm to tflite.
import os
import argparse
import numpy as np
import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.examples.tutorials.mnist import input_data
## Arguments
parser = argparse.ArgumentParser(prog='python tflite_batch_norm.py')
group = parser.add_mutually_exclusive_group()
group.add_argument('-t', '--train', action='store_true')
group.add_argument('-i', '--infer', action='store_true')
group.add_argument('-f', '--frozen', action='store_true')
group.add_argument('-l', '--lite', action='store_true')
args = parser.parse_args()
def freeze_graph(model_dir, output_node, frozen_model_nm):
if not tf.gfile.Exists(model_dir):
raise AssertionError("Export directory doesn't exists: %s" % model_dir)
if not output_node:
raise AssertionError("You need to supply the output_node_names.")
checkpoint = tf.train.get_checkpoint_state(model_dir)
input_ckpt = checkpoint.model_checkpoint_path
with tf.Session(graph=tf.Graph()) as sess:
saver = tf.train.import_meta_graph(input_ckpt+'.meta', clear_devices=True)
saver.restore(sess, input_ckpt)
output_graph_def = tf.graph_util.convert_variables_to_constants(
sess, tf.get_default_graph().as_graph_def(),output_node.split(",")
)
output_graph = os.path.join(model_dir, frozen_model_nm)
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
print("%d ops in the final graph." % len(output_graph_def.node))
return output_graph_def
def build_network(inputs, output_node, training=True):
bn_param = {'fused':True, 'is_training':training, 'updates_collections':None}
x = tf.reshape(inputs, [-1, 28, 28, 1])
net = slim.conv2d(x, 32, [5,5], activation_fn=None)
net = slim.batch_norm(net, **bn_param)
net = tf.nn.relu(net)
net = slim.max_pool2d(net, [2, 2])
net = slim.conv2d(net, 64, [5,5], activation_fn=None)
net = slim.batch_norm(net, **bn_param)
net = tf.nn.relu(net)
net = slim.max_pool2d(net, [2, 2])
net = slim.flatten(net)
net = slim.fully_connected(net, 1024, activation_fn=None)
net = slim.batch_norm(net, **bn_param)
net = tf.nn.relu(net)
out = slim.fully_connected(net, N_OUTPUT, activation_fn=None)
out = tf.identity(out, output_node)
return out
## Global configuration
################################################################################
INPUT_NODE = 'net_input'
OUTPUT_NODE = 'net_output'
N_INPUT = 784
N_OUTPUT = 10
MODEL_DIR = "model"
BEST_TRAIN_CKPT_NAME = 'best.ckpt'
BEST_INFER_CKPT_NAME = 'best_infer.ckpt'
BEST_TRAIN_CKPT_PATH = os.path.join(MODEL_DIR, BEST_TRAIN_CKPT_NAME)
BEST_INFER_CKPT_PATH = os.path.join(MODEL_DIR, BEST_INFER_CKPT_NAME)
FROZEN_MODEL_NAME = 'frozen_model.pb'
TFLITE_MODEL_NAME = 'model.tflite'
FROZEN_MODEL_PATH = os.path.join(MODEL_DIR, FROZEN_MODEL_NAME)
TFLITE_MODEL_PATH = os.path.join(MODEL_DIR, TFLITE_MODEL_NAME)
## Model Conversion
################################################################################
## Frozen the checkpoints
if args.frozen:
freeze_graph(os.path.join(MODEL_DIR), OUTPUT_NODE, FROZEN_MODEL_NAME)
## Convert it to tflite
if args.lite:
converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
graph_def_file=FROZEN_MODEL_PATH,
input_arrays=[INPUT_NODE],
input_shapes={INPUT_NODE : [None, N_INPUT]},
output_arrays=[OUTPUT_NODE]
)
converter.target_ops=[tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
converter.inference_type = tf.float32
tflite_model = converter.convert()
with open(TFLITE_MODEL_PATH, 'wb') as f:
f.write(tflite_model)
## Build netowrk and session intialization
################################################################################
if args.infer or args.train:
n_classes = 10
x = tf.placeholder("float", [None, N_INPUT], INPUT_NODE)
y = tf.placeholder("float", [None, n_classes])
# Set the is_training according to args.train flag
pred = build_network(x, OUTPUT_NODE, args.train)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
labels=y, logits=pred))
optm = tf.train.AdamOptimizer(learning_rate=0.001).minimize(cost)
corr = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
accr = tf.reduce_mean(tf.cast(corr, "float"))
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(max_to_keep=100)
if not os.path.exists(MODEL_DIR):
os.makedirs(MODEL_DIR)
## Infercence Model, restore the model with is_training setting to False
################################################################################
if args.infer:
saver.restore(sess, BEST_TRAIN_CKPT_PATH)
saver.save(sess, BEST_INFER_CKPT_PATH)
mnist = input_data.read_data_sets(r'data/', one_hot=True)
batch_size = 100
total_batch = int(mnist.test.num_examples/batch_size)
total_acc = 0
for i in range(total_batch):
xi, yi = mnist.test.next_batch(batch_size)
total_acc += sess.run(accr, feed_dict={x: xi, y: yi})
val_acc = total_acc/total_batch
print('test accuracy: %s' % ( val_acc))
## Train the model and save the best checkpoint
################################################################################
if args.train:
mnist = input_data.read_data_sets(r'data/', one_hot=True)
trainimg, trainlabel = mnist.train.images, mnist.train.labels
valimg, vallabel = mnist.validation.images, mnist.validation.labels
training_epochs = 20
batch_size = 100
val_acc_max = 0
total_batch = int(mnist.train.num_examples/batch_size)
total_batch_val = int(mnist.validation.num_examples/batch_size)
for epoch in range(training_epochs):
total_loss = 0.
for i in range(total_batch):
xi, yi = mnist.train.next_batch(batch_size)
sess.run(optm, feed_dict={x: xi, y: yi})
cur_loss = sess.run(cost, feed_dict={x: xi, y: yi})
total_loss += cur_loss
avg_loss = total_loss / total_batch
total_acc_val = 0
for i in range(total_batch_val):
xi, yi = mnist.validation.next_batch(batch_size)
total_acc_val += sess.run(accr, feed_dict={x: xi, y: yi})
val_acc = total_acc_val/total_batch_val
print('train loss: %s . valid accuracy: %s' % ( avg_loss, val_acc))
if val_acc > val_acc_max:
val_acc_max = val_acc
saver.save(sess=sess, save_path=BEST_TRAIN_CKPT_PATH)
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets(r'data/', one_hot=True)
testimg = mnist.test.images
testlabel = mnist.test.labels
# Load TFLite model and allocate tensors.
interpreter = tf.contrib.lite.Interpreter(model_path='model/model.tflite')
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
input_shape = input_details[0]['shape']
total_correct = 0
for index in range(mnist.test.num_examples):
input_data = testimg[index].reshape(input_shape)
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
pred = int(np.argmax(output_data[0]))
label= int(np.argmax(testlabel[index]))
total_correct += int(pred==label)
print('test acc: %s' % (total_correct/(mnist.test.num_examples)))
@iamyb
Copy link
Copy Markdown
Author

iamyb commented Nov 27, 2020

Execute below commands one by one to generate tflite model with batch_norm.

python tflite_batch_norm.py --train
python tflite_batch_norm.py --infer
python tflite_batch_norm.py --frozen
python tflite_batch_norm.py --lite

Test the generated model:

python try_tflite_model.py

Only tested with tensorflow 1.13.1
The workaround is to solve the error caused by batch normolization of tensorflow during tflite conversion, things like below:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Input 0 of node BatchNorm/AssignMovingAvg was passed float from BatchNorm/moving_mean:0 incompatible with expected float_ref.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment