PyTorch Fashion MNIST

This script provides a basic example of how the Zetane Python API and the Zetane Engine work together, using PyTorch and the FashionMNIST dataset. The script involves:

  • Building a convolutional neural network (CNN) on PyTorch and training it,

  • Sending the resulting model to the Zetane Engine, and

  • Training / Testing the model in the Zetane Engine with sample images,

  • Produce Grad-CAM images during training / inference in the engine.

# imports
import numpy as np

import torch
import torchvision
import torchvision.transforms as transforms

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.onnx
import sys
import os
import math
from PIL import Image
import time

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
from zetane.explain.torch import gradcam
from zetane.explain.torch import preprocess_image, get_layers
from zetane.utils import remap
import zetane.context as ztn

Creating the Model

Here’s our CNN class. Two convolutional layers with pooling between them, and three fully connected layers. Feel free to experiment with different architectures, but do keep in mind that FashionMNIST is a basic dataset with only 10 output classes, so increasing the complexity of the models will not necessarily yield better results.

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 32, 5)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(-1, 32 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Function to produce GradCAM heatmap images for predicted classes
def run_gradcam(net, out_class, class_dict, prep_img):
    temp_g = gradcam(net, out_class, class_dict, prep_img)

def evaluation(net, dataloader):
    total, correct = 0, 0

    # keeping the network in evaluation mode
    for data in dataloader:
        inputs, labels = data
        # moving the inputs and labels to gpu
        inputs, labels =,
        outputs = net(inputs)
        _, pred = torch.max(, 1)
        total += labels.size(0)
        correct += (pred == labels).sum().item()
    return 100 * correct / total

Training in PyTorch

Our training loop is a traditional PyTorch training loop:

def train(net, criterion, optimizer, epochs=5, visual_training=False):
    loss_arr = []
    loss_epoch_arr = []
    running_loss = 0.0

    # The training loop - this is where the fun begins...
    for epoch in range(epochs):
        total, correct = 0, 0
        ztxt_epoch.text("Epoch: " + str(epoch + 1)).update()

        for i, data in enumerate(trainloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            inputs, labels =,

            # zero the parameter gradients

            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)

We call net(inputs) on each batch and compute loss via criterion(outputs, labels). We then zero the optimizer gradients, perform a backward pass using the loss to compute gradients, and step the optimizer. Additionally, we record some variables to display metrics in the engine. If we want to display the inputs and outputs in Zetane, we can take care of that with a few additional lines. The variable visual_training argument must be set to True if we want to do so.

target = labels.cpu().numpy()[0]
pred = torch.argmax(F.softmax(outputs, dim=1)).detach().cpu().numpy()
total += labels.size(0)
if (pred == target):
    correct += 1

running_loss += loss.item()
# let's display the whole pipeline every 100 mini-batches...
if visual_training and i % 100 == 1:
    acc = (correct / total)

    ztxt_target.text("Target class: " + classes[target] + "- " + str(target)).update()
    ztxt_pred.text("Predicted class: " + classes[pred] + "- " + str(pred)).update()

    mean = [0.5, 0.5, 0.5]
    std = [0.5, 0.5, 0.5]

    np_img = inputs.cpu().numpy()
    normalized = remap(np_img, (-1.0, 1.0))

    zonnx.torch(net, inputs).update()

    #img = inputs[0, 0].cpu().numpy().astype(np.uint8)
    #img = np.array([img, img, img]).transpose(1, 2, 0)
    #prep_img = preprocess_image(img, mean, std, size=(32, 32), resize_im=False)
    #run_gradcam(net, pred, classes, prep_img)

We run PyTorch training on our data to get the output class, and update the ztxt variables to reflect the target, predicted classes, loss, accuracy, correct count, and total count Next, we send our model to Zetane via zmodel.torch(net, inputs).update(inputs=np_img). This converts the PyTorch model to an ONNX model that is renderable in Zetane, and also inputs the data to the model to produce the intermediate visualizations. Finally, we call run_gradcam(net, inputs, classes, zimg) which defines a Gradcam object that takes in an input image and the model, and produces Grad-CAM heatmaps. It then sends the resulting heatmaps to the engine. For more information on Grad-CAM, see the original paper or this explanatory introduction.

        running_loss = 0.0
        print('Epoch: %d/%d, Test acc: %0.2f, Train acc: %0.2f' %
        (epoch+1, epochs, evaluation(net, testloader_batch), evaluation(net, trainloader_batch))), '')

print('Finished Training')

Finally, we have testing! The inference loop is almost identical to the training loop, except we do not perform backward pass, naturally. The process to visualize the inference process is identical to training, too:

def predict_in_zetane(net, criterion):
    running_loss = 0.0
    total, correct = 0, 0
    for i, data in enumerate(testloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        inputs, labels =,

        # set the model to eval mode

        if i == 0:
            zonnx.torch(net, inputs).update()

        # get predictions
        outputs = net(inputs)
        loss = criterion(outputs, labels)

        target = labels.cpu().numpy()[0]
        pred = torch.argmax(F.softmax(outputs, dim=1)).detach().cpu().numpy()
        total += labels.size(0)
        if (pred == target):
            correct += 1

        running_loss += loss.item()
        # let's display the whole pipeline every 100 mini-batches...
        if i % 10 == 1:
            acc = (correct / total) * 100

            ztxt_target.text("Target class: " + classes[target] + "- " + str(target)).update()
            ztxt_pred.text("Predicted class: " + classes[pred] + "- " + str(pred)).update()

            #run_gradcam(net, inputs, classes, zimg)

# Now, the main function: We initialize our model and define our loss and optimizer and begin training

def main(load_model=True):
    print("Building pytorch model")
    net = Net()
    # Initialize convnet, define loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    train(net, epochs=5, criterion=criterion,
            optimizer=optimizer, visual_training=True)

    predict_in_zetane(net, criterion=criterion)

Data Loading and Processing

PyTorch already has FashionMNIST available in torchvision.datasets.FashionMNIST, but one can download it manually from as well. Note that we create two sets of loaders, one pair is to load the data in batches (optimized for performance), the other pair is to load each example one by one so that we can visualize them sequentially in Zetane.

# transforms
transform = transforms.Compose(
     transforms.Normalize((0.5,), (0.5,))])

# datasets
trainset = torchvision.datasets.FashionMNIST('data',
testset = torchvision.datasets.FashionMNIST('data',

# dataloaders
trainloader =, batch_size=1,
                                          shuffle=True, num_workers=0)

testloader =, batch_size=1,
                                         shuffle=False, num_workers=0)

trainloader_batch =, batch_size=128,
                                                shuffle=True, num_workers=0)

testloader_batch =, batch_size=128,
                                               shuffle=False, num_workers=0)

# Class names
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
           'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Preparing the Zetane Engine

Before we start training, let’s set up our variables with the Zetane API. We’ll initialize a zetane.Context() first, which we’ll use to send all Python data to the Zetane engine. We then initialize two zcontext.image() variables, one to display the input image, and the other for the Grad-CAM heatmaps. We initialize a zcontext.onnx() object through which we will send our PyTorch model to Zetane, and two zcontext.text() objects to display the target and predicted classes. We can then modify the look of this text via calls like .position(), .font() and .gradient().

zcontext = ztn.Context().launch()

# Zetane image object for the input
zimg_input = zcontext.image().position(x=-1.88, y=3.76).update()

# Zetane image object for GradCAM heatmaps
zimg_gradcam = zcontext.image().position(0.0, -4.0, 0.1).update()

# Zetane ONNX object for the convnet
zonnx = zcontext.model()

# Graphs
zgraph = zcontext.chart('Loss Chart', [0.0, 1.0], [0.0, 1.0], visual_type='Line').scale(x=0.25, y=0.25).position(x=-2.6, y=-3.9)
loss_metric = zgraph.metric(label = 'Loss')
accuracy_metric = zgraph.metric(label = 'Accuracy')

# Zetane text objects to display target and predicted classes
# Note that these calls are long because we are demonstrating several possible methods to change on text objects.
ztxt_target = zcontext.text("Target class: ").position(-2.7, 2.95, 0).font('slab').font_size(0.16).billboard(True) \
    .color((1, 0, 1)).highlight((0.5, 0.5, 1))\
ztxt_pred = zcontext.text("Predicted class: ").position(-2.7, 2.45, 0).font('slab').font_size(0.16).billboard(True) \
    .color((1, 0, 1)).highlight((0.5, 0.5, 1))\
ztxt_epoch = zcontext.text("Epoch: ").position(-2.7, 0.23, 0).font('slab').font_size(0.16).billboard(True) \
    .color((1, 0, 1)).highlight((0.5, 0.5, 1))\

Run the script via python

if __name__ == "__main__":

Total running time of the script: ( 0 minutes 0.000 seconds)

Gallery generated by Sphinx-Gallery