1 of 30

Bridging Machine Learning and Scientific Computing

mike@juliacomputing.com

2 of 30

A powerful, high level language with high performance.

Pythonic, mathematical syntax that looks like notation.

Performance consistently within 2x of tuned C code.

Most of Julia is written in Julia!

function mandel(z)� c = z� maxiter = 80for n = 1:maxiter� if abs(z) > 2return n-1end� z = z^2 + c� endreturn maxiter�end

3 of 30

4 of 30

Scientific Computing

5 of 30

Machine Learning

MLJ

6 of 30

Machine Learning

Scientific Computing

High-level and flexible (Python)

High overhead, focus on tensor operations and manual vectorisation

Relatively simple programs (network architectures)

Mutation support considered advanced/unusual.

Low-level and manual (Fortran)

Low overhead, focus on scalar operations

Regularly run over millions of lines of code.

Research on auto-vectorisation, shared memory parallelism, checkpointing etc.

7 of 30

8 of 30

Languages that are both high level and high performance.

the key:

9 of 30

Tapenade

Ruthless pragmatism and scalability.�Output can be highly optimised using existing optimising compilers.

λ the Ultimate Backpropagator

Elegant recursive formalism, including nested AD (closure), convenience (callee-derives) and bags of expressive power.

10 of 30

11 of 30

function pow(x, n)

r = 1

while n > 0

n -= 1

r *= x

end

return r

end

User Function

1: (%2, %3)

br 2 (%3, 1)

2: (%4, %5)

%6 = %4 > 0

br 4 unless %6

br 3

3:

%7 = %4 - 1

%8 = %5 * %2

br 2 (%7, %8)

4:

return %5

Primal

1: (%1)

br 2 (%1, 0)

2: (%2, %4)

br 4 unless @6

br 3

3:

%10 = %2 * @2

%11 = %2 * @5

%14 = %4 + %11

br 2 (%10, %14)

4:

return (%4, 0)

Adjoint

pow(5, 3) == 125�gradient(pow, 5, 3) == (75, 0)

12 of 30

13 of 30

function foo(x)

a = bar(x)

b = baz(a)

return b

end

function J(::typeof(foo), x)

a, da = J(bar, x)

b, db = J(baz, a)

return b, function(b̄)

ā = db(b̄)

x̄ = da(ā)

return

end

end

14 of 30

15 of 30

J(::typeof(sin), x) = sin(x), ȳ -> ȳ*cos(x)

@adjoint sin(x) = sin(x), ȳ -> ȳ*cos(x)

Core compiler pass is ~200 lines of code

All semantics added via custom adjoints –�mutation, data structures, checkpointing, etc.

16 of 30

nestlevel() = 0

@adjoint nestlevel() = nestlevel()+1, _ -> nothing

julia> function f(x)

println(nestlevel(), " levels of nesting")

return x

end

julia> f(1);

0 levels of nesting

julia> grad(f, 1);

1 levels of nesting

julia> grad(x -> x*grad(f, x), 1);

2 levels of nesting

17 of 30

@adjoint checkpoint(f, x...) =

f(x...), Δ -> J(f, x...)[2](Δ)

@adjoint hook(f, x) = x, Δ -> (f(Δ),)

hook(-, x) # reverse the gradient of x

@adjoint function forwarddiff(f, x)

y, J = forward_jacobian(f, x)

y, Δ -> (J’Δ,)

end

18 of 30

19 of 30

Differentiation á la Carte

  • Mixed-mode AD (forward, reverse, Taylor series, …)
  • Forward-over-reverse (Hessians)
  • Cross-language AD
  • Support for Complex and other number types
  • Easy custom gradients
  • Checkpointing
  • Gradient hooks
  • Custom types (colours!)
  • Hardware backends: CPU, CUDA, TPU, …
  • Deeply nested AD (WIP)

20 of 30

Data Structures & Mutation

21 of 30

22 of 30

Deep learning in 5 lines.

23 of 30

24 of 30

Some Bonus Features

25 of 30

26 of 30

27 of 30

28 of 30

@adjoint function pycall(f, x...; kw...)

x = map(py, x)

y = pycall(f, x...; kw...)

y.detach().numpy(), function (ȳ)

y.backward(gradient = py(ȳ))

(nothing, map(x -> x.grad.numpy(), x)...)

end

end

29 of 30

Future Challenges

  • Mutation of values is hard
  • Need adjoints to cover the entire standard library
  • Compiler improvements
    • More functional-style optimisations
    • Better heuristics for AD-generated code
  • Fast code vs. dynamic semantics
  • Differentiating Julia’s concurrency and parallelism constructs
  • Reducing overheads: currently ~50ns per operation
    • Great compared to ML frameworks but far from optimal

30 of 30

Unifying Machine Learning and Scientific Computing

mike@juliacomputing.com