Yes, model size matters.
One of the biggest unsolved challenges in continual learning is preventing forgetting previously learned information upon acquiring new information. Known as “catastrophic forgetting,” this phenomenon is particularly pertinent in scenarios where AI systems must adapt to new data without losing valuable insights from past experiences. Numerous studies have investigated different approaches to solving this problem in the past years, mostly around proposing innovative strategies to modify the way models are trained and measuring its impact on model performance, such as accuracy and forgetting.
Yet, compared to the numerous amount of studies done in establishing new strategies and evaluative approaches in visual continual learning, there is surprisingly little discussion on the impact of model size. It is commonly known that the size of a deep learning model (the number of parameters) is known to play a crucial role in its learning capabilities
In this blog post, I explore the following research question: How do network depth and width impact model performance in an online continual learning setting? I set forth a hypothesis based on existing literature and conduct a series experiments with models of varying sizes to explore this relationship. This study aims to shed light on whether larger models truly offer an advantage in mitigating catastrophic forgetting, or if the reality is more nuanced.
Continual learning (CL), also known as lifelong learning or incremental learning, is an approach that seeks to continually learn from non-iid data streams without forgetting previously acquired knowledge. The challenge in continual learning is generally known as the stability-plasticity dilemma
While traditional CL models assume new data arrives task by task, each with a stable data distribution, enabling offline training. However, this requires having access to all task data, which can be impractical due to privacy or resource limitations. In this study, I will consider a more realistic setting of Online Continual Learning (OCL), where data arrives in smaller batches and are not accessible after training, requiring models to learn from a single pass over an online data stream. This allows the model to learn data in real-time
Online continual learning can involve adapting to new classes (class-incremental) or changing data characteristics (domain-incremental). Specifically, for class-incremental learning, the goal is to continually expand the model’s ability to recognize an increasing number of classes, maintaining its performance on all classes it has seen so far, despite not having continued access to the old class data
Popular methods to mitigate catastrophic forgetting in continual learning generally fall into three buckets:
Moreover, there are many methods that combine two or more of these techniques such as Averaged Gradient Episodic Memory (A-GEM)
Among the methods, Experience Replay (ER) is a classic replay-based method and widely used for online continual learning. Despite its simplicity, recent studies have shown ER still outperforms many of the newer methods that have come after that, especially for online continual learning
It is generally known across literature that deeper models increase performance
He et al. introduced Residual Networks (ResNets)
ResNet18 | ResNet34 | ResNet50 | ResNet101 | ResNet152 | |
---|---|---|---|---|---|
Number of Layers | 18 | 34 | 50 | 101 | 152 |
Number of Parameters | ~11.7 million | ~21.8 million | ~25.6 million | ~44.5 million | ~60 million |
Top-1 Accuracy | 69.76% | 73.31% | 76.13% | 77.37% | 78.31% |
Top-5 Accuracy | 89.08% | 91.42% | 92.86% | 93.68% | 94.05% |
FLOPs | 1.8 billion | 3.6 billion | 3.8 billion | 7.6 billion | 11.3 billion |
This leads to the question: do larger models perform better in continual learning? While much of the focus in continual learning research has often been on developing various strategies, methods, and establishing benchmarks, the impact of model scale remains a less explored path.
Moreover, recent studies on model scale in slightly different contexts have shown conflicting results. Luo et al.
The relative lack of discussion on model size and the conflicting perspectives among existing studies indicate that the answer to the question is far from being definitive. In the next section, I will describe further how I approach this study.
Online continual learning can be defined as follows
The objective is to learn a function $f_\theta : \mathcal X \rightarrow \mathcal Y$ with parameters $\theta$ that predicts the label $Y \in \mathcal Y$ of the input $\mathbf X \in \mathcal X$. Over time steps $t \in \lbrace 1, 2, \ldots \infty \rbrace$, a distribution-varying stream $\mathcal S$ reveals data sequentially, which is different from classical supervised learning.
At every time step,
In the context of class-incremental learning, I will adopt the definitions of task-agnostic and boundary-agnostic from Soutif et al. 2023
Yes | No | |
---|---|---|
Task labels | Task-aware | Task-agnotic |
Task boundaries | Boundary-aware | Boundary-agnostic |
In a class-incremental learning setting, the nature of the Experience Replay (ER) method aligns well with task-agnostic and boundary-agnostic settings. This is because ER focuses on replaying a subset of past experiences, which helps in maintaining knowledge of previous classes without needing explicit task labels or boundaries. This characteristic of ER allows it to adapt to new classes as they are introduced, while retaining the ability to recognize previously learned classes, making it inherently suitable for task-agnostic and boundary-agnostic continual learning scenarios.
Implementation-wise, ER involves randomly initializing an external memory buffer $\mathcal M$, then implementing before_training_exp
and after_training_exp
callbacks to use the dataloader to create mini-batches with samples from both training stream and the memory buffer. Each mini-batch is balanced so that all tasks or experiences are equally represented in terms of stored samples
For this study, the SplitCIFAR-10
Key metrics established in earlier work in online continual learning are used to evaluate the performance of each model.
Average Anytime Accuracy (AAA) as defined in
The concept of average anytime accuracy serves as an indicator of a model’s overall performance throughout its learning phase, extending the idea of average incremental accuracy to include continuous assessment scenarios. This metric assesses the effectiveness of the model across all stages of training, rather than at a single endpoint, offering a more comprehensive view of its learning trajectory.
\[\text{AAA} = \frac{1}{T} \sum_{t=1}^{T} (\text{AA})_t\]Average Cumulative Forgetting (ACF) as defined in
This equation represents the calculation of the Cumulative Accuracy ($b_k^t$) for task $k$ after the model has been trained up to task $t$. It computes the mean accuracy over the evaluation set $E^k_\Sigma$, which contains all instances $x$ and their true labels $y$ up to task $k$. The model’s prediction for each instance is given by $\underset{c \in C^k_\Sigma}{\text{arg max }} f^t(x)_c$, which selects the class $c$ with the highest predicted logit $f^t(x)_c$. The indicator function $1_y(\hat{y})$ outputs 1 if the prediction matches the true label, and 0 otherwise. The sum of these outputs is then averaged over the size of the evaluation set to compute the cumulative accuracy.
\[b_k^t = \frac{1}{|E^k_\Sigma|} \sum_{(x,y) \in E^k_\Sigma} 1_y(\underset{c \in C^k_\Sigma}{\text{arg max }} f^t(x)_c)\]From Cumulative Accuracy, we can calculate the Average Cumulative Forgetting ($F_{\Sigma}^t$) by setting the cumulative forgetting about a previous cumulative task $k$, then averaging over all tasks learned so far:
\[F_{\Sigma}^t = \frac{1}{t-1} \sum_{k=1}^{t-1} \max_{i=1,...,t} \left( b_k^i - b_k^t \right)\]Average Accuracy (AA) and Average Forgetting (AF) as defined in
$a_{i,j}$ is the accuracy evaluated on the test set of task $j$ after training the network from task 1 to $i$, while $i$ is the current task being trained. Average Accuracy (AA) is computed by averaging this over the number of tasks.
\[\text{Average Accuracy} (AA_i) = \frac{1}{i} \sum_{j=1}^{i} a_{i,j}\]Average Forgetting measures how much a model’s performance on a previous task (task $j$) decreases after it has learned a new task (task $i$). It is calculated by comparing the highest accuracy the model $\max_{l \in {1, \ldots, k-1}} (a_{l, j})$ had on task $j$ before it learned task $k$, with the accuracy $a_{k, j}$ on task $j$ after learning task $k$.
\[\text{Average Forgetting}(F_i) = \frac{1}{i - 1} \sum_{j=1}^{i-1} f_{i,j}\] \[f_{k,j} = \max_{l \in \{1,...,k-1\}} (a_{l,j}) - a_{k,j}, \quad \forall j < k\]In the context of class-incremental learning, the concept of classical forgetting may not provide meaningful insight due to its tendency to increase as the complexity of the task grows (considering more classes within the classification problem). Therefore,
To compare learning performance across varying model depths, I chose to use the popular ResNet architectures, particularly ResNet18, ResNet34, and ResNet50. As mentioned earlier in this blog, ResNets were designed to increase the performance of deeper neural networks, and their performance metrics are well known. While using custom models for more variability in sizes was a consideration, existing popular architectures were chosen for better reproducibility.
Moreover, while there are newer versions (i.e. ResNeXt
Moreover, in order to observe the effect of model width on performance, I also test a slim version of ResNet18 that has been used in previous works
I use saliency maps to visualize “attention” of the networks. Saliency maps are known to be useful for understanding which parts of the input image are most influential for the model’s predictions. By visualizing the specific areas of an image that a CNN considers important for classification, saliency maps provide insights into the internal representation and decision-making process of the network
The continual learning benchmark was implemented using the Avalanche framework
Experiment 1 | Experiment 2 | Experiment 3 | Experiment 4 | Experiment 5 | Experiment 6 | Experiment 7 | |
---|---|---|---|---|---|---|---|
Model | ResNet18 | ResNet34 | ResNet50 | SlimResNet18 | ResNet18 | ResNet34 | ResNet50 |
Strategy | Experience Replay | Experience Replay | Experience Replay | Experience Replay | Experience Replay | Experience Replay | Experience Replay |
Benchmark | SplitCIFAR10 | SplitCIFAR10 | SplitCIFAR10 | SplitCIFAR10 | SplitCIFAR10 | SplitCIFAR10 | SplitCIFAR10 |
Training | Online | Online | Online | Online | Offline | Offline | Offline |
GPU | V100 | T4 | A100 | T4 | T4 | T4 | T4 |
Training time (estimate) | 3h | 4.5h | 5h | 1h | <5m | <5m | <5m |
Average Anytime Accuracy (AAA) decreases with model size (Chart 1), with a sharper drop from ResNet34 to ResNet50. The decrease in AAA is more significant in online learning than offline learning.
When looking at average accuracy for validation stream for online CL setting (Chart 2), we see that the rate to which accuracy increases with each task degrade with larger models. Slim-ResNet18 shows the highest accuracy and growth trend. This could indicate that larger models are worse at generalizing to a class-incremental learning scenario.
Average Anytime Acc (AAA) | Final Average Acc | |
---|---|---|
Slim ResNet18 | 0.664463 | 0.5364 |
ResNet18 | 0.610965 | 0.3712 |
ResNet34 | 0.576129 | 0.3568 |
ResNet50 | 0.459375 | 0.3036 |
Now we turn to forgetting.
Looking at Average Cumulative Forgetting (ACF), we see that for online CL setting, ResNet34 performs the best (with a slight overlap at the end with ResNet18), and ResNet50 shows the mosts forgetting. An noticeable observation in both ACF and AF is that ResNet50 performed better initially but forgetting started to increase after a few tasks.
However, results look different for offline CL setting. ResNet50 has the lowest Average Cumulative Forgetting (ACF) (although with a slight increase in the middle), followed by ResNet18, and finally ResNet34. This differences in forgetting between online and offline CL setting is aligned with the accuracy metrics earlier, where the performance of ResNet50 decreases more starkly in the online CL setting.
Visual inspection of the saliency maps revealed some interesting observations. When it comes to the ability to highlight intuitive areas of interest in the images, there seemed to be a noticeable improvement from ResNet18 to ResNet34, but this was not necessarily the case from ResNet34 to ResNet50. This phenomenon was more salient in the online CL setting.
Online
Offline
Interestingly, Slim-ResNet18 seems to be doing better than most of them, certainly better than its plain counterpart ResNet18. A further exploration of model width on performance and representation quality would be an interesting avenue of research.
Slim-ResNet18
In this study, I compared key accuracy and forgetting metrics in online continual learning across ResNets of different depths and width, as well as brief qualitative inspection of the models’ internal representation. These results show that larger models do not necessary lead to better continual learning performance. We saw that Average Anytime Accuracy (AAA) and stream accuracy dropped progressively with model size, hinting that larger models struggle to generalize to newly trained tasks, especially in an online CL setting. Forgetting curves showed similar trends but with more nuance; larger models perform well at first but suffer from increased forgetting with more incoming tasks. Interestingly, the problem was not as pronounced in the offline CL setting, which highlights the challenges of training models in a more realistic, online continual learning context.
Why do larger models perform worse at continual learning? One of the reasons is that larger models tend to have more parameters, which might make it harder to maintain stability in the learned features as new data is introduced. This makes them more prone to overfitting and forgetting previously learned information, reducing their ability to generalize.
Building on this work, future research could investigate the impact of model size on CL performance by exploring the following questions:
To conclude, this study has empirically explored the role of model size on performance in the context of online continual learning. Specifically, it has shown that model size matters when it comes to continual learning and forgetting, albeit in nuanced ways. These findings contribute to the ongoing discussions on the role of the scale of deep learning models on performance and have implications for future area of research.