An example script for working with Zetane and Keras. In this script we will:

  1. Build a convolutional neural network (CNN) on Keras and train it,

  2. Send the resulting model to the Zetane Engine, and

  3. Test the model in the Zetane Engine with sample images.

from __future__ import print_function
import sys
import os
import time
os.environ['TF_KERAS'] = '1'
import tensorflow.keras as keras
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential, model_from_json
from tensorflow.keras.layers import Dense, Dropout, Flatten
from tensorflow.keras.layers import Conv2D, MaxPooling2D
from tensorflow.keras import backend as K
import tensorflow as tf
physical_devices = tf.config.experimental.list_physical_devices('GPU')
for physical_device in physical_devices:
    tf.config.experimental.set_memory_growth(physical_device, True)

import numpy as np
from PIL import Image

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
import zetane.context as ztn

Here’s our function to create a basic CNN. Two convolutional layers, pooling and two fully connected layers and a softmax activation. Feel free to experiment with different architectures, but do keep in mind that MNIST is a basic dataset with only 10 output classes, so increasing the complexity of the models will not necessarily yield better results.

def create_model():
  model = Sequential()
  model.add(Conv2D(32, kernel_size=(3, 3),
  model.add(Conv2D(64, (3, 3), activation='relu'))
  model.add(MaxPooling2D(pool_size=(2, 2)))
  model.add(Dense(128, activation='relu'))
  model.add(Dense(num_classes, activation='softmax'))


  # serialize model to JSON
  model_json = model.to_json()
  with open(os.path.join(dir_path, "model.json"), "w") as json_file:
  return model

def train(model, epochs=5, batch_size=64):, y_train,
              validation_data=(x_test, y_test))
    model.save_weights(os.path.join(dir_path, "model_weights_trained.h5"))

Now, the actual inference loop, where we call each example one by one: Calling them one by one means we’ll need to add a batch dimension so that the input is compatible with our CNN, which expects data in batches, hence the tf.expand_dims(x, 0) call. We first run inference in Keras to obtain our predicted class pred, and update our ztxt objects from the previous cell to show the target and predicted class.

We can now input our data to the Zetane model to inspect the intermediate layers in Zetane by calling zmodel.update(inputs=x_ext), where inputs may be a Numpy array or a filepath to a .npy or .npz object.

def predict_in_zetane(model, test_dataset):
    for x, y in test_dataset:
        x_ext = tf.expand_dims(x, 0).numpy()
        y_out = model.predict(x_ext, steps=1)

        target = tf.argmax(y).numpy()
        pred = tf.argmax(y_out[0]).numpy()

        ztxt_target.text("Target class: " + str(target)).update()
        ztxt_pred.text("Predicted class: " + str(pred)).update()

        # Send inputs to viz
        # sleep to see more

Here we define some global variables we’ll refer to throughout the script. The load variable here specifies whether to load a predefined/trained model as opposed to training one from scratch. Don’t worry if you don’t have a model ready, the script will train one from scratch if it can’t find a model.

dir_path = os.path.dirname(os.path.realpath(__file__))
batch_size = 128
num_classes = 10
epochs = 5
load = True
# input image dimensions
img_rows, img_cols = 28, 28
# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()

Note the if/else on K.image_data_format(): This is to determine where the channels dimension will be after preprocessing to be compatible with the Keras backend. Afterwards, the data is converted to floats, normalized to be between 0 and 1, and the target values are one-hot encoded.

if K.image_data_format() == 'channels_first':
    x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
    x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
    input_shape = (1, img_rows, img_cols)
    x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
    x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
    input_shape = (img_rows, img_cols, 1)

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

if load:
      # load json and create model
      json_file = open(os.path.join(dir_path, 'model.json'), 'r')
      loaded_model_json =
      model = model_from_json(loaded_model_json)
      # load weights into new model
      model.load_weights(os.path.join(dir_path, "model_weights_trained.h5"))
      print("Loaded model from disk")
      model = create_model()
      train(model, epochs)
    model = create_model()

We start with creating a ztn.Context() object. Any data sent to Zetane via Python will be sent through this zcontext. Strings are sent via zcontext.text(), for example. we can then modify the look of this text via calls like .position(), .font() and .gradient().

zcontext = ztn.Context()
zmodel = zcontext.model()

Let’s also create a zimg = zcontext.image() as well here. It is not attached to any data yet, but we will use this zimg object in the predict function to visualize our input images in Zetane.

zimg = zcontext.image().position(-1.0, 1.5, 0.0)

ztxt_target = zcontext.text("Target class: ").position(-1.5, -1, 1).font('roboto-mono').font_size(0.16).billboard(True) \
    .color((1, 0, 1)).highlight((0.5, 0.5, 1))\
ztxt_pred = zcontext.text("Predicted class: ").position(-1.5, -1.5, 1).font('roboto-mono').font_size(0.16).billboard(True) \
    .color((1, 0, 1)).highlight((0.5, 0.5, 1))\

test_dataset =, y_test))
predict_in_zetane(model, test_dataset)

Total running time of the script: ( 0 minutes 0.000 seconds)

Gallery generated by Sphinx-Gallery