1 of 27

Differentiable Signal Processing with Jax

Student Lecture

Leslie Li

11/30/2023

2 of 27

Overview

  • Signal Processing
    • What can signal processing achieve
    • Why making it differentiable?
    • Example: DDSP
  • The JAX library
    • Signal Processing in JAX: numpy/scipy with GPU
    • Using autodiff with JAX
    • Limitations & Issues
    • Building (data-driven) deep models with JAX: flax and optax
  • Demo: differentiable cochlear processing (if time)

3 of 27

Signal processing: motivating example

Figure from Paiva, 2002

Fourier Transform

4 of 27

Signal processing - how is it useful in machine learning?

  • Perform transformation
    • Fourier Transform (FFT, STFT)
    • Envelope extraction (Hilbert Transform)
  • Feature extraction
    • Filtering/Convolution
    • Complete pipeline: e.g. MFCCs

Deep learning came —

Signal processing methods become obsolete?

5 of 27

Why make signal processing differentiable?

  • Feature extraction pipelines
    • Informed of regularities in data & human perception
    • Not complex enough to handle variability in data
  • Deep, data-driven, end-to-end
    • Sensitive to noise and less robust in generalization
    • Large, expensive and slow

How to utilize the knowledge from signal processing methods and utilize the flexibility of deep learning?

6 of 27

Example 1: DDSP (differentiable digital signal processing)

Background: vocoders

Figure taken from Quatieri, Discrete-Time Speech Signal Processing

7 of 27

Example 1: DDSP (differentiable digital signal processing)

  • Demo (Bell Labs, ca. 1940)

→ It’s 2023 now, can we make this better-sounding?

8 of 27

Example 1: DDSP (differentiable digital signal processing)

  • Combines deep (red) blocks with signal processing (yellow) blocks
  • Used in speech & audio synthesis
  • Audio demos: https://storage.googleapis.com/ddsp/index.html

Engel et al, 2020 (ICLR)

pitch

timbre

9 of 27

Example 2: differentiable CAR-FAC

  • Model of the cochlea:
    • Frequency Decomposition
  • Brute-force way: Fourier Transform
    • Independent filters
  • But filters in the ear are dependent

→ Cascade of filters

→ Use back propagation to update the filter parameters

(ongoing work by Google)

Figures taken from Schnupp, 2011 and Lyon, 2011

10 of 27

Overview

  • Signal Processing
    • What can signal processing achieve
    • Why making it differentiable?
    • Example: DDSP
  • The JAX library
    • Signal Processing in JAX: numpy/scipy with GPU
    • Using autodiff with JAX
    • Limitations & Issues
    • Building (data-driven) deep models with JAX: flax and optax
  • Demo: differentiable cochlear processing

11 of 27

What is JAX?

12 of 27

Signal Processing with GPU

  • JAX has its own implementation of common functions for signal processing
  • Similar syntax to numpy/scipy
  • Also theoretically faster, due to GPU accelerations, etc

convolve

x

y

z

y

x

z

13 of 27

Making signal processing differentiable: autodiff with JAX

  • Operations in JAX are differentiable
  • cf. a convolutional neural network layer
  • N.B.: you can make it work with e.g. pytorch, too
    • Advantages of JAX: unified syntax, more functions implemented, etc

y

x

z

convolve

x

y

z

Loss function

s

ds/dz

ds/dx

Update signal

ds/dy

Update filter

14 of 27

Differentiable signal processing example: Improved loss function (1)

  • Calculating loss in time-domain may be misleading
  • Compute loss in different domains:
    • E.g., Frequency domain

  • Auraloss

Steimetz & Reiss, 2020. https://github.com/csteinmetz1/auraloss

ŷ

STFT

STFT

y

Loss function

L (loss)

Back-propagation

dL/dŷ

x

DNN

dL/dw

15 of 27

Differentiable signal processing example: Improved loss function (2)

  • Calculating loss in time-domain may be misleading
  • Compute loss in different domains:
    • Cortical modulation (STRF) domain

Vuong et al., ICASSP 2021. arxiv.org/abs/2102.07330

Signal processing/mathematical operations

Back-propagation

16 of 27

Connecting JAX with deep learning: FLAX

  • Torch-like syntax:
    • Define a model as nn.module
    • Write NN layers in linear order
    • Backward pass using jax.grad()
    • Update using an optimizer
      • Optax
  • Community

For more, see https://github.com/google/flax/tree/main/examples/vae/

17 of 27

Connecting JAX with deep learning: FLAX

  • Backward pass using jax.grad()
    • Inherent in JAX
    • Physics/DSP-based models can be added here

18 of 27

Connecting JAX with deep learning: FLAX

  • Community:
    • Big projects written in JAX, or translated into JAX
    • Active forum?

19 of 27

Some other caveats

  • Jax defaults to 32-bit computing
    • Change to 64-bit using config.update("jax_enable_x64", True)

20 of 27

Issues of Jax Ex. 1: acceleration and function purity

  • Just-in-time compilation: @jit
  • Very fast!
  • Impure functions would behave erratically, or throw errors
    • E.g., print() inside function
    • E.g., using global variables
  • Tips:
    • When getting started, write small functions in jax and compare closely with non-jax functions
    • Also compare jax code with and without jitting
    • Expect the first compilation to be slow, due to compiling

For more, see https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html

21 of 27

Issues of Jax Ex. 2: immutable arrays

  • jnp.array is built to be handled just like np.array, but it is ultimately immutable
  • As a result, changing values in a jnp array needs a small workaround
  • Similarly, some truth evaluations do not work with jnp.arrays

  • Convert np.array to jnp.array is easy:

arr = np.zeros([2,2])

arr = jax.np.array(arr)

  • But this is slow, so plan accordingly

22 of 27

JAX for differentiable programming: pros and cons

Pros

  • All computations can be GPU-accelerated (or TPU)
    • Scalable
    • May be much faster depending on the applications
    • Worth using even if no differentiability is in the picture
  • Easy differentiability & customization of the computational graph
    • Also possible in torch, but straightforward in JAX

Cons

  • Learning curve for pure functions, immutable computing

Other main differences:

    • The backend of JAX is much more different from other Python libraries
    • Although torch has a longer history, the JAX community is growing, and support is sufficient

For more, see https://www.reddit.com/r/MachineLearning/comments/shsfkm/d_current_state_of_jax_vs_pytorch/

23 of 27

Demo: differentiable hearing model

24 of 27

Step 1: Implement the physical model

  • Implement a non-jax forward model first
  • Translate into a JAX version
  • Compare outputs
    • Non-JAX version: ~100 ms
    • JAX version: 1+ min
  • Find alternative implementation for suboptimal processes
    • Optimized JAX: ~300 ms

Slow when x=1600

25 of 27

Step 2: obtaining gradients

  • Try calling jax.grad() on the forward pass
    • Received NaNs
    • From JAX community: turn on NaN debugging
    • NaN debugger pointed to jnp.phase(): getting phase of some complex values
    • Replaced zero with small arbitrary numbers

26 of 27

Step 3: Connecting the physical model to deep models

  • Find a similar repository from Flax: examples
  • Two ways to combine the physical model and deep model:
    • Implement the physical model as a part of the deep model, using
      • nn.parameter() or nn.variable()
      • Code will look nicer, but takes extra effort to write
    • Keep two models separate, get gradient respectively, and update each part
      • Easy to write
  • Find a dataloader
    • Jax does not have its own dataloader!
    • Compatible with either torch or tensorflow dataloader
    • but must convert from tensor to jnp.array (takes extra time…)

27 of 27

https://www.reddit.com/r/learnmachinelearning/comments/16vgfed/is_jax_a_better_choice_to_focus_on_over_pytorch/