And how can it be improved so our machine learning model trains better?
Most of the time, we can’t answer these questions. The usual metrics we use to measure how well our model is performing — from ROC curves to F1 scores — measure a model’s aggregate performance across the whole dataset. Try to ask what subsets of the data is causing problems, or what patterns in the data are problematic, and our toolbox comes up empty.
I have personal experience with this problem: our text categorisation model had disappointingly low per-category F1 scores, yet our AUC scores were somehow all hitting 99. We knew our data was kind of dirty, for some categories more than others, but which parts of the data was affecting the predictions, and how to fix the data problem, we couldn’t say. We made some vague recommendations about getting more labelled data, but were otherwise at a loss.
We just didn’t have the mental framework or tools to debug our training data. It doesn’t help that most of the literature we found around machine learning interpretability was focused on understanding models, rather than the data used to train those models. For example, random Forests have feature importance scores that show how much a model relies on a feature to give its final prediction, Neural Networks have Concept Activation Vectors that show which parts of a picture a neural network attends to.
There’s a whole book on algorithmic transparency, local and global model interpretability and other model interpretability methods. However, apart from manual and painstaking error analysis, we didn’t have a systematic way of tying model predictions and performance back the to training data.
This dataset-as-a-black-box approach to machine learning affects how well we can optimise our training processes. By only evaluating our model but neglecting to connect performance back to the data, most of the time the main recourse we have for improving our accuracy scores is tweaking the model. We try a different model architecture, tune some hyperparameters, or just add more data.
The reality is that models and data go together. If a model doesn’t perform well in practice, we are not confined to finding better hyperparameters. One viable solution is to define better labels. For example, for text classification, instead of producing general labels like “environment”, we might define more granular labels like “climate change” or “pollution”. If there are articles about these topics in our datasets, and these topics are interesting to domain experts, then we can focus on getting the model to learn these more specific, relevant patterns from the data. Another solution would be to constrain the type of data we use to train our model. Rather than feeding it both forum posts and formal articles, we can stick to one type of text data. Our model doesn’t have to learn a shared representation between different text formats. Yet, although we talk about iterating on our models a lot, we don’t talk much about iterating on our dataset.
Still, there are interesting, if still somewhat scattered, developments in this space. SHAP values, for example allow us to see how the features of a single datapoint contribute negatively or positively to its final prediction (local interpretability). There are also prototype examples and counterfactual examples that construct meaningful instances that explaining a machine learning model. These approaches even have python implementations for practitioners.
I am particularly excited about a few projects. Uber’s Manifold tool takes an elegant approach based on creating meaningful subsets and clusters of the data. In this example, one segment of the data (perhaps one with a lower model loss), has a different feature distribution from another segment of the data (perhaps one with a higher loss). This chart effectively allows us to see how patterns in the data is affecting the model’s predictions.
Google’s What-if tool also allows users to interactively select datapoints and view their counterfactuals (examples that are similar in terms of features, but ended up having a different prediction).
Taken together, there are three principles that I feel are relevant to effectively debugging training datasets :
- subset the data (this could be done interactively, Exploratory Data Analysis style, or with predefined approaches like clustering)
- within the context of features, predictions and ground truth
- then visualise the results!
I also feel it’s important to note that none of these methods needs fancy algorithmic methods or metrics. Subsetting and visualisation are simple, but if used right, can be very effective.