Differentiable Signal Processing with Jax
Student Lecture
Leslie Li
11/30/2023
Overview
Signal processing: motivating example
Figure from Paiva, 2002
Fourier Transform
Signal processing - how is it useful in machine learning?
Deep learning came —
Signal processing methods become obsolete?
Why make signal processing differentiable?
How to utilize the knowledge from signal processing methods and utilize the flexibility of deep learning?
Example 1: DDSP (differentiable digital signal processing)
Background: vocoders
Figure taken from Quatieri, Discrete-Time Speech Signal Processing
Example 1: DDSP (differentiable digital signal processing)
→ It’s 2023 now, can we make this better-sounding?
Example 1: DDSP (differentiable digital signal processing)
Engel et al, 2020 (ICLR)
pitch
timbre
Example 2: differentiable CAR-FAC
→ Cascade of filters
→ Use back propagation to update the filter parameters
(ongoing work by Google)
Figures taken from Schnupp, 2011 and Lyon, 2011
Overview
What is JAX?
Signal Processing with GPU
convolve
x
y
z
y
x
z
Making signal processing differentiable: autodiff with JAX
y
x
z
convolve
x
y
z
Loss function
s
ds/dz
ds/dx
Update signal
ds/dy
Update filter
Differentiable signal processing example: Improved loss function (1)
Steimetz & Reiss, 2020. https://github.com/csteinmetz1/auraloss
ŷ
STFT
STFT
y
Loss function
L (loss)
Back-propagation
dL/dŷ
x
DNN
dL/dw
Differentiable signal processing example: Improved loss function (2)
Vuong et al., ICASSP 2021. arxiv.org/abs/2102.07330
Signal processing/mathematical operations
Back-propagation
Connecting JAX with deep learning: FLAX
For more, see https://github.com/google/flax/tree/main/examples/vae/
Connecting JAX with deep learning: FLAX
Connecting JAX with deep learning: FLAX
Some other caveats
Issues of Jax Ex. 1: acceleration and function purity
For more, see https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html
Issues of Jax Ex. 2: immutable arrays
arr = np.zeros([2,2])
arr = jax.np.array(arr)
JAX for differentiable programming: pros and cons
Pros
Cons
Other main differences:
For more, see https://www.reddit.com/r/MachineLearning/comments/shsfkm/d_current_state_of_jax_vs_pytorch/
Demo: differentiable hearing model
Step 1: Implement the physical model
Slow when x=1600
Step 2: obtaining gradients
Step 3: Connecting the physical model to deep models
https://www.reddit.com/r/learnmachinelearning/comments/16vgfed/is_jax_a_better_choice_to_focus_on_over_pytorch/