Training Neural Networks: How AI Learns to Learn

In previous articles, we explored a high-level overview of two kinds of neural networks, supervised neural networks and unsupervised neural networks. Today, we'll peel back another layer of this digital cortex and explore how these networks evolve from their 'naive' initial state to a 'trained' state where they can make decisions, identify patterns, or generate new creations. We previously explained concepts using coins as an analogy. We’ll continue a little further with that theme.

Supervised or Unsupervised Learning? How to Train a Machine

In the world of machine learning, 'training' a neural network is akin to coaching a new employee through the intricacies of sorting coins. As we'd seen earlier in our discussion, training can be approached via supervised learning or unsupervised learning. To grasp these concepts, it's crucial to understand 'labels.' Imagine each coin has a tag specifying its year of minting—that tag is a 'label.' It gives clear information about the coin, which can be used to guide the sorting process.

 

Now, let's explore how our coin analogy applies to these two foundational types of neural network training:
 

AspectSupervised LearningUnsupervised Learning
Data TypeLabeled DataUnlabeled Data
LearningThe algorithm learns from the provided labels.The algorithm infers patterns from the data.
FeedbackDirect Feedback (Corrective)Indirect Feedback (No explicit correction)
Success MetricAccuracy of sorting based on labels.Quality of the discovered groupings or patterns.


In supervised learning, the 'training' involves a set of data that's already labeled—like our coins with year tags. This method is akin to giving the employee a reference guide to sort coins into pre-defined categories. They receive direct feedback—such as being corrected when a coin is placed in the wrong year—and their performance is measured by how accurately they can apply these learnings to new coins.

 

Conversely, unsupervised learning deals with unlabeled data, akin to a pile of coins without year tags. The employee—or algorithm in this case—learns to discern and create categories based on patterns they observe, such as color, size, and weight, sometimes employing clustering algorithms like K-means or using dimensionality reduction techniques like PCA to uncover these patterns. They adjust their sorting criteria based on the inherent structure of the coins, without knowing if there's a 'right' or 'wrong' way to sort them. The success isn't about matching a label but about how effectively the algorithm can group coins and apply this to new sets. Autoencoders (covered later) could also be employed for unsupervised learning, where success would be to shorten the description of each coin without losing any important information that helps identify it.

 

Understanding these two approaches helps businesses decide how to implement machine-learning solutions. Whether sorting coins or sorting through complex data, the choice between supervised and unsupervised learning hinges on the nature of the data at hand and the specific goals of the task.

Balancing Knowledge and Flexibility in Training

When you're teaching an employee to sort coins by year, the goal is for them to grasp the broad features that characterize coins from different eras, rather than memorizing the specific details of each coin in front of them. It's akin to guiding a student to understand the principles behind the lessons instead of memorizing the textbook. Just as an overzealous student might memorize facts without grasping the underlying concepts, but struggle to adapt this knowledge to new problems, an employee might focus too narrowly on the coins they've already handled. To prevent this sort of "overfitting" in training, we introduce a variety of coins from different piles or sets throughout the process, while also ensuring that the model is not too simplistic to capture the underlying trend in the data, which could lead to “underfitting.”

 

The image below illustrates the concepts of  “overfitting” and “underfitting”.  Between these two extremes is a graphic showing the ideal model complexity in machine learning. On the left, the model labeled "Overfitting" has an excessive number of connections and layers, representing a highly complex network that may perform exceptionally on training data but poorly on unseen data due to capturing noise rather than the underlying pattern. On the right, the "Underfitting" model has sparse connections and layers, suggesting a model that is too simple to capture the complexity of the data, resulting in poor performance on both training and new data. The middle model, marked "Ideal," shows a balanced structure with an optimal number of layers and connections, indicating a well-tuned model that can apply accumulated learning well to new data. This visual metaphor helps convey the importance of model complexity in machine learning and the trade-off between a model's ability to learn from data and its capacity to generalize from that learning.

 

Cross-Validation Explained Through Coin Sorting

In cross-validation, we don't just show the neural network one set of data during training. Instead, we divide our data into several parts, or 'folds'. The neural network trains on some of these folds and then validates what it has learned on a different fold, a bit like a pop quiz. We rotate which fold is used for validation, ensuring each part of the data is used for testing the model, mimicking k-fold cross-validation in practice.

 

Returning to our coin analogy, it would be like having several bags of coins and asking the employee to sort one bag at a time while using coins from the other bags to test their sorting rules. This way, you make sure they can't just memorize the coins in front of them; they need to learn the general sorting principles that can apply to any bag of coins they might encounter.

The Role of Cross-Validation in Avoiding Overfitting

By using cross-validation, you help ensure the employee doesn't overfit to one specific set of coins. They develop a robust understanding of how to sort any coin by year, not just the ones they've seen. In machine learning, cross-validation helps us catch overfitting early by showing us how well the model performs on different subsets of the data—giving us confidence that it's truly learning the patterns rather than memorizing the noise.
 

This balanced training, with the help of cross-validation, is crucial for creating a model—or training an employee—that performs well in the real world, ready for the variety and unpredictability of new coins or data it will encounter.

How AI Learns to Improve: The Importance of Feedback in Machine Learning

The loss function in a neural network is analogous to a performance review that guides an employee's improvement. It's a measure of how well the network is doing its job. If our employee incorrectly sorts a 1995 coin into the 1996 pile, the loss function is akin to the supervisor pointing out the mistake and suggesting a closer look at the coin's features.

 

In technical terms, the loss function calculates the difference between the neural network's predictions and the actual target values. It's the cornerstone of learning, as it provides a quantitative basis for the network to adjust its weights, which means improving its sorting strategy. A good loss function will guide the network towards making fewer mistakes over time, just as constructive feedback helps our employee become more adept at sorting coins.

Wise use of a loss function ensures that our digital 'employee' doesn't just memorize the data (the result of “overfitting”)  but learns to apply its rules to sort through any new coins — or data — it encounters. This is critical for successfully training neural networks. 

 

Through iterations of this process, bolstered by feedback mechanisms like backpropagation and optimization algorithms, the neural network fine-tunes its parameters. This is how it evolves from a naïve state, with random guesses, to a trained state that makes informed predictions and decisions, much like our employee grows from a novice sorter to an expert coin classifier.

The Need for Computational Power in Training: The Role of GPUs

Training an AI model is a resource-intensive task that involves processing vast amounts of data and performing complex mathematical operations millions or even billions of times. This is where GPUs (Graphics Processing Units) come into play. Originally designed for rendering graphics, GPUs are incredibly efficient at matrix and vector computations, which are fundamental to the operations in neural network training. Their architecture allows them to execute many parallel operations simultaneously, significantly accelerating the training process.

Parallel Processing Power

Unlike CPUs, which are optimized for sequential task processing and handling a broad range of computations, GPUs are composed of thousands of smaller, more efficient cores designed for parallel processing. When training a neural network, a GPU can update thousands of weights at once, making it vastly faster than a CPU for this kind of task.

The Computational Heaviness of Training

During training, neural networks go through a process of backpropagation, where errors are calculated and propagated back through the network to adjust the weights. This process requires a considerable amount of computation and memory bandwidth. GPUs excel in this area due to their high number of cores and specialized design, which allows them to handle multiple calculations at lightning speeds. Additionally, tasks like gradient descent optimization and the tuning of hyperparameters are computationally expensive operations that benefit from the raw power of GPUs.

Why GPUs Aren't as Necessary for Inference

Once a model is trained, the heavy lifting has been done. The model no longer needs to learn; it simply applies what it has learned to make predictions. This task, while still computationally demanding, is less intense and can be handled efficiently by CPUs, which are more common in everyday devices. During inference, the neural network performs a straightforward series of matrix multiplications as data passes through the trained network, a task well within the capabilities of modern CPUs, especially when optimized for these operations.

 

Once a model is trained, the heavy lifting has been done. The model no longer needs to learn; it simply applies what it has learned to make predictions. This task, while still computationally demanding, is less intense and within the capabilities of modern CPUs (although, depending on the model, a GPU may still be necessary for performance). During inference, the neural network performs a straightforward series of matrix multiplications as data passes through the trained network, a task well within the capabilities of modern CPUs, especially when optimized for these operations.

Energy Efficiency in Inference and the Role of Edge Devices

Energy efficiency becomes paramount when we talk about edge devices. But what exactly is an edge device? An edge device is a piece of hardware that processes data closer to the source of data generation (like a camera or a smartphone) rather than relying on a centralized data-processing warehouse. These devices typically have constraints on power, size, and processing capacity. CPUs in these devices are engineered to provide the necessary computational power for real-time AI applications such as voice recognition or image processing, while also conserving energy to maintain battery life and device longevity. Complementing these advancements, specialized chips for running neural networks exist and are becoming increasingly common on the edge, further enhancing their capabilities.

Division of Labor

In the realm of AI, GPUs play a crucial role in training models by leveraging their parallel processing capabilities to handle the computationally intense tasks of learning and optimization. Once the model is trained, however, the baton is passed to CPUs, especially in edge devices, for efficient and energy-conservative inference. This division of labor between GPUs for training and CPUs for inference is what enables AI to be both powerful in development and practical in deployment.

topic previous button
topic next button
Pete
Pete Slade
November 27, 2023