1 of 55

GNN For EF Tracking

Santosh Parajuli

University of Illinois Urbana-Champaign

18 January 2024

2 of 55

3 of 55

4 of 55

4

Quantization Aware Training for GNN Part

  • Quantization Aware Training (QAT) with Brevitas involves training a neural network with the awareness that the model will be quantized (reducing precision of weights and activations) during inference.

  • The main goal of quantization is to enable more efficient deployment of the model on hardware platforms with lower computational precision capabilities.

5 of 55

5

Preliminary Results

  • Quick test with 48 hidden layers. (with hard Cut pT>1 GeV)

Without QAT

With QAT

Bit Width = 8

6 of 55

6

Preliminary Results

  • Quick test with 48 hidden layers. (with hard Cut pT>1 GeV)

Without QAT

With QAT

7 of 55

7

Preliminary Results

  • Quick test with 48 hidden layers. (with hard Cut pT>1 GeV)

Without QAT

With QAT

8 of 55

8

Preliminary Results

  • Quick test with 48 hidden layers. (with hard Cut pT>1 GeV)

Without QAT

With QAT

9 of 55

9

Next

  • Explore more place to use Quant function. (output hidden dimension, layer_norm,...)
  • Train without any hard pT cut with full 9000 ttbar events
  • Pruning

10 of 55

Hidden Dimension Optimization

Santosh Parajuli

University of Illinois Urbana-Champaign

5 February 2024

11 of 55

11

Quantization Aware Training for GNN Part

  • Quantization Aware Training (QAT) with Brevitas involves training a neural network with the awareness that the model will be quantized (reducing precision of weights and activations) during inference.

  • The main goal of quantization is to enable more efficient deployment of the model on hardware platforms with lower computational precision capabilities.

ReLU→QuantReLU

Linear→QuantLinear

12 of 55

12

Number of Hidden Dimension

Number of Parameters

GNN Edgewise-Efficiency

Total Purity

16

32K

83.86

72.4

32

125K

93.61

80.4

48

280K

94.6

82.9

64

496K

95.32

83.0

80

774K

95.15

83.0

96

1.1M

95.54

83.4

128

2M

95.48

85.74

128(CTD checkpoint)

2M

99.0 (trained on 7800 events)

93.8 (trained on 7800 events)

  • Train/val/test: 80/10/10 ttbar events (in each case)
  • InteractionGNN2 with No hard cuts

Choosing 48 hidden dimension to start with. Trade off between efficiency and size of model.

13 of 55

13

Number of Hidden Dimension

Number of Messaging Passing

Number of Parameters

GNN Edgewise-Efficiency

Total Purity

48 (batchnorm)

8

283K

88.9

76.18

48 (layernorm)

8

280K

94.6

82.9

48 (layernorm)

4

150K

92.78

80.2

48 (layernorm)

2

84.9K

82.8

66.1

64 (layernorm)

8

496K

95.32

83.0

64 (layernorm)

4

265K

92.7

80.2

64 (layernorm)

2

150K

83.82

68.6

  • Train/val/test: 80/10/10 ttbar events (in each case)
  • InteractionGNN2 with No hard cuts

Choosing 48 hidden dimension , 8 graph iteration. Trade off between efficiency and size of model.

14 of 55

14

Number of Hidden Dimension

Recurrent Edge/Node Network

Number of Parameters

GNN Edgewise-Efficiency

Total Purity

48

False

280K

94.6

82.9

48

True

52.4K

92.53

76.78

64

False

496K

95.32

83.0

64

True

92.4K

93.83

80.2

96

False

1.1M

95.54

83.4

96

True

206K

95.98

83.3

128

False

2M

95.48

85.74

128

True

364K

96.31

84.65

  • Train/val/test: 80/10/10 ttbar events (in each case)
  • InteractionGNN2 with No hard cuts

Recurrent Edge/Node Network:

  • Edge/Node network parameters are shared across different iterations of message passing.
  • The recurrent edge/node network allows the model to iteratively refine edge/node representations, taking into account information from previous iterations.
  • Shared parameters across iterations reduce redundancy and the overall number of parameters.

15 of 55

QAT with Brevitas

15

Bit Width

Number of Parameters

GNN Edgewise-Efficiency

Total Purity

Pytorch Original

280K

94.6

82.9

8 (no bias qt)

277K

93.86

81.75

6

277K

93.53

81.44

  • 48 Hidden Dimension; 8 graph Iteration

16 of 55

QAT with Brevitas

16

Pytorch Original

8 bit width

6 bit width

17 of 55

QAT with Brevitas

17

Choosing 48 hidden dimension , 8 graph iteration. 8 weight_bit width (to start with).

Pytorch Original

8 bit width

6 bit width

18 of 55

18

Pruning in GNN Part

Recurrent Edge Network:

Edge network's parameters are shared across different iterations of message passing or graph updates.

The recurrent edge network allows the model to iteratively refine edge representations, taking into account information from previous iterations.

19 of 55

Quantization Aware Training with Brevitas

Santosh Parajuli

University of Illinois Urbana-Champaign

5 March 2024

20 of 55

20

Quantization Aware Training for GNN Part

  • Quantization Aware Training (QAT) with Brevitas involves training a neural network with the awareness that the model will be quantized (reducing precision of weights and activations) during inference.

  • The main goal of quantization is to enable more efficient deployment of the model on hardware platforms with lower computational precision capabilities.

ReLU→QuantReLU

Linear→QuantLinear

21 of 55

QAT with Brevitas

21

Bit Width

GNN Edgewise-Efficiency

Total Purity

Pytorch Original

94.44

82.16

12

94.65

84.29

8

94.89

82.21

6

94.63

82.84

4

93.71

81.99

2

80.09

70.83

  • 48 Hidden Dimension; 8 graph Iteration; 280K Parameters.

128(CTD checkpoint)

2M param

99.0 (trained on 7800 events)

93.8 (trained on 7800 events)

128

2M param

95.48

85.74 (trained on 100 events)

!!!

Only 80 events for training. 10 test, 10 val

22 of 55

QAT with Brevitas

22

  • 280K parameters; 48 Hidden Dimension; 8 graph Iteration; 6 bit width to prune

23 of 55

23

Pruning in GNN Part

  • In L1 unstructured pruning: Eliminate less important weights by setting them to zero based on their magnitude.

  • The pruning process involves identifying the least important weights according to their L1 norm, which is the sum of the absolute values of the weights in a layer.

24 of 55

Iterative Pruning

24

  • 48 Hidden Dimension; 8 graph Iteration; 6 bit width
  • Val_loss metric for pruning iteratively

Pruning in Progress

Pruning Percentage

(in hidden layers)

Total number of unpruned Parameters

GNN Edgewise-Efficiency

Total Purity

0.0

280K

94.6

82.8

20

224K

94.8

82.3

36

179K

95.0

82.1

48

143K

94.5

83.5

59

114K

94.6

83.2

67

91K

94.7

82.4

73

73K

94.2

83.6

79

58K

94.6

82.5

83

46K

94.3

82.8

86

37K

91.1

77.7

25 of 55

Iterative Pruning

25

  • 48 Hidden Dimension; 8 graph Iteration; 6 bit width
  • Val_loss metric for pruning iteratively

!!!

Only 80 events for training. 10 test, 10 val

26 of 55

Quantization Aware Training with Brevitas

Santosh Parajuli

University of Illinois Urbana-Champaign

21 March 2024

27 of 55

27

Number of Hidden Dimension

Number of Parameters

GNN Edgewise-Efficiency

Total Purity

16

32K

96.66

88.53

32

125K

98.03

91.04

48

280K

98.12

91.06

96

1.1M

98.15

91.27

128

2M

98.16

92.35

128

2M

99.0 (trained on 7800 events)

93.8 (trained on 7800 events)

  • Train/val/test: 800/100/100 ttbar events (in each case)
  • InteractionGNN2 with No hard cuts

Choosing 48 hidden dimension to start with. Trade off between efficiency and size of model.

128

2M param

95.48

85.74 (trained on 100 events)

28 of 55

Hidden Dimension Optimization

28

  • 280K parameters; 48 Hidden Dimension; 8 graph Iteration-> to quantize

800 events for training. 100 test, 100 val

2023 ttbar : pileup 200

29 of 55

29

Quantization Aware Training for GNN Part

  • Quantization Aware Training (QAT) with Brevitas involves training a neural network with the awareness that the model will be quantized (reducing precision of weights and activations) during inference.

  • The main goal of quantization is to enable more efficient deployment of the model on hardware platforms with lower computational precision capabilities.

ReLU→QuantReLU

Linear→QuantLinear

30 of 55

QAT with Brevitas

30

Bit Width

GNN Edgewise-Efficiency

Total Purity

Pytorch Original

98.20

91.06

12

97.87

91.02

8

97.99

91.46

6

98.02

90.95

4

97.15

89.66

2

90.00

71.60

  • 48 Hidden Dimension; 8 graph Iteration; 280K Parameters.

128

2M param

99.0 (trained on 7800 events)

93.8 (trained on 7800 events)

128

2M param

98.16

92.35 (trained on 800 events)

800 events for training. 100 test, 100 val

2023 ttbar : pileup 200

31 of 55

QAT with Brevitas

31

  • 280K parameters; 48 Hidden Dimension; 8 graph Iteration;

6 bit width (hidden activation, weight and bias in linear layers) to prune

800 events for training. 100 test, 100 val

2023 ttbar : pileup 200

32 of 55

32

Pruning in GNN Part

  • In L1 unstructured pruning: Eliminate less important weights by setting them to zero based on their magnitude.

  • The pruning process involves identifying the least important weights according to their L1 norm, which is the sum of the absolute values of the weights in a layer.

33 of 55

Iterative Pruning

33

Pruning Percentage

(in hidden layers)

Total number of unpruned Parameters

GNN Edgewise-Efficiency

Total Purity

0.0

280K

98.02

90.95

20

224K

98.00

90.21

36

179K

98.21

90.13

48

143K

98.09

90.50

59

114K

98.10

90.56

67

91K

97.99

90.94

73

75K

97.42

90.86

85

40K

97.78

90.73

90

26.6K

97.84

90.85

92.2

21.8K

97.32

90.18

93.5

18k

97.53

90.85

94.6

15K

96.78

88.90

97.8

6K

53.29

42.01

  • 48 Hidden Dimension; 8 graph Iteration; 6 bit width
  • Val_loss metric for pruning iteratively

34 of 55

Iterative Pruning

34

  • 48 Hidden Dimension; 8 graph Iteration; 6 bit width
  • Val_loss metric for pruning iteratively

800 events for training. 100 test, 100 val

2023 ttbar : pileup 200

35 of 55

Iterative Pruning

35

  • 48 Hidden Dimension; 8 graph Iteration; 6 bit width
  • Val_loss metric for pruning iteratively

75K parameters

280K parameters

36 of 55

7800 Events for train

36

37 of 55

Iterative Pruning

37

  • 48 Hidden Dimension; 8 graph Iteration; 6 bit width
  • Val_loss metric for pruning iteratively

35K parameters

280K parameters

38 of 55

Iterative Pruning

38

  • 48 Hidden Dimension; 8 graph Iteration; 6 bit width
  • Val_loss metric for pruning iteratively

35K parameters

280K parameters

39 of 55

Iterative Pruning

39

  • 48 Hidden Dimension; 8 graph Iteration; 6 bit width
  • Val_loss metric for pruning iteratively

35K parameters

280K parameters

40 of 55

Iterative Pruning

40

  • 48 Hidden Dimension; 8 graph Iteration; 6 bit width
  • Val_loss metric for pruning iteratively

35K parameters

280K parameters

41 of 55

GNN For EF Tracking

Santosh Parajuli

University of Illinois Urbana-Champaign

28 May 2024

42 of 55

Pruning on CTD 2023 Model(except we use layernorm here)

42

2M non-zero parameters

169K non-zero parameters

Already 90 % pruned

43 of 55

Pruning on CTD 2023 Model(except we use layernorm here)

43

2M non-zero parameters

169K non-zero parameters

Already 90 % pruned

Efficiency

44 of 55

Pruning on CTD 2023 Model(except we use layernorm here)

44

2M non-zero parameters

169K non-zero parameters

Already 90 % pruned

Total Purity

45 of 55

Pruning on CTD 2023 Model(except we use layernorm here)

45

2M non-zero parameters

169K non-zero parameters

Already 90 % pruned

Target Purity

46 of 55

Pruning on CTD 2023 Model(except we use layernorm here)

46

2M non-zero parameters

169K non-zero parameters

Already 90 % pruned

Masked Purity

47 of 55

To-Do for Inference for faster inference time

  1. Export and Optimize:
    • Export the pruned model to an optimized format that supports sparse inference. For example, exporting the model to ONNX and then using frameworks like TensorRT can help leverage sparsity. (Export your model to ONNX and then use TensorRT for optimization.)

2. Use PyTorch's Sparse Tensor Support

47

48 of 55

GNN For EF Tracking

Santosh Parajuli

University of Illinois Urbana-Champaign

25 July 2024

49 of 55

Rel24 Sample

49

2M parameters(128 Hidden dim)

280K parameters(48-hidden dim) Quantized (6 bit width)

50 of 55

Rel24 Sample

50

2M parameters(128 Hidden dim)

280K parameters(48-hidden dim) Quantized (6 bit width)

51 of 55

Rel24 Sample

51

2M parameters(128 Hidden dim)

280K parameters(48-hidden dim) Quantized (6 bit width)

52 of 55

Rel24 Sample

52

2M parameters(128 Hidden dim)

280K parameters(48-hidden dim) Quantized (6 bit width)

53 of 55

GNN For EF Tracking

Santosh Parajuli

University of Illinois Urbana-Champaign

20 Aug 2024

54 of 55

Knowledge Distillation

  1. Train Student Model, ->we use Teacher model loss here

54

55 of 55

Thank You

55