1 of 35

Growing a Compiler

Getting to machine learning from a general purpose compiler

CGO C4ML 2019

Keno Fischer and Jameson Nash

With Tim Besard, James Bradbury, Valentin Churavy, Dhairya Gandhi, Mike Innes, Neethu Joy, Tejan Karmali, Matt Kelley, Avik Pal, Chris Rackauckas, Marco Rudilosso, Elliot Saba, Viral Shah, and Deniz Yuret

2 of 35

Neural ODEs

Combining ML and Differential Equation libraries

DiffEqFlux.jl — A Julia Library for Neural Differential Equations (arXiv:1902.02376)

3 of 35

Composability of Libraries

  • Numbers (Integer, Floating Point, Complex, Quaternion, Rational, Fixed Point, non-standard precision, Intervals, Unitful, Symbolic, Dual, Grassmanian, Log-Domain, Polynomial, ...)
  • Arrays (Matrices, Tensors, Offset, Named, Remote, ...)
  • Linear Algebra (Factorizations, Arithmetic, Infinite, Finite Fields, ...)
  • Differential Equations
  • Convex Optimization
  • Parallel Computing
  • Data Science
  • Graphics
  • Machine Learning
  • Image Processing Libraries

4 of 35

Composability of Libraries

  • Numbers (Integer, Floating Point, Complex, Quaternion, Rational, Fixed Point, non-standard precision, Intervals, Unitful, Symbolic, Dual, Grassmanian, Log-Domain, Polynomial, ...)
  • Arrays (Matrices, Tensors, Offset, Named, Remote, ...)
  • Linear Algebra (Factorizations, Arithmetic, Infinite, Finite Fields, ...)
  • Differential Equations
  • Convex Optimization
  • Parallel Computing
  • Data Science
  • Graphics
  • Machine Learning
  • Image Processing Libraries

Solved

5 of 35

Composability of Compiler Transforms?

Transforms / Programming Models

Optimizations

Code Generators

Automatic Differentiation

SPMD

Task Parallelism

Data Query

Interval Constraint Programming

Disciplined Convex Programming

Tracing/Profiling/Debugging

Verification/Formal Methods

Scalar

Tensor

Parallel/Distributed

Data Layout

Data Access (Polyhedral)

Relational (Query Compilers)

Symbolic

CPU

GPU

TPU / ML Accelerators

Virtual Machines (WebAssembly)

FPGA / Silicon

Quantum Computers

Homomorphic Encryption

...

...

...

6 of 35

Composability of Compiler Transforms?

Transforms / Programming Models

Optimizations

Code Generators

Automatic Differentiation

SPMD

Task Parallelism

Data Query

Interval Constraint Programming

Disciplined Convex Programming

Tracing/Profiling/Debugging

Verification/Formal Methods

Scalar

Tensor

Parallel/Distributed

Data Layout

Data Access (Polyhedral)

Relational (Query Compilers)

Symbolic

CPU

GPU

TPU / ML Accelerators

Virtual Machines (WebAssembly)

FPGA / Silicon

Quantum Computers

Homomorphic Encryption

...

...

...

Not Solved

7 of 35

Inference Deep Dive

Extracting Information From Dynamic Programs

8 of 35

Extensive Standard Library

julia> @which sin(1.0)

sin(x::T) where T<:Union{Float32, Float64} in Base.Math at special/trig.jl:30

julia> using Base.Math: @horner

julia> begin

# Coefficients in 13th order polynomial approximation on [0; π/4]

# sin(x) ≈ x + S1*x³ + S2*x⁵ + S3*x⁷ + S4*x⁹ + S5*x¹¹ + S6*x¹³

# D for double, S for sin, number is the order of x-1

const DS1 = -1.66666666666666324348e-01

const DS2 = 8.33333333332248946124e-03

const DS3 = -1.98412698298579493134e-04

const DS4 = 2.75573137070700676789e-06

const DS5 = -2.50507602534068634195e-08

const DS6 = 1.58969099521155010221e-10

function sin_kernel_f64(y)

y² = y*y

y⁴ = y²*y²

r = @horner(y², DS2, DS3, DS4) + y²*y⁴*@horner(y², DS5, DS6)

y³ = y²*y

return y+y³*(DS1+y²*r)

end

end;

Avoiding “boilerplate” syntax: no types, interfaces, or lifetime constraints mentioned. The polynomial evaluation (Horner’s method) is also concisely and efficiently specified with a macro

9 of 35

Dynamic semantics

+

Static Analysis:

Just-Ahead-of-Time analysis to extract static properties of code

10 of 35

julia> methods(Base.Math.muladd)

# 12 methods for generic function "muladd":

[1] muladd(a::Float16, b::Float16, c::Float16)

in Base at float.jl:406

[2] muladd(x::Float64, y::Float64, z::Float64)

in Base at float.jl:404

[3] muladd(x::Float32, y::Float32, z::Float32)

in Base at float.jl:403

...

[10] muladd(x::T, y::T, z::T) where T<:Number

in Base at promotion.jl:397

[11] muladd(x::Number, y::Number, z::Number)

in Base at promotion.jl:348

[12] muladd(x, y, z)

in Base.Math at math.jl:1011

julia> methods(+)

# 161 methods for generic function "+":

[1] +(x::Bool, y::Bool)

in Base at bool.jl:96

[2] +(a::Float16, b::Float16)

in Base at float.jl:392

[3] +(x::Float32, y::Float32)

in Base at float.jl:394

[4] +(x::Float64, y::Float64)

in Base at float.jl:395

[5] +(c::BigInt, x::BigFloat)

in Base.MPFR at mpfr.jl:408

[6] +(a::BigInt, b::BigInt, c::BigInt, d::BigInt, e::BigInt)

in Base.GMP at gmp.jl:436

...

julia> code_lowered(sin_kernel_f64)

1 ─ nothing

@ REPL[29]:12 within `sin_kernel_f64'

y² = y * y

@ REPL[29]:13 within `sin_kernel_f64'

y⁴ = y² * y²

@ REPL[29]:14 within `sin_kernel_f64'

t@_3 = y²

┌ @ base/math.jl:101 within `@horner'

%5 = t@_3

%6 = Base.Math.muladd(t@_3, Main.DS4, Main.DS3)

r@_4 = Base.Math.muladd(%5, %6, Main.DS2)

%8 = r@_4

%9 = y²

%10 = y⁴

t@_5 = y²

┌ @ base/math.jl:101 within `@horner'

r@_6 = Base.Math.muladd(t@_5, Main.DS6, Main.DS5)

%13 = r@_6

%14 = %9 * %10 * %13

r@_9 = %8 + %14

@ REPL[29]:15 within `sin_kernel_f64'

y³ = y² * y

@ REPL[29]:16 within `sin_kernel_f64'

%17 = y³

%18 = y² * r@_9

%19 = Main.DS1 + %18

%20 = %17 * %19

%21 = y + %20

└── return %21

Everything is a virtual function call.

Oh no!

11 of 35

julia> code_lowered(sin_kernel_f64)

1 ─ nothing

@ REPL[29]:12 within `sin_kernel_f64'

y² = y * y

@ REPL[29]:13 within `sin_kernel_f64'

y⁴ = y² * y²

@ REPL[29]:14 within `sin_kernel_f64'

t@_3 = y²

┌ @ base/math.jl:101 within `@horner'

%5 = t@_3

%6 = Base.Math.muladd(t@_3, Main.DS4, Main.DS3)

r@_4 = Base.Math.muladd(%5, %6, Main.DS2)

%8 = r@_4

%9 = y²

%10 = y⁴

t@_5 = y²

┌ @ base/math.jl:101 within `@horner'

r@_6 = Base.Math.muladd(t@_5, Main.DS6, Main.DS5)

%13 = r@_6

%14 = %9 * %10 * %13

r@_9 = %8 + %14

@ REPL[29]:15 within `sin_kernel_f64'

y³ = y² * y

@ REPL[29]:16 within `sin_kernel_f64'

%17 = y³

%18 = y² * r@_9

%19 = Main.DS1 + %18

%20 = %17 * %19

%21 = y + %20

└── return %21

julia> @code_warntype sin_kernel_f64(0.3)

Body::Float64

1 ─ nothing

(y² = y * y):Float64

(y⁴ = y² * y²):Float64

(t@_3 = y²):Float64

%5 = t@_3::Float64

%6 = Base.Math.muladd(t@_3, Main.DS4, Main.DS3)::Float64

(r@_4 = Base.Math.muladd(%5, %6, Main.DS2)):Float64

%8 = r@_4::Float64

%9 = y²::Float64

%10 = y⁴::Float64

(t@_5 = y²):Float64

(r@_6 = Base.Math.muladd(t@_5, Main.DS6, Main.DS5)):Float64

%13 = r@_6::Float64

%14 = (%9 * %10 * %13)::Float64

(r@_9 = %8 + %14):Float64

(y³ = y² * y):Float64

%17 = y³::Float64

%18 = (y² * r@_9)::Float64

%19 = (Main.DS1 + %18)::Float64

%20 = (%17 * %19)::Float64

%21 = (y + %20)::Float64

└── return %21:Float64

Variables

#self#::Core.Compiler.Const(sin_kernel_f64)

y::Float64 t@_3::Float64 r@_4::Float64 t@_5::Float64 r@_6::Float64

::Float64 y⁴::Float64 r@_9::Float64 ::Float64

But the compiler can devirtualize and simplify it.

And simple is fast.

12 of 35

Simple Auto-Differentiation: Forward Derivatives

julia> struct Dual{T<:Real} <: Real

x::T # the value

dx::T # the derivative at that point

end;

# Some convenience constructors describing the derivative of a constant

julia> Dual(x) = Dual(x, zero(x));

julia> Dual{T}(x) where {T} = Dual{T}(x, zero(x));

# Compose with existing types of numbers

julia> Base.promote_rule(::Type{Dual{T}}, ::Type{S}) where {T, S} =

Dual{Base.promote_type(T, S)};

# Define the base cases for arithmetic

julia> Base.:+(a::Dual{T}, b::Dual{T}) where {T} =

Dual{T}(a.x + b.x, a.dx + b.dx);

julia> Base.:*(a::Dual{T}, b::Dual{T}) where {T} =

Dual{T}(a.x * b.x, a.x*b.dx + a.dx*b.x);

Small abstraction to define a new subtype of real numbers

13 of 35

Simple Auto-Differentiation: Forward Derivatives

julia> wrt(x) = Dual(x, typeof(x)(1)) # helper for taking derivative “with-respect-to” this parameter

julia> sin_kernel_f64(wrt(.3)) |> dump # first derivative

Dual{Float64}

x: Float64 0.29552020666133955

dx: Float64 0.955336489125606

julia> sin_kernel_f64(wrt(wrt(.3))) |> dump # first and second derivatives

Dual{Dual{Float64}}

x: Dual{Float64}

x: Float64 0.29552020666133955

dx: Float64 0.955336489125606

dx: Dual{Float64}

x: Float64 0.955336489125606

dx: Float64 -0.2955202066613394

julia> sincos(0.3) # ground truth

(0.29552020666133955, 0.955336489125606)

And it works!

14 of 35

julia> @code_warntype sin_kernel_f64(wrt(0.3))

Body::Dual{Float64}

1 ─ nothing

(y² = y * y)::Dual{Float64}

(y⁴ = y² * y²)::Dual{Float64}

(t@_3 = y²)::Dual{Float64}

%5 = t@_3::Dual{Float64}

%6 = Base.Math.muladd(t@_3, Main.DS4, Main.DS3)::Dual{Float64}

%7 = Base.Math.muladd(%5, %6, Main.DS2)::Dual{Float64}

%8 = y²::Dual{Float64}

%9 = y⁴::Dual{Float64}

(t@_4 = y²)::Dual{Float64}

%11 = Base.Math.muladd(t@_4, Main.DS6, Main.DS5)::Dual{Float64}

%12 = (%8 * %9 * %11)::Dual{Float64}

(r = %7 + %12)::Dual{Float64}

(y³ = y² * y)::Dual{Float64}

%15 = y³::Dual{Float64}

%16 = (y² * r)::Dual{Float64}

%17 = (Main.DS1 + %16)::Dual{Float64}

%18 = (%15 * %17)::Dual{Float64}

%19 = (y + %18)::Dual{Float64}

└── return %19::Dual{Float64}

Custom user types no different from “builtin” ones:

It’s still inferred fully!

15 of 35

julia> @code_native sin_kernel_f64(wrt(wrt(.3)))

.text

vmovupd (%rsi), %ymm21

vmulsd %xmm21, %xmm21, %xmm1

vpermilpd $1, %xmm21, %xmm9 # xmm9 = xmm21[1,0]

vmulsd %xmm9, %xmm21, %xmm2

vaddsd %xmm2, %xmm2, %xmm2

vextractf32x4 $1, %ymm21, %xmm11

vmulsd %xmm11, %xmm21, %xmm3

vpermilpd $1, %xmm11, %xmm8 # xmm8 = xmm11[1,0]

vmulsd %xmm8, %xmm21, %xmm4

vmulsd %xmm11, %xmm9, %xmm5

vaddsd %xmm5, %xmm4, %xmm4

vaddsd %xmm3, %xmm3, %xmm3

vaddsd %xmm4, %xmm4, %xmm16

vmulsd %xmm1, %xmm1, %xmm12

vmulsd %xmm2, %xmm1, %xmm6

vaddsd %xmm6, %xmm6, %xmm10

vmulsd %xmm3, %xmm1, %xmm6

vmulsd %xmm16, %xmm1, %xmm7

vmulsd %xmm3, %xmm2, %xmm5

vaddsd %xmm7, %xmm5, %xmm5

vaddsd %xmm6, %xmm6, %xmm14

vaddsd %xmm5, %xmm5, %xmm13

movabsq $140702207837800, %rax # imm = 0x7FF7C91E0668

vmovsd (%rax), %xmm5 # xmm5 = mem[0],zero

vmulsd %xmm5, %xmm1, %xmm15

vxorpd %xmm20, %xmm20, %xmm20

vmulsd %xmm20, %xmm1, %xmm7

vmulsd %xmm5, %xmm2, %xmm4

vaddsd %xmm4, %xmm7, %xmm17

vmulsd %xmm20, %xmm2, %xmm4

vaddsd %xmm4, %xmm7, %xmm18

vmulsd %xmm5, %xmm3, %xmm4

vmulsd %xmm20, %xmm3, %xmm19

vmulsd %xmm5, %xmm16, %xmm5

vaddsd %xmm5, %xmm19, %xmm5

vaddsd %xmm4, %xmm7, %xmm4

vaddsd %xmm5, %xmm18, %xmm5

movabsq $140702207837808, %rax # imm = 0x7FF7C91E0670

vaddsd (%rax), %xmm15, %xmm6

vaddsd %xmm20, %xmm17, %xmm0

vaddsd %xmm20, %xmm4, %xmm4

vaddsd %xmm20, %xmm5, %xmm17

vmulsd %xmm6, %xmm1, %xmm15

vmulsd %xmm0, %xmm1, %xmm22

vmulsd %xmm6, %xmm2, %xmm5

vaddsd %xmm22, %xmm5, %xmm22

vmulsd %xmm4, %xmm1, %xmm23

vmulsd %xmm17, %xmm1, %xmm5

vmulsd %xmm4, %xmm2, %xmm4

vaddsd %xmm5, %xmm4, %xmm4

vmulsd %xmm3, %xmm6, %xmm5

vaddsd %xmm23, %xmm5, %xmm5

vmulsd %xmm3, %xmm0, %xmm0

vmulsd %xmm16, %xmm6, %xmm6

vaddsd %xmm6, %xmm0, %xmm0

vaddsd %xmm4, %xmm0, %xmm0

movabsq $140702207837816, %rax # imm = 0x7FF7C91E0678

vaddsd (%rax), %xmm15, %xmm23

vaddsd %xmm20, %xmm22, %xmm22

vaddsd %xmm20, %xmm5, %xmm17

vaddsd %xmm20, %xmm0, %xmm15

movabsq $140702207837824, %rax # imm = 0x7FF7C91E0680

vmovsd (%rax), %xmm0 # xmm0 = mem[0],zero

vmulsd %xmm0, %xmm1, %xmm5

vmulsd %xmm0, %xmm2, %xmm6

vaddsd %xmm6, %xmm7, %xmm6

vmulsd %xmm0, %xmm3, %xmm4

vmulsd %xmm0, %xmm16, %xmm0

vaddsd %xmm0, %xmm19, %xmm0

vaddsd %xmm4, %xmm7, %xmm4

vaddsd %xmm0, %xmm18, %xmm0

movabsq $140702207837832, %rax # imm = 0x7FF7C91E0688

vaddsd (%rax), %xmm5, %xmm5

vaddsd %xmm20, %xmm6, %xmm19

vaddsd %xmm20, %xmm4, %xmm24

vaddsd %xmm20, %xmm0, %xmm18

vmulsd %xmm12, %xmm1, %xmm7

vmulsd %xmm10, %xmm1, %xmm0

vmulsd %xmm2, %xmm12, %xmm6

vaddsd %xmm0, %xmm6, %xmm26

vmulsd %xmm14, %xmm1, %xmm25

vmulsd %xmm13, %xmm1, %xmm4

vmulsd %xmm14, %xmm2, %xmm6

vaddsd %xmm4, %xmm6, %xmm13

vmulsd %xmm3, %xmm12, %xmm6

vaddsd %xmm25, %xmm6, %xmm14

vmulsd %xmm3, %xmm10, %xmm4

vmulsd %xmm16, %xmm12, %xmm0

vaddsd %xmm0, %xmm4, %xmm0

vaddsd %xmm13, %xmm0, %xmm10

vmulsd %xmm5, %xmm7, %xmm4

vaddsd %xmm23, %xmm4, %xmm4

vmulsd %xmm19, %xmm7, %xmm0

vmulsd %xmm26, %xmm5, %xmm6

vaddsd %xmm6, %xmm0, %xmm0

vaddsd %xmm0, %xmm22, %xmm12

vmulsd %xmm24, %xmm7, %xmm6

vmulsd %xmm18, %xmm7, %xmm7

vmulsd %xmm24, %xmm26, %xmm0

vaddsd %xmm7, %xmm0, %xmm0

vmulsd %xmm14, %xmm5, %xmm7

vaddsd %xmm7, %xmm6, %xmm6

vaddsd %xmm6, %xmm17, %xmm17

vmulsd %xmm14, %xmm19, %xmm7

vmulsd %xmm10, %xmm5, %xmm5

vaddsd %xmm5, %xmm7, %xmm5

vaddsd %xmm5, %xmm0, %xmm0

vaddsd %xmm0, %xmm15, %xmm10

vmulsd %xmm21, %xmm1, %xmm5

vmulsd %xmm9, %xmm1, %xmm7

vmulsd %xmm21, %xmm2, %xmm0

vaddsd %xmm7, %xmm0, %xmm13

vmulsd %xmm11, %xmm1, %xmm7

vmulsd %xmm8, %xmm1, %xmm0

vmulsd %xmm11, %xmm2, %xmm6

vaddsd %xmm0, %xmm6, %xmm0

vmulsd %xmm21, %xmm3, %xmm6

vaddsd %xmm6, %xmm7, %xmm8

vmulsd %xmm9, %xmm3, %xmm7

vmulsd %xmm21, %xmm16, %xmm6

vaddsd %xmm7, %xmm6, %xmm6

vaddsd %xmm6, %xmm0, %xmm9

vmulsd %xmm4, %xmm1, %xmm6

vmulsd %xmm12, %xmm1, %xmm7

vmulsd %xmm4, %xmm2, %xmm0

vaddsd %xmm7, %xmm0, %xmm0

vmulsd %xmm17, %xmm1, %xmm7

vmulsd %xmm10, %xmm1, %xmm1

vmulsd %xmm17, %xmm2, %xmm2

vaddsd %xmm1, %xmm2, %xmm1

vmulsd %xmm3, %xmm4, %xmm2

vaddsd %xmm7, %xmm2, %xmm2

vmulsd %xmm12, %xmm3, %xmm3

vmulsd %xmm16, %xmm4, %xmm4

vaddsd %xmm4, %xmm3, %xmm3

vaddsd %xmm1, %xmm3, %xmm1

movabsq $140702207837840, %rax # imm = 0x7FF7C91E0690

vaddsd (%rax), %xmm6, %xmm3

vaddsd %xmm20, %xmm0, %xmm0

vaddsd %xmm20, %xmm2, %xmm2

vaddsd %xmm20, %xmm1, %xmm1

vmulsd %xmm5, %xmm3, %xmm4

vmulsd %xmm0, %xmm5, %xmm6

vmulsd %xmm13, %xmm3, %xmm7

vaddsd %xmm7, %xmm6, %xmm6

vmulsd %xmm2, %xmm5, %xmm7

vmulsd %xmm1, %xmm5, %xmm1

vmulsd %xmm2, %xmm13, %xmm2

vaddsd %xmm1, %xmm2, %xmm1

vmulsd %xmm8, %xmm3, %xmm2

vmulsd %xmm8, %xmm0, %xmm0

vmulsd %xmm9, %xmm3, %xmm3

vaddsd %xmm3, %xmm0, %xmm0

vaddsd %xmm7, %xmm2, %xmm2

vaddsd %xmm1, %xmm0, %xmm0

vunpcklpd %xmm0, %xmm2, %xmm0 # xmm0 = xmm2[0],xmm0[0]

vunpcklpd %xmm6, %xmm4, %xmm1 # xmm1 = xmm4[0],xmm6[0]

vinsertf128 $1, %xmm0, %ymm1, %ymm0

vaddpd %ymm21, %ymm0, %ymm0

vmovupd %ymm0, (%rdi)

movq %rdi, %rax

vzeroupper

retq

nop

movabsq $140702207837800, %rax # imm = 0x7FF7C91E0668

vmovsd (%rax), %xmm5 # xmm5 = mem[0],zero

vmulsd %xmm5, %xmm1, %xmm15

vxorpd %xmm20, %xmm20, %xmm20

vmulsd %xmm20, %xmm1, %xmm7

vmulsd %xmm5, %xmm2, %xmm4

vaddsd %xmm4, %xmm7, %xmm17

Eliminated runtime overhead of dynamic definition!

And it’s fast!

16 of 35

julia> function mysum(a)

s = zero(eltype(a))

for x in a

s += x

end

return s

end;

julia> @code_warntype mysum(1:10)

Body::Int64

@ REPL[86]:2 within `mysum'

1 ─ %1 = Main.eltype(a)::Core.Compiler.Const(Int64)

(s = Main.zero(%1))::Int64

@ REPL[86]:3 within `mysum'

%3 = a::UnitRange{Int64}

(@_4 = Base.iterate(%3))::Union{Nothing, Tuple{Int64,Int64}}

%5 = (@_4 === nothing)::Bool

%6 = Base.not_int(%5)::Bool

└── goto #4 if not %6

2 ┄ %8 = @_4::Tuple{Int64,Int64}::Tuple{Int64,Int64}

(x = Core.getfield(%8, 1))::Int64

%10 = Core.getfield(%8, 2)::Int64

@ REPL[86]:4 within `mysum'

(s = s + x)::Int64

(@_4 = Base.iterate(%3, %10))::Union{Nothing, Tuple{Int64,Int64}}

%13 = (@_4 === nothing)::Bool

%14 = Base.not_int(%13)::Bool

└── goto #4 if not %14

3 ─ goto #2

@ REPL[86]:6 within `mysum'

4 ┄ return s::Int64

Just a “simple” for-loop

But dynamic, like what

Variables

#self#::Core.Compiler.Const(mysum)

a::UnitRange{Int64}

s::Int64

@_4::Union{Nothing, Tuple{Int64,Int64}}

x::Int64

17 of 35

julia> sigma(n) = mysum(1:n);

julia> @code_llvm sigma(10)

define i64 @julia_sigma_12697(i64) {

top:

%1 = icmp sgt i64 %0, 0

br i1 %1, label %L7.L12_crit_edge, label %L29

L7.L12_crit_edge: ; preds = %top

%2 = shl nuw i64 %0, 1

%3 = add nsw i64 %0, -1

%4 = zext i64 %3 to i65

%5 = add nsw i64 %0, -2

%6 = zext i64 %5 to i65

%7 = mul i65 %4, %6

%8 = lshr i65 %7, 1

%9 = trunc i65 %8 to i64

%10 = add i64 %2, %9

%11 = add i64 %10, -1

br label %L29

L29: ; preds = %L7.L12_crit_edge, %top

%value_phi9 = phi i64 [ 0, %top ], [ %11, %L7.L12_crit_edge ]

ret i64 %value_phi9

}

julia> f() = sigma(10);

julia> @code_llvm f()

define i64 @julia_f_12704() {

top:

ret i64 55

}

Not even a loop, like woah!

Gone, all gone!

18 of 35

Machine Learning

Differentiable Programming

Fashionable Modelling with Flux (arXiv:1811.01457)

Building a Language and Compiler for Machine Learning (julialang:ml-language-compiler)

19 of 35

Zygote.jl - AD is a compiler problem

function foo(W, Y, x)

Z = W * Y

a = Z * x

b = Y * x

c = tanh.(b)

r = a + c

return r

end

function ∇foo(W, Y, x)

Z = W * Y

a = Z * x

b = Y * x

c, 𝒥tanh = ∇tanh.(b)

a + c, function (Δr)

Δc = Δr, Δa = Δr

(Δtanh, Δb) = 𝒥tanh(Δc)

(ΔY, Δx) = (Δb * x', Y' * Δb)

(ΔZ = Δa * x', Δx += Z' * Δa)

(ΔW = ΔZ * Y', ΔY = W' * ΔZ)

(nothing, ΔW, ΔY, Δx)

end

end

Note: Simplified to assume *,+ are compiler primitives (Not the case in the original implementation)

Note: Reverse-AD model (actual implementation uses mixed)

Don't Unroll Adjoint: Differentiating SSA-Form Programs (arXiv:1810.07951)

20 of 35

Zygote.jl - AD is a compiler problem

function foo(W, Y, x)

Z = W * Y

a = Z * x

b = Y * x

c = tanh.(b)

r = a + c

return r

end

function ∇foo(W, Y, x)

Z = W * Y

a = Z * x

b = Y * x

c, 𝒥tanh = ∇tanh.(b)

a + c, function (Δr)

Δc = Δr, Δa = Δr

(Δtanh, Δb) = 𝒥tanh(Δc)

(ΔY, Δx) = (Δb * x', Y' * Δb)

(ΔZ = Δa * x', Δx += Z' * Δa)

(ΔW = ΔZ * Y', ΔY = W' * ΔZ')

(nothing, ΔW, ΔY, Δx)

end

end

In the backwards pass

  • Inputs become Outputs
  • Outputs become Inputs

21 of 35

Zygote.jl - AD is a compiler problem

function foo(W, Y, x)

Z = W * Y

a = Z * x

b = Y * x

c = tanh.(b)

r = a + c

return r

end

function ∇foo(W, Y, x)

Z = W * Y

a = Z * x

b = Y * x

c, 𝒥tanh = ∇tanh.(b)

a + c, function (Δr)

Δc = Δr, Δa = Δr

(Δtanh, Δb) = 𝒥tanh(Δc)

(ΔY, Δx) = (Δb * x', Y' * Δb)

(ΔZ = Δa * x', Δx += Z' * Δa)

(ΔW = ΔZ * Y', ΔY = W' * ΔZ')

(nothing, ΔW, ΔY, Δx)

end

end

22 of 35

Zygote.jl - AD is a compiler problem

struct 𝒥_foo

W

Y

x

Z

𝒥tanh

end

(::𝒥_foo)(Δr) = ....

function ∇foo(W, Y, x)

Z = W * Y

a = Z * x

b = Y * x

c, 𝒥thanh = ∇tanh.(b)

r = a + c

(r, 𝒥_foo(W, Y, x, Z, 𝒥thanh))

end

function ∇foo(W, Y, x)

Z = W * Y

a = Z * x

b = Y * x

c, 𝒥tanh = ∇tanh.(b)

a + c, function (Δr)

Δc = Δr, Δa = Δr

(Δtanh, Δb) = 𝒥tanh(Δc)

(ΔY, Δx) = (Δb * x', Y' * Δb)

(ΔZ = Δa * x', Δx += Z' * Δa)

(ΔW = ΔZ * Y', ΔY = W' * ΔZ')

(nothing, ΔW, ΔY, Δx)

end

end

Closure conversion

The compiler builds the “tape” for us

23 of 35

AD as a compiler problem

Simple (Syntactic) - but requires optimizing compiler for performance

Partial Specialization (/DCE) => Partial Gradients

Better Compiler Optimizations ⇔ Faster AD

Nested AD for free

24 of 35

Even Control Flow!

function lstm(model, input, output)

hidden = initial(model)

total_loss = 0.

for sample in take(zip(input, output), 50)

(hidden, predicted) = model(sample, hidden)

total_loss += loss(output, predicted)

end

total_loss

end

function ∇lstm(model, input, output)

stack = Stack()

hidden, 𝒥initial = ∇initial(model)

total_loss = 0.

for sample in take(zip(input, output), 50)

(hidden, predicted), 𝒥model = ∇model(sample, hidden)

(total_loss +), 𝒥loss = loss(output, predicted)

push!(stack, (𝒥model, 𝒥loss))

end

total_loss, function(Δ)

Δmodel_total = zero(typeof(model))

Δhidden = nothing

for (𝒥model, 𝒥loss) in reverse(stack)

(Δloss, Δoutput, Δpredicted) = 𝒥loss(Δ)

(Δmodel_it, Δsample, Δhidden) = 𝒥model(Δhidden, Δpredicted)

Δmodel += Δmodel_it

# (... For input and output, but let's

# ignore those for simplicity)

end

Δmodel_total += 𝒥initial(Δhidden)

(Δmodel_total, ...)

end

end

25 of 35

Machine Learning

Compiler Backends

Fashionable Modelling with Flux (arXiv:1811.01457)

Building a Language and Compiler for Machine Learning (julialang:ml-language-compiler)

26 of 35

CUDAnative.jl - low level GPU Programming

function vadd(gpu, a, b, c)

i = threadIdx().x + blockDim().x *

((blockIdx().x-1) + (gpu-1) * gridDim().x)

@inbounds c[i] = a[i] + b[i]

return

end

a, b, c = (CuArray(...) for _ in 1:3)

@cuda threads=length(a) vadd(1, a, b, c)

julia> @device_code_ptx @cuda vadd(1, a, a, a)

//

// Generated by LLVM NVPTX Back-End

//

.visible .entry ptxcall_vadd_23(

.param .u64 ptxcall_vadd_23_param_0,

.param .align 8 .b8 ptxcall_vadd_23_param_1[16],

.param .align 8 .b8 ptxcall_vadd_23_param_2[16],

.param .align 8 .b8 ptxcall_vadd_23_param_3[16]

)

{

mov.u32 %r1, %tid.x;

mov.u32 %r2, %ntid.x;

mov.u32 %r3, %ctaid.x;

...

}

Provides:

  • CUDA intrinsics
  • SPMD Programming model
  • GPU memory management

...

Effective Extensible Programming:�Unleashing Julia on GPUs (arXiv:1712.03112)c

27 of 35

Performance

28 of 35

Julia all the way down

function vadd(a, b, c)

i = threadIdx().x

c[i] = a[i] + b[i]

return

end

W = randn(2, 10)

b = randn(2)

f(x) = softmax(W * x .+ b)

model = Chain(

Dense(10, 5, σ),

Dense(5, 2),

softmax)

29 of 35

Scaling Up

2017

2019

Scaling GANs to ExaScale16 on GPUs

(Work in Progress)

30 of 35

XLA.jl

struct XRTArray{T, Dims, N} <: AbstractArray{T, N}

# ...

end

XRTScalar{T} = XRTArray{T, (), 0}

function +(a::XRTScalar{T},

b::XRTScalar{T}) where T

GenericHloOp{:add}(T, ())(a, b)

End

+(a::XRTScalar, b::XRTScalar) =

+(promote(a, b)...)

Shape information represented in the type system

Declare primitives using multiple dispatch

Re-use vast amount of existing julia code (e.g. promotion, array abstractions, broadcast machinery)

Automatic Full Compilation of Julia Programs and ML Models to Cloud TPUs (arXiv:1810.09868)

31 of 35

Generic Code makes retargeting easy

function mapreduce(f, op, A; dims=:)

...

end

add_sum(x, y) = x + y

sum(f, a) = mapreduce(f, add_sum, a)

sum(a) = sum(identity, a)

In Julia Standard Library

In XLA.jl

function Base.mapreduce(f, op, A::XRTArray; dims=:)

dt = dims_tuple(A, dims)

res = HloReduce{Core.Typeof(op)}(dt)(op,

HloMap{Core.Typeof(f)}()(f, A),

XRTArray(zero(eltype(A)))

)

if dims != (:)

# Put back the dimensions that HloReduce dropped;

# Julia semantics require this.

res = HloReshape(

reduced_dimensions_collapes(size(A), dims))(res)

end

return res

end

Coverage of a large number of functions from a couple generic abstractions

Existing Julia array primitives map well to XLA primitives

32 of 35

Essentially: An embedding of HLO in Julia IR

Before Julia-level optimization

After Julia-level optimization

Statements correspond 1:1 to HLO

Trivial to convert to HLO .pb from here

dense(W, x, b) = W * x .+ b

33 of 35

XLA.jl

Compilation

Full Julia Semantics (over XLA primitives)

Control Flow

Takes the place of LLVM

Re-use Inference/Julia Optimizer / AD

~1000 LOC

No XLA-specific changes to core julia

Template for Julia as a frontend to future/experimental

IRs/Backends

34 of 35

Runs Well on TPU

Performance on par with TF

Scales to pods (512 TPU cores - 4.3 PF16/s on ResNet50)

35 of 35

Growing a language (youtube:_ahvzDzKdB0)

This leads me to claim that, from now on, a main goal in designing a language should be to plan for growth. The language must start small, and the language must grow as the set of users grows.

Guy Steele, “Growing a language”, 1998

www.cs.virginia.edu/~evans/cs655/readings/steele.pdf