An Inferential Perspective on FL &�Remarks on Data Efficiency of Meta-learning
Maruan Al-Shedivat
Based on joint work with:
Jenny Gillenwater, Liam Li, Afshin Rostamizadeh, Ameet Talwalkar, Eric Xing
FLOW seminar 3/24/2021
Outline & Relevant Papers
Part I (25-30 minutes)�focus on standard FL
Part II (10-15 minutes)
focus on personalized FL
Part I:
An Inferential Perspective on FL
Solve this problem using FedAvg (local SGD):
Federated Learning (FL) is usually formulated as a distributed optimization problem
local client objectives
global objective
Server
request�clients
Client Population
Local SGD...
Cross-device FL: Modeling Assumptions
101-103 clients per round
Bonawitz et al. “Towards Federated Learning at Scale: System Design.” arXiv:1902.01046
Cross-device FL: Modeling Assumptions
Bonawitz et al. “Towards Federated Learning at Scale: System Design.” arXiv:1902.01046
communication of the model if often the bottleneck
Cross-device FL: Modeling Assumptions
Bonawitz et al. “Towards Federated Learning at Scale: System Design.” arXiv:1902.01046
The cross-device federated setting:
Solve this problem using FedAvg (local SGD):
Federated Learning (FL) is usually formulated as a distributed optimization problem
local client objectives
global objective
Client-server communication is often slow & expensive. How can we speed up training?
Server
Client Population
Solve this problem using FedAvg (local SGD):
Federated Learning (FL) is usually formulated as a distributed optimization problem
local client objectives
global objective
Client-server communication is often slow & expensive. How can we speed up training?
Server
Client Population
Solve this problem using FedAvg (local SGD):
Federated Learning (FL) is usually formulated as a distributed optimization problem
local client objectives
global objective
Client-server communication is often slow & expensive. How can we speed up training?
Server
Client Population
Convergence Issues: Toy Example (Least Squares in 2D)
⇒
⇒
Least squares:
Convergence Issues: Toy Example (Least Squares in 2D)
⇒
⇒
Least squares:
Convergence Issues: Toy Example (Least Squares in 2D)
⇒
⇒
Least squares:
Federated Averaging: Fixed Points (for quadratic losses)
Centralized least squares optimum:
FedAvg fixed point (e steps per round):
Takeaways:
Proposed fix: an algorithm similar to SCAFFOLD, which uses stateful clients that are revisited throughout the course of training.
Proposed fix: careful tuning of the client and server learning rate schedules that reduce the discrepancy.
Federated Optimization ⇒ Federated Posterior Inference
Federated learning is often formulated as an optimization problem:
expectation�over clients
expectation over�client data
Federated Optimization ⇒ Federated Posterior Inference
Federated learning is often formulated as an optimization problem:
If the loss function is the negative log likelihood then the solution of the optimization problem is the maximum likelihood estimator (MLE)
expectation�over clients
expectation over�client data
Federated Optimization ⇒ Federated Posterior Inference
Federated learning is often formulated as an optimization problem:
If the loss function is the negative log likelihood then the solution of the optimization problem is the maximum likelihood estimator (MLE)
An alternative to MLE is posterior inference, we would like to infer the posterior distribution over the parameters
expectation�over clients
expectation over�client data
under the uniform prior, posterior mode ≡ MLE
Federated Optimization ⇒ Federated Posterior Inference
Any posterior distribution decomposes into a product of sub-posteriors:
Federated Optimization ⇒ Federated Posterior Inference
Any posterior distribution decomposes into a product of sub-posteriors:
A high-level algorithm that will attain the global optimum:
Key point:�If we can do this efficiently, then we will have a globally consistent algorithm with stateless clients.
Example: Federated Quadratics
Quadratic objectives are log likelihoods under the Gaussian model:
Example: Federated Quadratics
Quadratic objectives are log likelihoods under the Gaussian model:
The global posterior is a product of Gaussians (⇒ also Gaussian):
Example: Federated Quadratics
Quadratic objectives are log likelihoods under the Gaussian model:
The global posterior is a product of Gaussians (⇒ also Gaussian):
The global optimum is the posterior mode (coincides with the mean for Gaussians):
Federated Posterior Averaging
If we can infer the posterior mode, we have solved the problem!
local posterior covariances
local posterior means
Key idea:
Federated Posterior Averaging
If we can infer the posterior mode, we have solved the problem!
Note: If posteriors are non-Gaussian, multimodal, etc., the above expression is simply a (federated) Laplace approximation of both local and global posteriors.
local posterior covariances
local posterior means
Key idea:
Federated Posterior Averaging
If we can infer the posterior mode, we have solved the problem!
Challenges:
(1) how to infer local posteriors efficiently?
(2) how to do aggregation on the server efficiently?
(3) how to communicate them to the server efficiently?
local posterior covariances
local posterior means
Key idea:
Local Posterior Inference: Stochastic Gradient MCMC
JMLR 2017
Key ideas:
Global Posterior Inference: Computing the Matrix Inverse
denote:
is the minimizer of the following objective:
Idea: solve it using SGD instead of matrix inverse!
Precisely the server update done by FedAvg, except client Δ’s are different
It also solves our communication problem: clients need to only send some new deltas to the server!
Final Hurdle: Local Computation of New Deltas
Server updates:
So, we need to be able to compute deltas efficiently…�Note: we cannot even store because it is a d x d matrix (where d is 1M+)
Final Hurdle: Local Computation of New Deltas
Server updates:
So, we need to be able to compute deltas efficiently…�Note: we cannot even store because it is a d x d matrix (where d is 1M+)
Good news: we can compute deltas using only O(d) memory and O(d) compute per each approximate posterior sample!
Final Hurdle: Local Computation of New Deltas
We propose to compute f in two steps:
Final Hurdle: Local Computation of New Deltas
We propose to compute f in two steps:
Estimator based on previous (l - 1) samples
Rank-1 update based on the l-th sample
Final Hurdle: Local Computation of New Deltas
We propose to compute f in two steps:
The per-sample update requires a constant number of vector-vector multiplies ⇒�requires O(d) compute and O(d) memory on the client.�
Federated Posterior Averaging
If we can infer the posterior mode, we have solved the problem!
Challenges:
(✓) how to infer local posteriors efficiently?
(✓) how to do aggregation on the server efficiently?
(✓) how to communicate them to the server efficiently?
local posterior covariances
local posterior means
Key idea:
Federated Posterior Averaging (FedPA): The Algorithm
On the server:
On the clients:
Similar to FedOpt, we run SGD on the clients,�but compute deltas differently.
Identical to (adaptive) FedOpt!�[Reddi*, Charles*, et al., ICLR 2021]
Note: our new perspective suggests that known federated optimization algorithms are doing posterior inference under the Laplace approximation and estimating local covariances using the identity matrix
Back to our Toy Example
the “noise barrier”
FedAvg vs. FedPA: Bias and Variance of the Server Updates
The plots are for 10D federated least squares regression on multiple synthetically generated datasets.
Reduce bias by reducing the amount of local computation
Reduce bias by increasing the amount of local computation
FedAvg vs. FedPA: Federated CIFAR100
stands for multi-epoch
FedAvg vs. FedPA: Federated StackOverflow LR
stands for multi-epoch
Concluding Thoughts for Part I
Part II:
On Data Efficiency of Meta-learning
(for personalized FL)
The goal of FL is not always to learn a global model
A simple way to personalize models:
Bottom line: models obtained by FedAvg are often worse than purely local training for the majority�of the clients ⇒ need model personalization
word prediction
image classification
The goal of FL is not always to learn a global model
Bottom line: models obtained by FedAvg are often worse than purely local training for the majority�of the clients ⇒ need model personalization
word prediction
image classification
A simple way to personalize models:
In the few-shot learning literature, this method is known as “Reptile:”
Background: Model-agnostic Meta-learning
Finn et al., ICML 2017
Background: Model-agnostic Meta-learning
Intuition:
Finn et al., ICML 2017
The Question of Interest
Why?
How can we characterize the data efficiency of�modern meta-learning algorithms?
How?
Theoretical Analysis: Objective Functions
1. Empirical estimator of the transfer risk
Task data is NOT split into subsets
2. Hold-out estimator of the transfer risk
Example methods: FedAvg/Reptile
Example methods: MAML, ProtoNets
Support set: used in the inner loop
Query set: used in the outer loop
Theoretical Analysis: Meta-generalization Bounds
Theoretical Analysis: Meta-generalization Bounds
Practical Implications
O1: As # tasks goes to infinity, generalization error of -methods goes to zero.
⇒ MAML vs. ProtoNets must be indistinguishable when trained on many tasks.
Performance of MAML and ProtoNets as a function of the # of training tasks:
Practical Implications
O1: As # tasks goes to infinity, generalization error of -methods goes to zero.
⇒ MAML vs. ProtoNets must be indistinguishable when trained on many tasks.
O2: FedAvg/Reptile (and other -methods) has an additive term that
depends on the # samples per task ⇒ worse performance unless a sufficient
number of samples per task is provided at meta-training time.
Practical Implications
O1: As # tasks goes to infinity, generalization error of -methods goes to zero.
⇒ MAML vs. ProtoNets must be indistinguishable when trained on many tasks.
O2: FedAvg/Reptile (and other -methods) has an additive term that
depends on the # samples per task ⇒ worse performance unless a sufficient
number of samples per task is provided at meta-training time.
O3: Meta-generalization depends on the algorithmic stability constants of the
inner and outer loops ⇒ actively selecting labeled data in the inner loop
may improve meta-generalization.
Practical Implications
To Learn More
Thank you!
Questions?
(blank)
Additional Results for Part I