Retrain a classification model on-device with weight imprinting
Weight imprinting is a technique for retraining a neural network (classification models only) using
a small set of sample data, based on the technique described in Low-Shot Learning with Imprinted
Weights. It's designed to update the weights for only the
last layer of the model, but in a way that can retain existing classes while adding new ones. We've
implemented this technique in the
ImprintingEngine
API, allowing you to accelerate transfer-learning with the Edge TPU.
ImprintingEngine
API is available in both PyCoral (Python) and Libcoral (C++), but this guide
describes only the Python API.
Overview
To use the ImprintingEngine
API, you need to
provide a specially-designed model that separates the embedding extractor from the last layer where
classification occurs. This is necessary because once a model is compiled for the Edge TPU, the
network's weights are locked and cannot be changed—by separating the last layer and compiling only
the base of the graph, we can update weights in the classification layer. Additionally, the weight
imprinting technique requires a few changes to the model architecture to facilitate more accurate
weights for the last layer (such as an additional L2-normalization layer and an added
scaling factor). (For all the details about the model architecture, read Low-Shot Learning with
Imprinted Weights.)
However, unlike the on-device backpropagation
technique, the model you provide for
weight imprinting must be the complete graph (not just the embedding extractor). The model is still
divided into separate parts for the embedding extractor and the classification layer, and only the
base portion is compiled for the Edge TPU, but the two parts are recombined so that the
classifications from the original model are preserved. However, the original classes cannot be
retrained—you can train and update only new classes that you add using the
ImprintingEngine
API.
Of course, this strategy has both benefits and drawbacks:
Benefits:
- Transfer-learning happens on-device, at near-realtime speed.
- Very few sample images are required (fewer than 10 training samples can achieve high accuracy).
- You don't need to recompile the model.
Drawbacks:
- It has difficulty learning from datasets with large intra-class variation (when the data for a given class contains large variation across samples, such as major differences in the subject angle or size). If your use-case expects data with high intra-class variance, consider instead using on-device transfer learning with backpropagation (it requires a larger training dataset).
- The last fully-connected layer executes on the CPU, not the Edge TPU. However, this layer represents a very small portion of the overall network, so impact on the inference speed is minimal.
- It has specific model architecture requirements. We've shared a version of MobileNet v1 that is compatible (see below), but if you prefer a different model, then you must make the necessary changes to your model.
SoftmaxRegression
API,
which instead uses backpropagation to update the weights of the last layer. For a comparison of
these two techniques, read Transfer-learning on-device.
API summary
The ImprintingEngine
class encapsulates
the entire model that you want to train. Once you instantiate an instance with a compatible model,
you can pass it training data to update the weights in the last layer, and then immediately use
the model to perform inferencing.
The basic procedure to perform weight imprinting with the
ImprintingEngine
API
is as follows:
-
Create an instance of
ImprintingEngine
by specifying a compatible TensorFlow Lite model. Most applications should use our pre-trained model (mobilenet_v1_1.0_224_l2norm_quant_edgetpu.tflite
), but you can also retrain the MobileNet model or build your own model.The initialization function allows you to specify whether to keep the classifications from the pre-trained model or abandon them and use only the classes you're about to add.
-
Create an instance of
Interpreter
for the Edge TPU, using theImprintingEngine
model, provided byserialize_extractor_model()
. For example:engine = ImprintingEngine(model_path) extractor_interpreter = make_edgetpu_interpreter(engine.serialize_extractor_model())
-
For each training image, run an inference with the
Interpreter
and collect the output (which is the image embedding). -
Then train a new class or continue training an existing class by calling
train()
, which takes the image embedding and a label ID. -
Save the retrained model using
serialize_model()
. For example:with open(output_path, 'wb') as f: f.write(engine.serialize_model())
-
Then use the new model to run inferences with PyCoral and TensorFlow Lite.
See the next section for a walkthrough with our example code.
Retrain a model on-device with our sample
To show you how this works, we've created a sample script,imprinting_learning.py
,
which uses ImprintingEngine
to perform
on-device transfer learning with a given model.
The model you'll retrain with this sample is a modified MobileNet v1 model that's pre-trained to
understand 1,000 classes from the ImageNet dataset. The
ImprintingEngine
API allows you to keep the
original classes learned from pre-training, but in this sample, you'll abandon those and retrain it
to understand just 10 classes (the model retains all the feature extractors from the base
model—we only reset the final classifications).
If you're using the Dev Board, execute the following commands on the board's terminal. If you're using the USB Accelerator, be sure it's connected to the host computer where you'll run these commands.
-
Set up the directory where you'll save all your work:
DEMO_DIR=$HOME/edgetpu/retrain-imprinting
mkdir -p $DEMO_DIR -
Download our pre-trained model (a custom version of MobileNet v1):
wget https://github.com/google-coral/test_data/raw/master/mobilenet_v1_1.0_224_l2norm_quant_edgetpu.tflite -P $DEMO_DIR
This model is pre-trained with 1,000 classes from ImageNet, so the base model has very good feature extractors. But if you want to pre-train this model with your own dataset, see the section below about how to retrain the base MobileNet model (this is not the usual procedure to train MobileNet).
-
Download our sample training dataset (10 classes with about 20 photos each):
wget https://dl.google.com/coral/sample_data/imprinting_data_script.tar.gz -P $DEMO_DIR
tar zxf $DEMO_DIR/imprinting_data_script.tar.gz -C $DEMO_DIR
bash $DEMO_DIR/imprinting_data_script/download_imprinting_test_data.sh $DEMO_DIRThis takes a couple minutes to download the images (depending on your internet speed).
-
Download and navigate to the sample code that performs retraining:
mkdir coral && cd coral
git clone https://github.com/google-coral/pycoral.git
cd pycoral/examples/ -
Start transfer learning on the Edge TPU:
python3 imprinting_learning.py \ --model_path ${DEMO_DIR}/mobilenet_v1_1.0_224_l2norm_quant_edgetpu.tflite \ --data ${DEMO_DIR}/open_image_v4_subset \ --output ${DEMO_DIR}/retrained_imprinting_model.tflite
This should take 1 - 2 minutes when using our sample dataset. When it's done, the newly trained model is saved at
${DEMO_DIR}/retrained_imprinting_model.tflite
. -
Try the retrained model by running it through the
classify_image.py
script:
# Download a new cat photo from Open Images: curl -o ${DEMO_DIR}/cat.jpg https://c4.staticflickr.com/4/3685/10013800466_8f2fb8697e_z.jpg
python3 classify_image.py \ --model ${DEMO_DIR}/retrained_imprinting_model.tflite \ --label ${DEMO_DIR}/retrained_imprinting_model.txt \ --input ${DEMO_DIR}/cat.jpgYou should see results such as this:
--------------------------- Cat Score : 0.9921875
That's it! You've just trained a model with weight imprinting on the Edge TPU.
To repeat this demo with your own dataset, just add a new directory inside the
open_image_v4_subset
directory, and add some photos of a new class (even just 5 - 10 photos
should work). Then repeat steps 5 and 6 to retrain the model and perform an inference.
Retrain the base MobileNet model
The MobileNet model we shared for the above demo was trained with 1,000 classes from ImageNet ILSVRC2012, which results in a model with very good feature extractors for a variety of image classification tasks. However, if you want to fine-tune the base MobileNet model with your own training dataset, you can do so as follows.
And although the training above was accelerated by the Edge TPU, the following retraining for the base MobileNet model cannot run on the Edge TPU, and some required tools are not compatible with the Coral Dev Board. So you should perform these steps on a powerful desktop computer.
Requirements:
- Python 3
- TensorFlow stable release (tested with v1.14)
- Bazel 0.26.1
-
Checkpoint for the MobileNet v1 model with quantized weights
You can use the checkpoint from any MobileNet v1 model, as long as all values are quantized. Or use the checkpoint from our pre-trained model in mobilenet_v1_1.0_224_l2norm_quant.tar.gz.
-
Training data in TFRecord format
Note: The.tfrecord
files for training must begin with "train-" (for exampletrain-00000-of-00005.tfrecord
) and files for validation must begin with "validation-".For an example conversion script, see download_and_convert_flowers.py.
Pre-train the model
-
Sync our Git repo that contains the training scripts:
git clone https://github.com/google-coral/imprinting-training
cd imprinting-training
git submodule init && git submodule update
export PYTHONPATH=$(pwd):$(pwd)/models/research/slim -
Build our modified MobileNet v1 model with L2-normalization:
cd classification
bazel build mobilenet_v1_l2norm -
Start the training script with the model checkpoint and dataset (set the variables for your own data paths):
# Location of your TFRecord files DATASET_DIR=/home/edgetpu/classify/flowers
# Location of your checkpoint (the common path for all .ckpt files) FINETUNE_CHECKPOINT_PATH=/home/edgetpu/classify/train/model.ckpt-300
# Destination for the training logs CHECKPOINT_DIR=/home/edgetpu/l2norm-training
python3 mobilenet_v1_l2norm_train.py \ --quantize=True \ --dataset_dir=${DATASET_DIR} \ --fine_tune_checkpoint=${FINETUNE_CHECKPOINT_PATH} \ --checkpoint_dir=${CHECKPOINT_DIR} \ --freeze_base_model=True \ --number_of_steps=100000 -
When you're ready to evaluate the model performance, you can do so as follows:
CHECKPOINT_FILE=$CHECKPOINT_DIR/model.ckpt-654321
python3 mobilenet_v1_l2norm_eval.py \ --quantize=True \ --checkpoint_dir=$CHECKPOINT_FILE \ --dataset_dir=$DATASET_DIR
Now you have a pre-trained MobileNet model (with L2-normalization).
The next section shows how to convert the model for the Edge TPU.
Export the graph for the Edge TPU
-
Save a GraphDef of the model:
# Still inside the imprinting-training/classification/ directory python3 export_inference_graph_l2norm.py \ --quantize=True \ --output_file=$CHECKPOINT_DIR/mobilenet_v1_l2norm_inf_graph.pb -
Freeze the graph with your new checkpoint:
# Use the tensorflow repo inside imprinting-training/ cd ../tensorflow/
bazel build tensorflow/python/tools:freeze_graph
# Check the output of the build command for the freeze_graph location freeze_graph \ --input_graph=$CHECKPOINT_DIR/mobilenet_v1_l2norm_inf_graph.pb \ --input_checkpoint=$CHECKPOINT_FILE \ --input_binary=true \ --output_graph=$CHECKPOINT_DIR/frozen_mobilenet_v1_l2norm.pb \ --output_node_names=MobilenetV1/Predictions/Reshape_1 -
Now we need to strip out the L2-norm operator that we added in the base graph because this operation is not supported on the Edge TPU (removing it now has no effect because the new weights are already frozen how we want them):
bazel build tensorflow/tools/graph_transforms:transform_graph
transform_graph \ --in_graph=$CHECKPOINT_DIR/frozen_mobilenet_v1_l2norm.pb \ --out_graph=$CHECKPOINT_DIR/frozen_mobilenet_v1_l2norm_optimized.pb \ --inputs=input \ --outputs=MobilenetV1/Predictions/Reshape_1 \ --transforms='strip_unused_nodes fold_constants' -
Now run the following commands to separate both the model base (the embedding extractor) and the model head (the classification layer) as individual graphs.
-
First convert the entire frozen graph to a TensorFlow Lite file (or else the model head will have the wrong input parameters):
tflite_convert \ --graph_def_file=$CHECKPOINT_DIR/frozen_mobilenet_v1_l2norm_optimized.pb \ --output_file=$CHECKPOINT_DIR/mobilenet_v1_l2norm_quant.tflite \ --inference_type=QUANTIZED_UINT8 \ --mean_values=128 \ --std_dev_values=128 \ --input_arrays=input \ --output_arrays=MobilenetV1/Predictions/Reshape_1
-
Create the base graph as its own file (using
toco
becausetflite_convert
does not support.tflite
files as input):
# You should build the following version of toco because the
# packaged version does not support the 'input_file' argument bazel build tensorflow/lite/toco:toco
toco \ --input_file=$CHECKPOINT_DIR/mobilenet_v1_l2norm_quant.tflite \ --output_file=$CHECKPOINT_DIR/mobilenet_v1_embedding_extractor.tflite \ --input_format=TFLITE \ --output_format=TFLITE \ --inference_type=QUANTIZED_UINT8 \ --input_arrays=input \ --output_arrays=MobilenetV1/Logits/AvgPool_1a/AvgPool -
Create the head graph:
toco \ --input_file=$CHECKPOINT_DIR/mobilenet_v1_l2norm_quant.tflite \ --output_file=$CHECKPOINT_DIR/mobilenet_v1_last_layers.tflite \ --input_format=TFLITE \ --output_format=TFLITE \ --inference_type=QUANTIZED_UINT8 \ --input_arrays=MobilenetV1/Logits/AvgPool_1a/AvgPool \ --output_arrays=MobilenetV1/Predictions/Reshape_1
-
-
Compile the base graph with the Edge TPU Compiler:
edgetpu_compiler -o $CHECKPOINT_DIR \ $CHECKPOINT_DIR/mobilenet_v1_embedding_extractor.tflite
-
Then re-join the compiled base graph to the head graph using our
join_tflite_models
tool.First clone our repo and build the tools with Docker:
git clone --recurse-submodules https://github.com/google-coral/libcoral.git cd libcoral make DOCKER_IMAGE="ubuntu:18.04" DOCKER_CPUS="k8" DOCKER_TARGETS="tools" docker-build
Now run the join tool (using the path in
out/
based on what you built above):# Your prompt should be in the libcoral/ root: ./out/k8/tools/join_tflite_models \ --input_graph_base=$CHECKPOINT_DIR/mobilenet_v1_embedding_extractor_edgetpu.tflite \ --input_graph_head=$CHECKPOINT_DIR/mobilenet_v1_last_layers.tflite \ --output_graph=$CHECKPOINT_DIR/mobilenet_v1_l2norm_quant_edgetpu.tflite
Now you're done.
You can move the mobilenet_v1_l2norm_quant_edgetpu.tflite
file to your Edge TPU device and
use it with the sample script above or your own code
using the ImprintingEngine
API.
Build a different model for ImprintingEngine
Everything above uses a version of the MobileNet v1 model we created specifically for weight
imprinting, because ImprintingEngine
is not compatible with ordinary classification models or embedding extractors.
Creating a different classification model that's compatible with ImprintingEngine
is possible, but it's a significant
undertaking that demands expert TensorFlow knowledge.
A complete description of the architecture we implemented for our model is beyond the scope of this
document, but you can inspect our implementation in mobilenet_v1_l2norm.py
.
We also suggest you carefully read the research paper that this design is based upon: Low-Shot Learning with Imprinted Weights.
Is this content helpful?