Retrain a classification model on-device with backpropagation
If you're familiar with backpropagation, then you know it's used to train a neural network
by updating the weights in every layer after you determine the model's current loss. However, you
can also use backpropagation to update weights for only the last layer, which allows you to
retrain your model very quickly. And it's this technique that our
SoftmaxRegression
API provides so you can accelerate transfer-learning with the Edge TPU.
SoftmaxRegression
API is available in both PyCoral (Python) and Libcoral (C++), but this guide
describes only the Python API.
Overview
Ordinarily, because a TensorFlow Lite model must be compiled to run on the Edge TPU, the weights
inside the neural network are locked and cannot be modified by training on the device. However, if
you remove the last layer from the model before compiling it (thus creating an embedding extractor
model that outputs an image embedding), then you can implement the last
layer on the device in a way that allows for retraining of that layer. So that's what we do
to enable transfer-learning with SoftmaxRegression
.
The SoftmaxRegression
class is an
on-device implementation of the fully-connected layer with softmax activation that performs final
classification. And with its APIs, you can train the weights of the layer using stochastic gradient
descent (SGD), immediately run inferences using the new weights, and save it as a new .tflite
model file.
Of course, this strategy has both benefits and drawbacks:
Benefits:
- Transfer-learning happens on-device, at near-realtime speed.
- You don't need to recompile the model.
Drawbacks:
- The fully-connected layer with softmax activation executes on the CPU, not the Edge TPU. However, this layer represents a very small portion of the overall network, so impact on the inference speed is minimal.
- It's compatible with image classification models only—officially, only MobileNet and Inception.
ImprintingEngine
,
which uses weight imprinting instead of
backpropagation to update the weights of the last layer. For a comparison of these two
techniques, read Transfer-learning on-device.
API summary
The SoftmaxRegression
class
represents only the softmax layer for a classification model. Unlike the
ImprintingEngine
, it does not encapsulate
the entire model graph. So in order to perform training, you must run training data through the
base model (the embedding extractor) and then feed the results to this softmax layer.
The basic procedure to train using backpropagation with the
SoftmaxRegression
API
is as follows:
-
Create an instance of
Interpreter
for the Edge TPU, such as withmake_interpreter()
and specifying your embedding extractor model. -
For each training image, call run an inference with the
Interpreter
and collect the output (which is the image embedding). -
Create an instance of
SoftmaxRegression
and calltrain_with_sgd()
, passing it all the image embeddings. This is where the new training happens. -
Save the retrained model using
serialize_model()
, passing it the embedding extractor model. For example:with open(output_path, 'wb') as f: f.write(softmax_model.serialize_model(embedding_model_path))
-
Then use the new model to run inferences with PyCoral and TensorFlow Lite.
See the next section for a walkthrough with our example code.
Retrain a model with our sample code
To better illustrate how you can use the SoftmaxRegression
API, we've created a sample script:
backprop_last_layer.py
.
Follow the below procedure to try it with a flowers dataset.
If you're using the Dev Board, execute these commands on the board's terminal; if you're using the USB Accelerator, be sure it's connected to the host computer where you'll run these commands.
-
Set up the directory where you'll save all your work:
DEMO_DIR=$HOME/edgetpu/retrain-backprop
mkdir -p $DEMO_DIR -
Download and extract the flowers dataset:
wget http://download.tensorflow.org/example_images/flower_photos.tgz
tar zxf flower_photos.tgz -C $DEMO_DIR -
Download our embedding extractor (a version of the neural network without the final fully-connected layer, and pre-trained on ImageNet):
wget https://github.com/google-coral/test_data/raw/master/mobilenet_v1_1.0_224_quant_embedding_extractor_edgetpu.tflite -P $DEMO_DIR
We've also created embedding extractors for all three sizes of EfficientNet. You can find the links on our Models page. For example, here's the medium size:
wget https://github.com/google-coral/test_data/raw/master/efficientnet-edgetpu-M_quant_embedding_extractor_edgetpu.tflite -P $DEMO_DIR
If you want to use your own model, see the section below about how to create your own embedding extractor.
-
Download and navigate to the sample code:
mkdir coral && cd coral
git clone https://github.com/google-coral/pycoral.git
cd pycoral/examples/ -
Start transfer learning on the Edge TPU:
python3 backprop_last_layer.py \ --data_dir ${DEMO_DIR}/flower_photos \ --embedding_extractor_path \ ${DEMO_DIR}/mobilenet_v1_1.0_224_quant_embedding_extractor_edgetpu.tflite \ --output_dir ${DEMO_DIR}
This takes 1 - 2 minutes, and you should see training logs printed to the console.
-
Try the retrained model works by running it through the
classify_image.py
script:# Download a rose image from Open Images: curl -o ${DEMO_DIR}/rose.jpg https://c2.staticflickr.com/4/3062/3067374593_f2963e50b7_o.jpg
python3 classify_image.py \ --model ${DEMO_DIR}/retrained_model_edgetpu.tflite \ --label ${DEMO_DIR}/label_map.txt \ --input ${DEMO_DIR}/rose.jpgYou should see results such as this:
--------------------------- roses Score : 0.99609375
Create an embedding extractor
To use this backpropagation technique with your own model, you need to compile your TensorFlow Lite model with its last layer removed. Doing so creates a model called an embedding extractor, which outputs an image embedding (also called a feature embedding tensor).
Separating the embedding extractor allows for the last fully-connected layer to be implemented
on-device (with SoftmaxRegression
)
so we can backpropagate new weights. Assuming you've already trained a classification model with the
supported model architectures, you can follow the steps below to create an embedding extractor from
that pre-trained model.
Also, this technique does not work for EfficientNet, but you can instead use the EfficientNet tools to create the embedding extractor, or download a pre-compiled version from our Models page.
-
Identify the feature embedding tensor. A feature embedding tensor is the input tensor for the last fully-connected layer. For the classification model architectures we officially support, the following table lists their feature embedding tensor names, and the feature dimensions.
Model name Feature embedding tensor name Size mobilenet_v1_1.0_224_quant MobilenetV1/Logits/AvgPool_1a/AvgPool 1024 mobilenet_v2_1.0_224_quant MobilenetV2/Logits/AvgPool 1280 inception_v1_224_quant InceptionV1/Logits/AvgPool_0a_7x7/AvgPool 1024 inception_v2_224_quant InceptionV2/Logits/AvgPool_1a_7x7/AvgPool 1024 inception_v3_224_quant InceptionV3/Logits/AvgPool_1a_8x8/AvgPool 2048 inception_v4_224_quant InceptionV4/Logits/AvgPool_1a/AvgPool 1536 (You can also find the feature embedding tensor name when you visualize your model or list all the layers of your model using tools such as tflite_convert.)
-
Cut off the last fully-connected layer from the pre-trained classification model. Because you'll be changing the weights in the last fully-connected layer, your embedding extractor model is just a new version of the existing model but with this last layer removed. So you'll remove this layer using the tflite_convert tool, which converts the TensorFlow frozen graph into the TensorFlow Lite format. You just need to specify the output array that is the input for the last fully-connected layer (the feature embedding tensor).
For example, the following command extracts the embedding extractor from a MobileNet v1 model, and saves it as a TensorFlow Lite model.
# Create embedding extractor from MobileNet v1 classification model tflite_convert \ --output_file=mobilenet_v1_embedding_extractor.tflite \ --graph_def_file=mobilenet_v1_1.0_224_quant_frozen.pb \ --inference_type=QUANTIZED_UINT8 \ --mean_values=128 \ --std_dev_values=128 \ --input_arrays=input \ --output_arrays=MobilenetV1/Logits/AvgPool_1a/AvgPool
-
Compile the embedding extractor. You now have a version of the embedding extractor that's compiled for a CPU, so you now need to recompile it for the Edge TPU, using the Edge TPU Compiler. (This is no different than compiling a full classification model.)
Now just follow the procedures described in the API summary to perform training,
or pass your model to the backprop_last_layer.py
demo script.
Is this content helpful?