1 of 23

Probabilistic Foundations of Machine Learning

I am: Yaniv / Professor Yacoby

I use: he/they

2 of 23

Announcements

  • HW3 Tip:
    • Copy code from the MLE chapters into DeepNote
    • Print things out
    • Play with it
    • Adapted, little by little, for HW3

3 of 23

Today: Optimization

So far, we:

  • Used joint distributions / DGMs to represent our modeling assumptions
  • Translated these into NumPyro

Next, we want to perform the MLE: but

  • How do we maximize the probability of the data given the parameters?

Today:

  • A simple, general-purpose optimization algorithm

4 of 23

Our generative process

in code.

5 of 23

Performing the MLE:

6 of 23

Checking for Convergence

7 of 23

Our goal: to minimize our “loss”

8 of 23

How can we minimize a function? Let’s get some intuition:

Idea: look for places where the derivative of the loss is 0.

9 of 23

Global vs. Local Optima

Looking at figure,

  • Different types of optima
  • Derivative is 0 at all optima
  • But we only want global minima
  • So is this still useful?

Answer: Yes. How?

  • Get all optima
  • Plug each into loss
  • See which one is smallest

10 of 23

Initial Algorithm

…shall we try it?

11 of 23

Analytic MLE, Step 1: What’s our model?

12 of 23

Analytic MLE, Step 2: What’s our joint data likelihood?

13 of 23

Analytic MLE, Step 3: What’s our loss function?

14 of 23

Analytic MLE, Step 4: Minimize our loss

15 of 23

Pros and Cons of Analytic MLE

Pros:

  • Fast: just compute a formula!

Cons:

  • Analytic solution doesn’t exist for most models
  • It’s a lot of work…

16 of 23

Alternative: Use Numeric Optimization Algorithm

  • Algorithm: Gradient Descent
  • Gradient = derivative (in high dims)
  • Gradient/derivative point in the direction of steepest ascent
  • Idea: take steps following the direction of the gradient

17 of 23

Gradient Descent

  • Repeat: 3 steps
  • How are derivatives computed? Jax.

18 of 23

Simulations (see chapter)

19 of 23

Gradient Descent in Jax

Note: you should not implement/use this. Use the code we provided.

20 of 23

Challenges with Numeric Optimization

  • It’s iterative (slow)
  • It’s approximate (gets stuck in local optima)
  • It’s sensitive to initialization (gets stuck in local optima)
  • It adds diagnostic challenges:
    • When model fits data poorly, is the model the problem? Or is the optimization too hard?

21 of 23

Constraining Parameters to Valid Ranges

  • For initial Bernoulli example,
  • Gradient descent might step outside these bounds
  • How to fix this?
  • Idea:
    • Define
    • Then, optimize with respect to psi
  • That is, our original problem is:

  • Instead, we solve:

  • This is how NumPyro implements C.unit_interval

22 of 23

In pairs: continue implementing the IHH-ER model in NumPyro!

23 of 23

That’s all for today!

Questions?

  • Tell your partner: what did you appreciate about working with them today!