1 of 55

An Inferential Perspective on FL &�Remarks on Data Efficiency of Meta-learning

Maruan Al-Shedivat

Based on joint work with:

Jenny Gillenwater, Liam Li, Afshin Rostamizadeh, Ameet Talwalkar, Eric Xing

FLOW seminar 3/24/2021

2 of 55

Outline & Relevant Papers

  • Federated Learning via Posterior Averaging:�A New Perspective and Practical Algorithmswith Jenny Gillenwater, Afshin Rostamizadeh, Eric Xingto appear at ICLR 2021
  • On Data Efficiency of Meta-learningwith Liam Li, Ameet Talwalkar, Eric Xingto appear at AISTATS 2021

Part I (25-30 minutes)�focus on standard FL

Part II (10-15 minutes)

focus on personalized FL

3 of 55

Part I:

An Inferential Perspective on FL

4 of 55

 

 

 

 

 

 

Solve this problem using FedAvg (local SGD):

  • Optimize the global objective over multiple communication rounds.
  • At each round, a subset of clients runs local optimization and communicates with the server.

Federated Learning (FL) is usually formulated as a distributed optimization problem

local client objectives

global objective

Server

request�clients

Client Population

 

Local SGD...

5 of 55

Cross-device FL: Modeling Assumptions

101-103 clients per round

Bonawitz et al. “Towards Federated Learning at Scale: System Design.” arXiv:1902.01046

6 of 55

Cross-device FL: Modeling Assumptions

Bonawitz et al. “Towards Federated Learning at Scale: System Design.” arXiv:1902.01046

communication of the model if often the bottleneck

7 of 55

Cross-device FL: Modeling Assumptions

Bonawitz et al. “Towards Federated Learning at Scale: System Design.” arXiv:1902.01046

The cross-device federated setting:

  • A very large number of clients (1M+) ⇒�clients participate in ≤ 1 training rounds
  • Data distributions on clients are different ⇒ non-IID setting
  • Increasing # of steps per client is cheap compared to increasing # of rounds due to communication costs
  • Increasing # of clients per round often has a negligible overhead

8 of 55

Solve this problem using FedAvg (local SGD):

  • Optimize the global objective over multiple communication rounds.
  • At each round, a subset of clients runs local optimization and communicates with the server.

Federated Learning (FL) is usually formulated as a distributed optimization problem

local client objectives

global objective

Client-server communication is often slow & expensive. How can we speed up training?

 

 

 

 

 

 

Server

Client Population

9 of 55

Solve this problem using FedAvg (local SGD):

  • Optimize the global objective over multiple communication rounds.
  • At each round, a subset of clients runs local optimization and communicates with the server.

Federated Learning (FL) is usually formulated as a distributed optimization problem

local client objectives

global objective

Client-server communication is often slow & expensive. How can we speed up training?

  • To speed up (x10-100) we can make clients spend more time at each round on local training (e.g., do more local SGD steps)�⇒ do more local progress, thereby reducing the total number of communication rounds.

 

 

 

 

 

 

Server

Client Population

10 of 55

Solve this problem using FedAvg (local SGD):

  • Optimize the global objective over multiple communication rounds.
  • At each round, a subset of clients runs local optimization and communicates with the server.

Federated Learning (FL) is usually formulated as a distributed optimization problem

local client objectives

global objective

Client-server communication is often slow & expensive. How can we speed up training?

  • To speed up (x10-100) we can make clients spend more time at each round on local training (e.g., do more local SGD steps)�⇒ do more local progress, thereby reducing the total number of communication rounds.
  • Because of client data heterogeneity, it turns out that more local computation per round results in convergence to inferior models!

 

 

 

 

 

 

Server

Client Population

11 of 55

Convergence Issues: Toy Example (Least Squares in 2D)

Least squares:

12 of 55

Convergence Issues: Toy Example (Least Squares in 2D)

Least squares:

13 of 55

Convergence Issues: Toy Example (Least Squares in 2D)

Least squares:

14 of 55

Federated Averaging: Fixed Points (for quadratic losses)

Centralized least squares optimum:

FedAvg fixed point (e steps per round):

Takeaways:

  • FedAvg (as well as FedProx and other methods) optimize a surrogate objective function
  • There is a gap between the optimum of the true objective and the surrogate objective

Proposed fix: an algorithm similar to SCAFFOLD, which uses stateful clients that are revisited throughout the course of training.

Proposed fix: careful tuning of the client and server learning rate schedules that reduce the discrepancy.

15 of 55

Federated Optimization ⇒ Federated Posterior Inference

Federated learning is often formulated as an optimization problem:

expectation�over clients

expectation over�client data

16 of 55

Federated Optimization ⇒ Federated Posterior Inference

Federated learning is often formulated as an optimization problem:

If the loss function is the negative log likelihood then the solution of the optimization problem is the maximum likelihood estimator (MLE)

expectation�over clients

expectation over�client data

17 of 55

Federated Optimization ⇒ Federated Posterior Inference

Federated learning is often formulated as an optimization problem:

If the loss function is the negative log likelihood then the solution of the optimization problem is the maximum likelihood estimator (MLE)

An alternative to MLE is posterior inference, we would like to infer the posterior distribution over the parameters

expectation�over clients

expectation over�client data

under the uniform prior, posterior mode ≡ MLE

18 of 55

Federated Optimization ⇒ Federated Posterior Inference

Any posterior distribution decomposes into a product of sub-posteriors:

19 of 55

Federated Optimization ⇒ Federated Posterior Inference

Any posterior distribution decomposes into a product of sub-posteriors:

A high-level algorithm that will attain the global optimum:

  1. On each client, (approximately) infer the local posterior distribution
  2. Communicate information about the inferred local posteriors to the server
  3. Multiplicatively aggregate local posteriors into the global on the server
  4. The mode of the global posterior is the global optimum!

Key point:�If we can do this efficiently, then we will have a globally consistent algorithm with stateless clients.

20 of 55

Example: Federated Quadratics

Quadratic objectives are log likelihoods under the Gaussian model:

21 of 55

Example: Federated Quadratics

Quadratic objectives are log likelihoods under the Gaussian model:

The global posterior is a product of Gaussians (⇒ also Gaussian):

22 of 55

Example: Federated Quadratics

Quadratic objectives are log likelihoods under the Gaussian model:

The global posterior is a product of Gaussians (⇒ also Gaussian):

The global optimum is the posterior mode (coincides with the mean for Gaussians):

23 of 55

Federated Posterior Averaging

If we can infer the posterior mode, we have solved the problem!

local posterior covariances

local posterior means

Key idea:

  • Estimate moments of the local posteriors on the clients.
  • Communicate this to the server that will infer the global posterior mean.

24 of 55

Federated Posterior Averaging

If we can infer the posterior mode, we have solved the problem!

Note: If posteriors are non-Gaussian, multimodal, etc., the above expression is simply a (federated) Laplace approximation of both local and global posteriors.

local posterior covariances

local posterior means

Key idea:

  • Estimate moments of the local posteriors on the clients.
  • Communicate this to the server that will infer the global posterior mean.

25 of 55

Federated Posterior Averaging

If we can infer the posterior mode, we have solved the problem!

Challenges:

(1) how to infer local posteriors efficiently?

(2) how to do aggregation on the server efficiently?

(3) how to communicate them to the server efficiently?

local posterior covariances

local posterior means

Key idea:

  • Estimate moments of the local posteriors on the clients.
  • Communicate this to the server that will infer the global posterior mean.

26 of 55

Local Posterior Inference: Stochastic Gradient MCMC

JMLR 2017

Key ideas:

  • SGD on the log likelihood ≈�sampling from the posterior
  • Run SGD on the local objective �(long enough for the Markov chain to mix-in)
  • Keep running SGD, collect posterior samples, and use them for estimating and

27 of 55

Global Posterior Inference: Computing the Matrix Inverse

denote:

is the minimizer of the following objective:

Idea: solve it using SGD instead of matrix inverse!

Precisely the server update done by FedAvg, except client Δ’s are different

It also solves our communication problem: clients need to only send some new deltas to the server!

28 of 55

Final Hurdle: Local Computation of New Deltas

Server updates:

So, we need to be able to compute deltas efficiently…�Note: we cannot even store because it is a d x d matrix (where d is 1M+)

29 of 55

Final Hurdle: Local Computation of New Deltas

Server updates:

So, we need to be able to compute deltas efficiently…�Note: we cannot even store because it is a d x d matrix (where d is 1M+)

Good news: we can compute deltas using only O(d) memory and O(d) compute per each approximate posterior sample!

30 of 55

Final Hurdle: Local Computation of New Deltas

We propose to compute f in two steps:

  1. Use a shrinkage estimator of the covariance [Ledoit & Wolf, 2004]:

31 of 55

Final Hurdle: Local Computation of New Deltas

We propose to compute f in two steps:

  1. Use a shrinkage estimator of the covariance [Ledoit & Wolf, 2004]:�����which can be written in the form of recursive rank-1 updates:

Estimator based on previous (l - 1) samples

Rank-1 update based on the l-th sample

32 of 55

Final Hurdle: Local Computation of New Deltas

We propose to compute f in two steps:

  1. Use a shrinkage estimator of the covariance [Ledoit & Wolf, 2004]:����
  2. Compute exactly and online using dynamic programming (20 lines of code!) as we keep sampling from the local posterior using SG-MCMC.

The per-sample update requires a constant number of vector-vector multiplies �requires O(d) compute and O(d) memory on the client.�

33 of 55

Federated Posterior Averaging

If we can infer the posterior mode, we have solved the problem!

Challenges:

(✓) how to infer local posteriors efficiently?

(✓) how to do aggregation on the server efficiently?

(✓) how to communicate them to the server efficiently?

local posterior covariances

local posterior means

Key idea:

  • Estimate moments of the local posteriors on the clients.
  • Communicate this to the server that will infer the global posterior mean.

34 of 55

Federated Posterior Averaging (FedPA): The Algorithm

On the server:

  1. Distribute the initial state to clients
  2. Collect & average deltas from clients
  3. Take a gradient step:

On the clients:

  1. Run SGD-based MCMC
  2. As new samples arrive,�keep computing deltas
  3. Send the final deltas to the server

Similar to FedOpt, we run SGD on the clients,�but compute deltas differently.

Identical to (adaptive) FedOpt!�[Reddi*, Charles*, et al., ICLR 2021]

Note: our new perspective suggests that known federated optimization algorithms are doing posterior inference under the Laplace approximation and estimating local covariances using the identity matrix

35 of 55

Back to our Toy Example

the “noise barrier”

36 of 55

FedAvg vs. FedPA: Bias and Variance of the Server Updates

The plots are for 10D federated least squares regression on multiple synthetically generated datasets.

Reduce bias by reducing the amount of local computation

Reduce bias by increasing the amount of local computation

37 of 55

FedAvg vs. FedPA: Federated CIFAR100

  • Task: 100 class image classification, 500 clients (model: ResNet-18).
  • We “burn-in” FedPA by running it in the FedAvg regime for 400 rounds.
  • Starting round 400, we switch to FedPA computation of client deltas.

stands for multi-epoch

38 of 55

FedAvg vs. FedPA: Federated StackOverflow LR

  • Task: 500 class multi-label classification, bag of words features, 300K+ clients.
  • We “burn-in” FedPA by running it in the FedAvg regime for 800 rounds.
  • Starting round 800, we switch to FedPA computation of client deltas.

stands for multi-epoch

39 of 55

Concluding Thoughts for Part I

  • Federated learning can be approached as a probabilistic inference problem,�which allows us to design new efficient FL algorithms + re-interpret well-known FedAvg
  • Bayesian ML/DL is typically used for quantification of predictive uncertainty.�Turns out, it is also quite useful in distributed, communication-limited settings.

40 of 55

Part II:

On Data Efficiency of Meta-learning

(for personalized FL)

41 of 55

The goal of FL is not always to learn a global model

A simple way to personalize models:

  • Learn a model using FL (e.g., FedAvg)
  • At test time, fine-tune the model on the available labeled data

Bottom line: models obtained by FedAvg are often worse than purely local training for the majority�of the clients ⇒ need model personalization

word prediction

image classification

42 of 55

The goal of FL is not always to learn a global model

Bottom line: models obtained by FedAvg are often worse than purely local training for the majority�of the clients ⇒ need model personalization

word prediction

image classification

A simple way to personalize models:

  • Learn a model using FL (e.g., FedAvg)
  • At test time, fine-tune the model on the available labeled data

In the few-shot learning literature, this method is known as “Reptile:”

43 of 55

Background: Model-agnostic Meta-learning

  • Idea: learn a good initialization for stochastic gradient descent (SGD)

Finn et al., ICML 2017

44 of 55

Background: Model-agnostic Meta-learning

  • Idea: learn a good initialization for stochastic gradient descent (SGD)���
  • Given a new task , produce a model using a gradient step:��
  • Meta-training is learning an initialization:

Intuition:

Finn et al., ICML 2017

45 of 55

The Question of Interest

Why?

  • It matters in many practical settings (e.g., personalized federated learning, where number of clients, amount of data, and compute are limited)

How can we characterize the data efficiency of�modern meta-learning algorithms?

How?

  1. Theoretically: we derive generalization bounds for modern meta-learning
  2. Empirically: we confirm predictions our theory thorough experiments

46 of 55

Theoretical Analysis: Objective Functions

1. Empirical estimator of the transfer risk

Task data is NOT split into subsets

2. Hold-out estimator of the transfer risk

Example methods: FedAvg/Reptile

Example methods: MAML, ProtoNets

Support set: used in the inner loop

Query set: used in the outer loop

47 of 55

Theoretical Analysis: Meta-generalization Bounds

  1. The following bound holds for methods that optimize (FedAvg/Reptile):

48 of 55

Theoretical Analysis: Meta-generalization Bounds

  • The following bound holds for methods that optimize (FedAvg/Reptile):
  • The following bound holds for methods that optimize (MAML, ProtoNets):

49 of 55

Practical Implications

O1: As # tasks goes to infinity, generalization error of -methods goes to zero.

MAML vs. ProtoNets must be indistinguishable when trained on many tasks.

Performance of MAML and ProtoNets as a function of the # of training tasks:

  • The methods are identical when trained on a very large number of tasks (i.e., classical benchmarks).
  • Clear difference between the two in the regime of limited supervision.

50 of 55

Practical Implications

O1: As # tasks goes to infinity, generalization error of -methods goes to zero.

MAML vs. ProtoNets must be indistinguishable when trained on many tasks.

O2: FedAvg/Reptile (and other -methods) has an additive term that

depends on the # samples per task ⇒ worse performance unless a sufficient

number of samples per task is provided at meta-training time.

51 of 55

Practical Implications

O1: As # tasks goes to infinity, generalization error of -methods goes to zero.

MAML vs. ProtoNets must be indistinguishable when trained on many tasks.

O2: FedAvg/Reptile (and other -methods) has an additive term that

depends on the # samples per task ⇒ worse performance unless a sufficient

number of samples per task is provided at meta-training time.

O3: Meta-generalization depends on the algorithmic stability constants of the

inner and outer loops ⇒ actively selecting labeled data in the inner loop

may improve meta-generalization.

52 of 55

Practical Implications

53 of 55

To Learn More

  • Federated Learning via Posterior Averaging:�A New Perspective and Practical Algorithmswith Jenny Gillenwater, Afshin Rostamizadeh, Eric Xingto appear at ICLR 2021
  • On Data Efficiency of Meta-learningwith Liam Li, Ameet Talwalkar, Eric Xingto appear at AISTATS 2021

Thank you!

Questions?

54 of 55

(blank)

55 of 55

Additional Results for Part I