Neural scaling laws course W22; Lecture 15
Presenter : Léo Gagnon
Routing networks and motivation
Plan of the presentation
Routing networks
A routing network consists of a set of modules (parameterized functions) from which a router can choose a composition.
**Note that in certain architectures the input can be routed to many modules at the same time.**
Motivation 1 : Compositionality and generalisation
Structure in the world often has a compositional and modular structure : things are made of parts which can be recombined in other things. When the world/task/distribution changes these parts remain unchanged.
By separately learning modules and a way to compose them, routing networks can better adapt to changes in distribution.
Motivation 2 : Disentangle model size and compute
It is increasingly evident that in many cases mode performance scales with size (i.e. Kaplan et al. 2020). However, so does compute (FLOPs). This undesirable connection between size and computation motivates a search for architectures wherein the two are disentangled.
Routing networks disentangle total number of parameters and compute cost.
Motivation 2’ : Improved parallelism
Since large models do not fit on any single device, in practice it is necessary to distributed the model on several devices (model parallelism). However, this can significantly slows down execution and communication cost can be prohibitively expensive.
Routing networks enables efficient model parallelism because modules do not interact and can be executed in parallel.
Challenges
Training a routing network is non-stationary from both the perspective of the router, and from the perspective of the modules, because the optimal composition strategy depends on the module parameters and vice versa. This gives rise to many challenges :
From a more practical POV, it is important to balance the modules so that each device is used efficiently and that all modules learn at the same rate (helps resolving collapse and diversity)
Routed language models
Routed language models
We introduce routing in a large transformer language models by replacing some (R=0.5) of the FFNs by routing networks : each token in the batch is sent to K (we use 1 for now) experts by a routing function.
Digression : Role of the FFN in a Transformer
“We show that feed-forward layers in transformer-based language models operate as key-value memories, where each key correlates with textual patterns in the training examples, and each value induces a distribution over the output vocabulary” -Geva et al. (2021)
Routed language models : Architectures
The paper considers three different choices for the routing process :
Sparse Mixture-of-Experts
A routing function produces a distribution over experts for every token. The token is the processed into a convex combination of the top K experts :
We will consider K=1 (as in the picture) for now and discuss this point further later.
Sparse Mixture-of-Experts : Balancing the load (1)
Naïve SMOEs leads to poor load balancing, which needs to be addressed. This is done is three ways.
1) Addition of an auxiliary load-balancing loss
where is the mean probability assigned for expert e in the batch and is the fraction of tokens dispatched to expert e in the batch.
Only is differentiable.
Sparse Mixture-of-Experts : Balancing the load (2)
Naïve SMOEs leads to poor load balancing, which needs to be addressed. This is done is three ways.
2) Balanced assignment optimisation during training
To make sure that the assignment are especially well balanced during training, we add an additional step where we iteratively normalize the expert distribution generated by the router so that all expert have an average assignment probability of 1/E.
The method is called the the Sinkhorn algorithm.
BASE uses a hard assignment optimisation instead.
Sparse Mixture-of-Experts : Balancing the load (3)
Naïve SMOEs leads to poor load balancing, which needs to be addressed. This is done is three ways.
3) Expert capacity and overflow
It can still happen that assignment is not exactly balanced. We can tolerate some amount of uneven assignment by giving additional capacity to each device.
If, even then, an expert receives too many tokens, they go straight to the next layer.
Routing with Reinforcement Learning
Here the token assignment problem is modeled as a one-step MDP where the observation is the token, the actions are the experts and the reward is the probability that the overall model assigns to the right next token. A policy gradient (REINFORCE) loss is added to the language modeling loss (Xent).
This is similar to the SMOEs but the routing process is optimised directly. However, the high variance of the gradient is problematic, especially when the number of expert grows. Authors experiment with various improvement to naïve REINFORCE (e.g. baseline).
Balancing tricks (2) and (3) of SMOEs are also used.
Input-based Deterministic Hash Routing
Here the token assignment is determined as a fixed function of its ID. The paper uses the ID modulo E.
“Finally, given that our routing approach is learning free, our results perhaps suggest that none of the current approaches are routing particularly well.” -Roller et al. 2021. lol
Balancing trick (3) is used.
Disclaimer : Engineering and tuning
Note that for all the methods described, careful tuning of hyperparameters, initialisation and other engineering aspects are essential for good performance and stability of training. This is especially important for stability of training. In fact many recent paper on routed language models (from Google) are almost exclusively focussed on engineering (Fedus et al. 2021, Du et al. 2021).
Unified Scaling Laws
Setup
Task : Autoregressive language modelling
Metric of performance : Validation log-likelihood
Dataset : MassiveText (Rae et al. 2022)
Base architecture : GPT-2 (Radford et al. 2019)
Notation
Separable Scaling Laws in Model Size and Experts
The starting point is the scaling law of Kaplan et al. (2020) for a dense language model with N params. :
Then they hypothesize that for a given N, the loss scales similarly with respect to E.
And further that the two power laws are separable and can be combined :
Power laws are not separable
While the first hypothesis is empirically verified, the separable power law doesn’t hold : the exponent b depends on N. Routing gives diminishing returns when N grows (N = 5M, 15M, 130M, 370M, 1.3B)
Quadratic Interaction in N and E
While the first hypothesis is empirically verified, the separable power law doesn’t hold : the exponent b depends on N. Therefore, they modify their ansatz to account for the interaction between E and N :
Quadratic Interaction in N and E
While the first hypothesis is empirically verified, the separable power law doesn’t hold : the exponent b depends on N. Therefore, they modify their ansatz to account for the interaction between E and N :
The constant c quantifies the diminishing returns from routing as size increases (and vice-versa). A clear goal for Routed Language Models is to have to c closest to 0.
Bounded scaling in E
Authors observe another source of diminishing returns : low and high values of E.
When E is low, scaling is weakened by the fixed overhead (e.g. interference of balancing loss) of the routing process and when E is high the different routing methods deteriorate for different reasons (e.g. variance in RL)
To correct for that, authors apply a transformation to E.
The constant Emax quantifies the diminishing returns from routing coming from high number of experts . A clear goal for Routed Language Models is to have the greatest E possible.
Final scaling law
At the end, they arrive at the following scaling law :
Routing is good
Routing is good : for how long?
Q : Given the diminishing returns, what value of N does routing stop being beneficial?
Let be the Effective Parameter Count (EPC), obtained by solving :
The validation trivially follows a power law with respect to
Routing is good : for how long?
Q : Given the diminishing returns, what value of N does routing stop being beneficial?
We are interested , the value of N at which . It can be found to be equal to
More general scaling law
To account for more general architectures (in particular those where K>1 or R ≠ 0.5), the authors introduce a more general version of the current scaling laws using variables
More general scaling law : K and R don’t matter
The new scaling law produces similar fits across K and R, i.e. loss can be predicted only based on P and B. This indicates that K and R have little impact on the performance
More general scaling law : which K,R to choose
Higher K means more parameter efficiency, but also more FLOPs and communication cost : K = 1 is prefered
Higher R normally means better performance, however there are diminishing returns : R > 0.5 is prefered
Recap and discussion
Recap
Important points of the paper :
Comparison of presented architectures
We summarize the relative behavior of the three routing techniques considered :
Left-out sections
Implication for future research and open questions
This framework gives a way to quantify the performance or different architectures in a disentangled way and informs future research. We can extract a few general lessons and questions for the future.
Questions :