1 of 10

Identifying Global Minimum of High-Dimensional Non-Convex functions

2 of 10

Motivation: Structural Biology

  • For many open questions in structural biology (e.g. DNA-protein, DNA-RNA, etc.), the number of known structures is very low
    • Hard to train an accurate supervised model (e.g. AlphaFold)
  • However, we have accurate energy functions (e.g. Amber Force Field)
    • If we could effectively minimize these functions, we could find the structure

3 of 10

Can NNs identify the global minimum? No.

  • Before we can “extract” the global minimum from an NN, an NN needs to correctly identify it
  • Experiment: Trained transformer to learn the 100-D Lennard Jones function with 10,000 data points. It could not correctly identify the global minimum.
    • NTK suggests that neural networks don’t do a good job assigning a y-value outside of the training range
  • NN cannot identify the global minimum through the y-value alone
  • Tested out autoencoder composed with ICNN → didn’t work well

4 of 10

Protein Trajectory

Let’s model the protein’s trajectory with an SDE

  • Assumptions
    • Inertia is negligible
      • Masses of all particles are the same (and can be ignored)
    • Noise is independent of state

5 of 10

How can we identify the global minimum?

  • Where xi is the initial state, xf is the final state, and S[x(s)] is the action (MSE deviation from gradient descent).
  • For all xi, the expression is maximized when xf = global minimum
    • This is the final structure
  • Dominated by probable paths
    • Approximation: pick the most probable path.
  • Thus, we want to find an xf such that there is a “probable path” from every xi to this xf

6 of 10

Forcing Gradient Descent Trajectories to Converge: Options

  • Converging in the real space (e.g. contractive function) is too restrictive
  • Need to converge in the latent space → after one step of gradient descent on f(x), the latent space gets closer to final point
  • In the latent space
    • Fix final point
      • Regularize direction
        • NN disregards the loss
      • Fix direction, let NN learn magnitude
        • Too restrictive, NN can’t learn
    • Don’t fix final point
      • Contractive Neural Networks (Lipschitz < 1)
        • Encoder and Decoder are not in sync
        • Next architecture: towards getting rid of the encoder -> and finding the initial latent vector via GD

7 of 10

Forcing Gradient Descent Trajectories to Converge: Try 1

Consistency between decoder and encoder for latent space point T(L)

Encoder

Decoder

Decoder

Encoder

x

x - f’(x)

After every step of gradient descent, the point in the latent space needs to converge. We apply a contractive function, T, with Lipschitz(T) < 1.

L

T(L)

x

x - f’(x)

Loss = MSE(x, D(E(x))) + MSE(x - f’(x), D(T(E(x)))) + MSE(E(D(T(L))), T(L))

  • Lipschitz func was too contractive → Decoder learned different func for L and T(L)
  • Adding a regularizer couldn’t fix this

8 of 10

Forcing Gradient Descent Trajectories to Converge: Try 2

Loss = MSE(D(E(x)), x) + MSE(D(E(x-f’(x)), x- f’(x)) - α

Encoder

Decoder

Decoder

Encoder

x

x - f’(x)

P

After every step of gradient descent, the latent space vector needs to be closer to a predefined “central” point P.

L1 = E(x)

L2 = E(x-f’(x))

  • Problem → not getting closer to P
  • Also: tried regularizing the norm (setting P as the zero vector)

9 of 10

Forcing Gradient Descent Trajectories to Converge: Try 3

Encoder

Decoder

Decoder

Encoder

x

x - f’(x)

P

After every step of gradient descent, the latent space vector needs to be closer to a predefined “central” point P.

L1 = E(x)

L2 = E(x-f’(x))

  • L2 = slerp(L1, P, t)
  • t is learned by NN
  • t_grad = 0 → little alignment between L2_grad and direction(L1, P)
  • Given a point in latent space, trajectory is hardcoded
  • distance(L1,P) is identical across training points → latent space doesn’t generalize

10 of 10

Broader Thoughts

  • Simplest approach would be to
    • (1) train an NN to mimic the energy function
    • And then (2) “project” this function onto the space of convex functions (for an input-convex neural networks, weights just need to be positive)
    • While training, can minimize the L2 norm between the two
    • But distance in the weight space has a complex relationship with distance in function space
  • Current approach → mapping from convex function to energy function is an NN, instead of identity matrix
    • NNs are very expressive, so a lot of regularization is needed
  • Alternative approach → collect random local minima, train NN to produce better local minima
    • Problem: # of local minima increases exponentially with dimension, so range of variation (signal for the NN) for these local minima will also decrease with dimension