Preprocess Keras Model for TensorSpace

Written by zchholmes | Published 2018/11/17
Tech Story Tags: machine-learning | data-visualization | technology | javascript | python

TLDRvia the TL;DR App

How to preprocess Keras model to be TensorSpace compatible for neural network 3D visualization

“TensorSpace is a neural network 3D visualization framework. — TensorSpace.org”

“Keras is a high-level neural network API. — keras.io ”

Introduction

You may learn about TensorSpace can be used to 3D visualize the neural networks. You might have read my previous introduction about TensorSpace. Maybe you find it is a little complicated on the model preprocess.

Hence today, I want to talk about the model preprocess of TensorSpace for more details. To be more specific, how to preprocess the deep learning model built by Keras to be TensorSpace compatible.

Fig. 1 — Use TensorSpace to visualize an LeNet built by Keras

What we should have?

To make a model built by Keras to be TensorSpace compatible, we need the model to satisfy two key points:

  • Support multiple outputs from intermediate layers.
  • Support TensorSpace compatible format which is browser-friendly.

For the following parts, I will use an LeNet as an example to introduce the workflow on preprocessing a Keras model.

To continue the process, I will assume we have a proper Python environment set up:

  • Python 3.6
  • import keras and numpy
  • installed tfjs-converter.

You can find all resources of the example from the TensorSpace preprocess Keras directory which includes:

The source file to train a sample LeNet mode can be found from keras_model.py. You can also try to use your own model which should be valid and work properly with a sample input data.

Before preprocessing, what we get from the model?

Before the preprocess, the pre-trained LeNet model is actually a black box: it is feed by a 28x28 image and then returns a list of 10 doubles. Each double represents the probability of a digit from ‘0’ to ‘9’.

Fig. 2 — Classic pre-trained model with single output

We can load the model and perform the prediction like:

model = load_model("/PATH/TO/OUTPUT/keras_model.h5")input_sample = np.ndarray(shape=(28,28),buffer=np.random.rand(28,28))input_sample = np.expand_dims(input_sample, axis=0)print(model.predict(input_sample))

We use a 28x28 numpy array to mock a random input image. The sample output is like:

Fig. 3 — Single list prediction output from trained model

If the model is pre-trained well, it is good enough to be used as an application — we just need to add an input pan for handwritten and a proper output function to show the predictions.

However, just like a “magic”, it is difficult for people to learn the process of the model predictions, since there is a gap between the input and output. The preprocess is actually a way to expose some parts of the “magic” from the gap.

What is the Preprocess for TensorSpace?

A model preprocess for TensorSpace is the process to:

  • Detect necessary data
  • Extract intermediate outputs from hidden layers
  • Convert to TensorSpace compatible model formats

The series of actions should be completed before applying the model to TensorSpace framework. The preprocess for TensorSpace is the way to satisfy the basic requirements of TensorSpace. The gathered data from pre-trained deep learning model are used to render the TensorSpace visualization model.

How to preprocess a Keras model?

From a call of model.summary(), it is easier to check the information of each layer.

Fig. 4 — Model summary and layer names

Here for example, we want to collect all layer names.

output_layer_names = ["Conv2D_1", "MaxPooling2D_1", "Conv2D_2", "MaxPooling2D_2","Dense_1", "Dense_2", "Softmax"]

Next, we want to construct a new model based on the original model and the layer names we just collected.

def generate_encapsulate_model_with_output_layer_names(model,output_layer_names):enc_model = Model(inputs=model.input,outputs=list(map(lambda oln: model.get_layer(oln).output,output_layer_names)))return enc_model

enc_model = generate_encapsulate_model_with_output_layer_names(model,output_layer_names)

enc_model.save("/PATH/TO/ENC_MODEL/enc_keras_model.h5")

Then, we can use the new encapsulated model to check if the last prediction is valid as before.

input_sample = np.ndarray(shape=(28,28),buffer=np.random.rand(28,28))input_sample = np.expand_dims(input_sample, axis=0)print(enc_model.predict(input_sample))

The encapsulated model returns a long list of outputs which represents the results from the intermediate layers in our “output_layer_names” list.

Fig. 5 — Multiple list outputs after preprocessing

The last output is a list of double with size 10, which represents the original output of the LeNet model.

Fig. 6 — Last list output is the same as the original inferences

Last, we can use the tfjs-converter to convert the Keras model into TensorFlow.js model which can be used directly by TensorSpace.

tensorflowjs_converter \--input_format=keras \/PATH/TO/ENC_MODEL/enc_keras_model.h5 \/PATH/TO/OUTPUT_DIR/

After preprocessing, what we expect from the encapsulated mode?

After preprocessing, we should have a model:

  • contains all from the original model
  • has the ability to provide data outputs from intermediate layers
  • is in a TensorSpace compatible format

The data from the intermediate layers can be collected by TensorSpace and be used to render visualization objects in the TensorSpace model.

Fig. 7 — TensorSpace compatible model with intermediate outputs

Last, we can apply our preprocessed model to TensorSpace for visualization!

View in CodePen.

Conclusion

The preprocess for TensorSpace is an important step before applying TensorSpace API. The necessary intermediate data used to render 3D visualization can be gathered after the preprocess.

Now, we can collect more data from the deep learning model. The next step is about how to use and analyze the data from intermediate layers wisely. Data visualization can be a way to observe the data, which may require some tools, for example, TensorSpace.

For further information about TensorSpace.js, please check out:


Published by HackerNoon on 2018/11/17