Data-driven decision-making, particularly through machine-learned algorithms, is more prevalent now than ever. Being data-driven only matters if you have the right data, which raises the topic of “distribution shift.” Distribution shift is a mismatch between training and real-world data, and it can arise several ways. For example, circumstances may evolve over time, resulting in changes to incoming data, or a lack of true data samples can force a company to rely on narrowly-constructed artificial datasets. Distribution shift is a real challenge that requires a thoughtful mitigation strategy, but it can be addressed.
Generalizability is a classic and persistent problem in AI
The remarkable capabilities of ChatGPT, especially in zero-shot and few-shot learning, may give the impression that the field of artificial intelligence (AI) has advanced to a stage where training data is merely a formality. But the rules for AI development haven’t changed– AI is still sensitive to spurious characteristics in its training data that can cause it to “miss the big picture”:
- Correlations in the data that arise by chance
- Imperceptible artifacts in images from the camera system used to take them
- Specific phrases disproportionately present in training text
Gradient descent-based training can exhibit a greedy nature that exploits these spurious characteristics. As a result, training may appear successful, but the model, latching onto these spurious features, fails to generalize adequately to real-world data that lacks them. In this sense, the model “misses the big picture.”
Fortunately, techniques like dropout, gradient clipping, multi-task learning, data augmentation, and large-scale pretraining can help to overcome these data limitations to improve generalizability. However, distribution shift is a greater challenge as it results from a fundamental information gap between training and real-world data.
What is distribution shift?
Imagine training an algorithm to differentiate drivers’ licenses from various states. If the model was trained on decades-old drivers’ licenses and applied to contemporary formats, how effective would it be? If the training set was heavily imbalanced, with the majority of licenses coming from high-population states and only a few from low-population states, how well would the model perform on a large collection from the latter?
This example captures the two types of distribution shift: one resulting from changes over time, and another resulting from disparities in data proportions. Despite the difficulties this presents to an AI system, humans effortlessly make sense of varying license formats, including new ones. How can we train our AI systems to adapt to these distribution shifts nearly as effectively?
Do you have a distribution shift problem?
One straightforward way to detect distribution shift is through its deleterious effects on accuracy. Another strategy is to examine the distribution of class labels on real-world data and compare it with the distribution of class labels on the training data. If significant deviations are present, there is likely a distribution shift.
In addition to analyzing class labels, distribution shift can be detected by model confidence levels. Even if the distribution of the class labels does not significantly change, the confidences behind those labels may. If the distribution of confidences exhibits a significant change, it might indicate the presence of distribution shift.
Statistics of the data independent of the model can be considered as well. This can be done by analyzing manually-developed features or features from an unsupervised model. These can increase the chances of finding even subtle distribution shifts over time.
Mitigating distribution shift
If distribution shifts in time are a concern, retraining can serve as a reliable mitigation strategy. This can be done periodically or when distribution shift is detected.
However, if distribution shift results from lacking sufficient real-world data, different strategies are needed. These strategies require some knowledge about the expected types of distribution shifts to make up for the information gap. We can demonstrate two straight-forward strategies using the prior case of drivers licenses:
- Apply a two-stage classification algorithm of (1) optical character recognition (OCR) to extract text followed by (2) a classifier to identify the issuing state based on that text.
- Augment the training dataset with synthetic data. The synthetic data can be made by moving elements of a driver’s license to various positions, altering foreground and background imagery, and incorporating other variations that appear to be reasonable modifications to a driver’s license.
In the first case, we are using a feature (the license’s text) that we expect to be immune from distribution shifts over time. In the second case, we are leveraging our knowledge of drivers’ license formats to create new formats that are within reason. In both cases, we are explicitly accounting for an information gap by incorporating information from our intuitive understanding of real-world data.
We live in a world run by technology where being data-driven is synonymous with success. Data-driven machine-learned algorithms can be essential to accuracy in important situations, but also for routine frontline work allowing companies to produce real-time insights that can help predict and improve performance. As AI systems continue to become more prevalent in decision making, it is increasingly important for companies to understand both their limitations and the strategies to mitigate those limitations so that they can be effectively leveraged for sustained growth.
About the Author
Michael Rinehart is the VP of Artificial Intelligence at Securiti.ai, a unified data control company that manages security, compliance, and privacy risks. Throughout his career, Michael has deployed machine learning and data science systems to numerous domains, including Internet security, health care, power electronics, automotives and marketing. Prior to joining Securiti, he led the research and development of a machine learning-based wireless communications jamming technology at BAE Systems. Michael has also held roles in cloud security, big data and engineering at companies including Elastica (acquired by Symantec) and Verizon. Michael holds a Ph.D. in electrical engineering from MIT.
Sign up for the free insideBIGDATA newsletter.
Join us on Twitter: https://twitter.com/InsideBigData1
Join us on LinkedIn: https://www.linkedin.com/company/insidebigdata/
Join us on Facebook: https://www.facebook.com/insideBIGDATANOW