1 of 46

SCAFFOLD: algorithm for federated learning

Sai Praneeth Karimireddy, Satyen Kale, Mehryar Mohri, Sashank Reddi, Sebastian Stich, Ananda Theertha Suresh

2 of 46

Tight analysis of FedAvg when clients are heterogeneous (non iid data)

Explain degradation of FedAvg via the ‘drift’ in the client updates

Prove SCAFFOLD is resilient to heterogeneity and client sampling

2

3 of 46

Federated Learning: Setting

3

Server

(e.g. Google)

Clients

(e.g. hospitals, phones)

model

data

[McMahan et al. 2016]

4 of 46

Federated Learning: Setting

4

In each round,

  • Some subset of clients are chosen

x

[McMahan et al. 2016]

5 of 46

Federated Learning: Setting

5

In each round,

  • Some subset of clients are chosen

  • copy of server model is sent to clients

x

y

y

y

[McMahan et al. 2016]

6 of 46

Federated Learning: Setting

6

In each round,

  • Some subset of clients are chosen

  • copy of server model is sent to clients

  • model is updated using client data

x

y

y

y

[McMahan et al. 2016]

7 of 46

Federated Learning: Setting

7

In each round,

  • Some subset of clients are chosen

  • copy of server model is sent to clients

  • model is updated using client data

  • Client updates are aggregated

  • server model is updated

x

[McMahan et al. 2016]

8 of 46

Federated Learning: Characteristics

8

model

data

  • High overhead per round

  • Only a few clients participate in each round

  • The data of the clients is heterogeneous

  • Small/medium number of total clients (cross-silo)

x

9 of 46

Federated Learning: Formalism

9

Model parameters

(weighted) average over clients

Expectation over client data

Loss function wrt parameters and client data

10 of 46

Algorithms for Federated Learning

10

11 of 46

Solving FL: SGD

  • Equivalent to synchronous centralized large-batch training

  • very slow (only 1 update)

11

  • on each client i in , compute a large batch stochastic gradient and average them

Mini-batch gradient with

Batch size K per client

Average over all sampled clients

12 of 46

Solving FL: FedAvg

  • Potentially faster (performs K updates)

  • different from centralized updates
  • may not converge

12

  • on each client i in , perform K steps of SGD

Repeat K times

Average over all sampled clients

  • Server model is a (weighted) average of client models

[McMahan et al. 2016]

13 of 46

Solving FL: SCAFFOLD

  • Potentially faster (performs K updates)

  • mimics centralized updates!

13

Correction term!

  • on each client i in , perform K steps of SGD

Repeat K times

  • Average as before

New!

14 of 46

When does

FedAvg fail?

14

15 of 46

FedAvg degrades with heterogeneous clients. Why?

15

16 of 46

Client updates: SGD vs. FedAvg

16

Optimum

Surface of two client loss functions, and the combined function

17 of 46

Client updates: SGD

17

The limit point of SGD is the optimum.

18 of 46

Client updates: FedAvg

18

19 of 46

Client updates: FedAvg

19

  • Moves away even if we start at the optimum.
  • FedAvg does not converge to optimum!
  • Requires small learning rate to get close

20 of 46

Drift in client updates: SGD vs. FedAvg

20

21 of 46

Convergence Rates: SGD

  • For strongly convex
  • For non-convex functions

21

Notation:

  • R communication rounds
  • K local steps
  • Total N clients

  • L - smooth, μ - strongly convex
  • σ - variance within a client

“the noise is in the noise and SGD don't care” [Chaturapruek et al. 2015]

22 of 46

Convergence Rates: SGD vs. FedAvg

  • For strongly convex

22

Assume: (B,G)-similar gradients

Generalizes [Li et al. 2019] and [Khaled et al. 2019]

23 of 46

Convergence Rates: SGD vs. FedAvg

  • For strongly convex
  • For non-convex functions

23

FedAvg

SGD

Assume: (B,G)-similar gradients

Tightest rates, uses server and client step-sizes

24 of 46

Lower bound: FedAvg

24

Assume: (B,G)-similar gradients

Necessary!

Theorem: For any G, we can find functions with (2, G)-similar gradients such that FedAvg for K>1 with arbitrary step-sizes always has error

25 of 46

Quick demo: SGD vs. FedAvg

  • Linear regression
  • concrete dataset (UCI)
  • 10 clients (no sampling)
  • K = 10 local steps

> FedAvg needs smaller learning rate

> Slower than SGD

25

Lower is better

26 of 46

SCAFFOLD: stochastic controlled averaging

26

27 of 46

Main Idea: Use control variates

27

  • Guess direction of client update
  • Guess direction of server update
  • Use the correction

28 of 46

Main Idea: Corrected updates

28

Correction terms

Mimics centralized updates!

29 of 46

Main Idea: Updating control variates

29

+

+

+

+

30 of 46

SCAFFOLD: Algorithm

30

31 of 46

SCAFFOLD: Quick demo

  • Linear regression
  • concrete dataset (UCI)
  • 10 clients (no sampling)
  • K = 10 local steps

> SCAFFOLD works with same learning rate as SGD

> Faster than SGD!

31

Lower is better

32 of 46

SCAFFOLD: Client sampling

  • Updates of every client mimics centralized updates.
  • Few #clients works, as long as control variates are accurate.
  • Hence, very robust to client sampling.

Different view: SAGA is a special case of SCAFFOLD with client sampling

32

33 of 46

SCAFFOLD: Variance reduced convergence Rates

  • For strongly convex functions

  • For non-convex functions

33

> Better than FedAvg

> Variance reduced rates!

Notation:

  • R communication rounds
  • S out of N clients sampled

  • L - smooth, μ - strongly convex

34 of 46

SCAFFOLD: Why take more than 1 step?

  • Each update mimics a centralized update => local steps should help.

  • In worst case not true [Arjevani & Shamir, 2015] :(

  • Possible if similar Hessians!

34

35 of 46

SCAFFOLD: Why take more than 1 step?

35

𝛿 - BHD (Bounded Hessian Dissimilarity)

And is 𝛿-weakly convex.

Note that

36 of 46

SCAFFOLD: Why take more than 1 step?

Assume: 𝛿 - BHD, quadratics

  • For strongly convex functions

  • For non-convex functions

36

Notation:

  • R communication rounds
  • All N clients participate
  • K local steps

  • L - smooth, μ - strongly convex

37 of 46

SCAFFOLD: Why take more than 1 step?

Assume: 𝛿 - BHD

  • For strongly convex functions

  • For non-convex functions

37

> Best to take

> We replaced L with 𝛿 in the rates (typically 𝛿 << L)

> First rate to characterize improvement due to local steps!

38 of 46

SCAFFOLD: Why take more than 1 step?

38

Quick demo on scalar quadratics

  • Scaffold is unaffected by G

  • Larger K is better

  • K=2 is 2 times faster

  • K=10 is only 4 times better

39 of 46

Experiments

39

40 of 46

Experimental Setup

  • Extended MNIST (balanced) dataset
  • Multi-class logistic regression (47 classes)
  • Partitioned into N clients
  • Sorted by labels and then ‘slightly shuffled’ before splitting

40

41 of 46

Performance of SCAFFOLD

41

Similarity = 0, 1 Epoch, #sampled clients = 20, total clients = 400

Communication rounds -->

42 of 46

Effect of similarity

42

Test accuracy, 10 Epochs, #sampled clients = 20, total clients = 100

Communication rounds -->

43 of 46

Effect of number of clients

43

SCAFFOLD with 5 clients is better than FedAvg with 50!

> Total #clients = 400

> Total #categories = 47

> 1 Epoch per round

> Similarity = 0

Communication rounds -->

44 of 46

Take aways

  • Degradation of FedAvg is due to the client drift. If you use FedAvg, use separate server and client step-sizes.
  • Why you should use SCAFFOLD:
    • Parameter-free
    • Provably converges faster than SGD and FedAvg
    • Resilient to heterogeneity and client sampling
  • Main limitation: requires maintaining client state

44

45 of 46

Some new work

  • Mime: Mimicking Centralized Stochastic Algorithms in Federated Learning. Cross-device setting? Momentum? Adam? lookout on arxiv.

  • Byzantine-Robust Learning on Heterogeneous Datasets via Resampling. robustness for non-iid data? lookout on arxiv.

  • Secure Byzantine-Robust Machine Learning. Lie He, SPK, Martin Jaggi. Combining privacy and robustness in federated learning. [arxiv 2006.04747]

45

46 of 46

Thank You.

Questions?

46