GitHub
Main detector circuits and their activation maps for three primitive detectors of a mnist classifier for each class of digit ranging from 0 to 9. See section 3 for an in-depth explanation.
1

Introduction

In the recent years, a lot of research has been done in the field of interpretability, providing us with powerful tools to help us to understand our models. Methods like GradCam enabled researchers to get a hint at what image models were looking for to make their prediction.

Understanding the inner workings of our models is not only a research interest, but also provides a baseline for developing safe AI for a broader use and maybe even extract knowledge .

This article focuses on circuit based interpretability , which is a new way of looking at atomic structures in neural networks. We will use the methods developed by this field to gain useful insights on how a classifier model classifies handwritten digits.

It is also intended to act as a boilerplate for everyone trying to understand their models or needing a head start in the field of circuit based interpretability.

2

Classification

Let's begin by looking at the model and the task we are going to inspect in this article.

We are going to train ourself a

3

Circuit based interpretability

A circuit is a combination of primitive feature detectors, that enables a neural network to detect essential features of its input.

3.1

Feature Visualization

Feature Visualization of the last layer before a softmax activation for each class. This shows us the combination of primitive detector circuits used by the model to classify handwritten digits . The detailed model specifications and a visualization for every channel can be found in the appendix.

Using Feature Visualization, we are able to understand visually what certain part of our model strongly react to. Through optimizing the activation of certain neurons, channels, regions or whole layers, we are able to generate an input for our model, that shows us what excites the inspected part of our model the most.

In general there are five main types of feature visualizations:

Neuron activation can be used as an optimization objective to

Channel activation can be used as an optimization objective to

Layer activation can be used as an optimization objective to

Pre-Softmax activation can be used as an optimization objective to

Post-Softmax activation can be used as an optimization objective to

These objectives are the building blocks for complex analyses of our models. We can perform joint optimizations by combining different objectives for the optimization step.

activation = {} # dictionary to store the activation of a layer def create_hook(name): def hook(m, i, o): # copy the output of the given layer activation[name] = o return hook model.fc2.register_forward_hook(create_hook("fc2")) class Objective: def __init__(self, model, layer): self.activation = {} self.model = model def add_forward_hook(self, layer, name): def create_hook(name): def hook(m, i, o): # copy the output of the given layer self.activation[name] = o.unsqueeze(0) return hook self.model[layer].register_forward_hook(create_hook(name)) def __call__(self, name): return NotImplementedError class neuron_objective(Objective): def __call__(self, name, channel, neuron): return -self.activation[name][channel, neuron].mean() class channel_objective(Objective): def __call__(self, name, channel): return -self.activation[name][channel, :].mean() class layer_objective(Objective): def __call__(self, name): return -self.activation[name][:, :].mean() 3.2

Circuits

Reviewers

Some text with links describing who reviewed the article.