1 of 56

Relax Figures

2 of 56

def symbolic_shape_fn(x: Tensor(("n", 2, 2), "f32")):

n, m = sym_var(), sym_var()

lv0: Tensor((n, 4), "f32") = reshape(x, shape(n, 4))� lv1: Tensor((n * 4,), "f32") = flatten(lv0)

lv2: Tensor(ndim=1, dtype="f32") = unique(lv1)

lv3 = match_cast(lv2, Tensor((m,), "f32"))

lv4: Tensor((m,), "f32") = exp(lv3)

return lv4

def any_shape_fn(x: Tensor((?, 2, 2), "f32")):

n = get_shape_value(x, axis=0)� lv0: Tensor((?, 4), "f32") = reshape(x, (n, 4))� lv1: Tensor((?,), "f32") = flatten(lv0)

lv2: Tensor(?, "f32") = unique(lv1)

lv3: Tensor(?, "f32") = exp(lv1)

return lv3

Shape annotation with unknown ? dimensions

Proposed approach: first class inter-operator symbolic shape

3 of 56

def symbolic_shape_fn(x: Tensor(("n", 2, 2), "f32")):

n, m = sym_var(), sym_var()

lv0: Tensor((n, 4), "f32") = reshape(x, shape(n, 4))� lv1: Tensor((n * 4,), "f32") = flatten(lv0)

lv2: Tensor(ndim=1, dtype="f32") = unique(lv1)

lv3 = match_cast(lv2, Tensor((m,), "f32"))

lv4: Tensor((m,), "f32") = exp(lv3)

return lv4

def any_shape_fn(x: Tensor((?, 2, 2), "f32")):

n = get_shape_value(x, axis=0)� lv0: Tensor((?, 4), "f32") = reshape(x, (n, 4))� lv1: Tensor((?,), "f32") = flatten(lv0)

lv2: Tensor(?, "f32") = unique(lv1)

lv3: Tensor(?, "f32") = exp(lv2)

return lv3

Shape annotation with unknown ? dimensions

First-class symbolic shape annotation

4 of 56

# Graph-level end-to-end dynamic ML model

def main(x: Tensor(("n", 128), "f32"), w: Tensor((128, 256), "f32")):

n = sym_var()

with dataflow(): � lv0: Tensor((n, 256), "f32") = call_tir(

mm, [x, w], Tensor((n, 256),"f32")

)

lv1: Tensor((n, 256), "f32") = relu(lv0)

lv2: Tensor((n, 256), "f32") = call_dps_library(

"cutlass.rms_norm", [lv1], Tensor((n, 256), "f32")

)

# Loop-level tensor programs

@tensorir_function

def mm(X: Buffer(("n", 128) "f32"), W: Buffer((128, 256), "f32"),

Y: Buffer(("n", 256), "f32")):

n = sym_var()

for i, j, k in grid(n, 256, 128):

with block():

with init():

Y[i, j] = 0

Y[i, j] += X[i, k] * W[k, j]

call library functions from graph level

call loop-level tensor programs from graph level

5 of 56

call library functions from graph level

# Graph-level end-to-end dynamic ML model

def graph_mlp_func(

x: Tensor(("n", 4096), "f16"),

w1: Tensor((4096, 11008), "f16"),

w2: Tensor((11008, 4096), "f16"),

):

n = sym_var()

lv0 = call_tir(func=matmul0, args=[x, w0],

out_shape=(n, 11008), sym_args=[n])

lv1: Tensor((n, 11008), "f16") = silu(lv0)

lv2: Tensor((n, 11008), "f16") = add(lv0, lv1)

lv3 = call_tir(func=matmul1, args=[lv3, w1],

out_shape=(n, 4096), sym_args=[n])

lv4 = call_dps_library(

"cutlass.rms_norm", args=[lv3],

out_shape=(n, 4096))

return lv4

# Loop-level tensor programs

def matmul0(

A: Tensor(("n", 4096), "f16"),

B: Tensor((4096, 11008), "f16"),

out: Tensor(("n", 11008), "f16"),

n: int,

):

# Computes "matmul(A, B)" and writes results

# to "out"

...

def matmul1(...):

...

call loop-level tensor programs from graph level

6 of 56

# Graph-level end-to-end dynamic ML model

def main(x: Tensor(("n", 128), "f32"), w: Tensor((128, 256), "f32")):

n = sym_var()

with dataflow(): � lv0: Tensor((n, 256), "f32") = call_tir(

mm, [x, w], Tensor((n, 256),"f32")

)

lv1: Tensor((n, 256), "f32") = relu(lv0)

lv2: Tensor((n, 256), "f32") = call_dps_library(

"cutlass.rms_norm", [lv1], Tensor((n, 256), "f32")

)

# Loop-level tensor programs

@tensorir_function

def mm(X: Buffer(("n", 128) "f32"), W: Buffer((128, 256), "f32"),

Y: Buffer(("n", 256), "f32")):

n = sym_var()

for i, j in grid(n, 256, 128):

with block():

with init():

Y[i, j] = 0

Y[i, j] += X[i, k] * W[k, j]

7 of 56

def call_tir(tir_func, args, annotation, sym_args):

# Allocate output tensor

output = alloc_tensor(annotation.shape, annotation.dtype)

# Call low-level function in destination-passing style

tir_func(*args, output, *sym_args)

return output

8 of 56

x

(2, n),"f32"

lv0 = exp(x)

lv1 = transpose(lv0)

lv2 = relu(lv1)

lv3 = transpose(lv2)

Allocate (2,n) for lv0

Allocate (n,2) for lv1

Allocate

(n,2) for lv2

Allocate (2,n) for lv3

Before memory planning:

After memory planning:

x

(2, n),"f32"

lv0 = exp(x)

lv1 = transpose(lv0)

lv2 = relu(lv1)

lv3 = transpose(lv2)

Allocate 2*n*4 bytes for s0

Allocate 2*n*4 bytes for s1

Instantiate

lv0:(2,n)

from s0

Instantiate

lv2:(n,2)

from s0

lv1:(n,2)

from s1

lv3:(2,n)

from s1

9 of 56

def main(x: Tensor(("n", 128), "f32"),

w: Tensor((128, 256), "f32")):

n = sym_var()

with dataflow():

lv1: Tensor((n, 256), "f32") = call_tir(

fused_mm_relu, [x, w], Tensor((n, 256),"f32")

)� return lv1

@tensorir_function

def fused_mm_relu(X: Buffer(("n", 128), "f32"),

W: Buffer((128, 256), "f32"),

Z: Buffer(("n", 256), "f32")):

n = sym_var()

Y = alloc_buffer((n, 256), "f32")

# matmul

for i, j, k in grid(n, 256, 128):

with block():

with init():

Y[i, j] = 0

Y[i, j] += X[i, k] * W[k, j]

# relu

for i, j in grid(n, 256):

with block():

Z[i, j] += max(Y[i, j], float32(0))

def main(x: Tensor(("n", 128), "f32"),

w: Tensor((128, 256), "f32")):

n = sym_var()

with dataflow():

lv1: Tensor((n, 256), "f32") = fused_mm_relu(x)� return lv1

def fused_mm_relu(x: Tensor(("n", 128), "f32"),

w: Tensor((128, 256), "f32")):

n = sym_var()

with dataflow(): � lv0: Tensor((n, 256), "f32") = call_tir(

mm, [x, w], Tensor((n, 256),"f32")

)

lv1: Tensor((n, 256), "f32") = relu(lv0)

return lv1

@tensorir_function

def mm(X: Buffer(("n", 128) "f32"),

W: Buffer((128, 256), "f32"),

Y: Buffer(("n", 256), "f32")):

func_attr("compute_pattern", "OutputEwiseFusible")

...

@tensorir_function

def relu(X: Buffer(("n", 256) "f32"),

Y: Buffer(("n", 256), "f32")):

func_attr("compute_pattern", "ElementWise")

...

Compute pattern analysis

+ FuseOps

+ FuseTensorIR

Initial Program

Follow-up scheduling

of TensorIR

def main(x: Tensor(("n", 128), "f32"),

w: Tensor((128, 256), "f32")):

n = sym_var()

with dataflow(): � lv0: Tensor((n, 256), "f32") = call_tir(

mm, [x, w], Tensor((n, 256),"f32")

)

lv1: Tensor((n, 256), "f32") = relu(lv0)

return lv1

@tensorir_function

def mm(X: Buffer(("n", 128) "f32"),

W: Buffer((128, 256), "f32"),

Y: Buffer(("n", 256), "f32")):

n = sym_var()

for i, j, k in grid(n, 256, 128):

with block():

with init():

Y[i, j] = 0

Y[i, j] += X[i, k] * W[k, j]

@tensorir_function

def relu(X: Buffer(("n", 256) "f32"),

Y: Buffer(("n", 256), "f32")):

n = sym_var()

for i, j in grid(n, 256):

with block():

Y[i, j] += max(X[i, j], float32(0))

10 of 56

def main(

x: Tensor(("n", 128), "f16")

Wdata: Tensor((128, 32), "u32"),

Wscale: Tensor((128, 8), "f16"),

):

n = sym_var()

with dataflow():

lv0: Tensor((n,256),"f16") = call_tir(

fused_decode_q4_mm,

[x, Wdata, Wscale],

Tensor((n, 256), "f16")

)� return lv0

@tensorir_function

def fused_decode_q4_mm(

X: Buffer(("n", 128), "f16"),

Wdata: Buffer((128, 32), "u32"),

Wscale: Buffer((128, 8), "f16"),

Y: Buffer(("n", 256), "f16"),

):

n = sym_var()

W = alloc_buffer((128, 256), "f16")

# decode_q4

for k, j in grid(128, 256):

W[k, j] = (

(data[k, j//8] >> (k%8*4)) & 15 - 7

) * scale[i, k // 32]

# matmul

for i, j, k in grid(n, 256, 128):

if k == 0:

Y[i, j] = 0.0

Y[i, j] += X[i, k] * W[k, j]

Follow-up scheduling

of TensorIR

def main(x: Tensor(("n", 128), "f16")

Wdata: Tensor((128, 32), "u32"),

Wscale: Tensor((128, 8), "f16")):

n = sym_var()

with dataflow():� W: Tensor((128, 256), "f16") = call_tir(� decode_q4,

[Wdata, Wscale],

Tensor((128, 256), "f16")

)

lv0: Tensor((n, 256), "f16") = call_tir(

mm, [x, W], Tensor((n, 256), "f16")

)

return lv0

@tensorir_function

def decode_q4(

Wdata: Buffer((128, 32), "u32"),

Wscale: Buffer((128, 8), "f16"),

W: Buffer((128, 128), "f16"),

):

for k, j in grid(128, 256):

W[k, j] = (

(data[k, j//8] >> (k%8*4)) & 15 - 7

) * scale[i, k // 32]

@tensorir_function

def mm(X: Buffer(("n", 128), "f16"),

W: Buffer((128, 256), "f16"),

Y: Buffer(("n", 256), "f16")):

n = sym_var()

for i, j, k in grid(n, 256, 128):

if k == 0:

Y[i, j] = 0.0

Y[i, j] += X[i, k] * W[k, j]

+ FuseTensorIR

Initial Program

Compute pattern analysis

+ FuseOps

def main(x: Tensor(("n", 128), "f16")

Wdata: Tensor((128, 32), "u32"),

Wscale: Tensor((128, 8), "f16")):

n = sym_var()

with dataflow():

lv0: Tensor((n,256),"f16") = fused_decode_q4_mm(

x, Wdata, Wscale

)� return lv0

def fused_decode_q4_mm(x:Tensor(("n",128),"f16")

Wdata:Tensor((128,32),"u32"),

Wscale:Tensor((128,8),"f16")):

n = sym_var()

with dataflow():� W: Tensor((128, 256), "f16") = call_tir(� decode_q4,[Wdata,Wscale],Tensor((128,256),"f16")

)

lv0: Tensor((n, 256), "f16") = call_tir(

mm, [x, W], Tensor((n, 256),"f16")

)

return lv0

@tensorir_function

def decode_q4(Wdata: Buffer((128, 32), "u32"),

Wscale: Buffer((128, 8), "f16"),

W: Buffer((128, 128), "f16")):

func_attr("compute_pattern", "Injective")

...

@tensorir_function

def mm(X: Buffer(("n", 128), "f16"),

W: Buffer((128, 256), "f16"),

Y: Buffer(("n", 256), "f16")):

func_attr("compute_pattern", "OutputEwiseFusible")

...

11 of 56

@tensorir_function

def decode_q4(

Wdata: Buffer((128, 16) "u32"),

Wscale: Buffer((128, 4), "f16"),

W: Buffer((128, 128), "f16")

):

for x, y in grid(128, 128):

W[x, y] = (

(data[x, y // 8] >> (y % 8 * 4)) & 15 - 7

) * scale[x, y // 32]

def main(

x: Tensor((1, 128), "f16"),

Wdata: Tensor((128, 16) "u32"),

Wscale: Tensor((128, 4), "f16")

):

with dataflow():

W0: Tensor((128, 128), "f16") = call_tir(� decode_q4, [Wdata, Wscale], � Tensor((128, 128), "f16")

)

lv0: Tensor((128, 128), "f16") = matmul(x, W0)

...

@tensorir_function

def fused_decode_q4_mm(

X: Buffer((1, 128) "f16"),

Wdata: Buffer((128, 16) "u32"),

Wscale: Buffer((128, 4), "f16"),

Y: Buffer((1, 128), "f16")

):

W = alloc_buffer((128, 128), "f16")

for x, y in grid(128, 128):

W[x, y] = (

(data[x, y // 8] >> (y % 8 * 4)) & 15 - 7

) * scale[x, y // 32]

for i, j, k in grid(1, 128, 128):

with block():

with init():

Y[i, j] = 0

Y[i, j] += X[i, k] * W[k, j]

def main(

x: Tensor((1, 128), "f16"),

Wdata: Tensor((128, 16) "u32"),

Wscale: Tensor((128, 4), "f16")

):

with dataflow():

lv0: Tensor((1, 128), "f16") = call_tir(� fused_decode_q4_mm, [x, Wdata, Wscale],

Tensor((1, 128), "f16")

)

...

12 of 56

@tensorir_function

def mm(X: Buffer(("n", 512) "f32"), W: Buffer((512, 1536), "f32"),

Y: Buffer(("n", 1536), "f32")):

n = sym_var()

for i, j, k in grid(n, 1536, 512):

with block():

with init():

Y[i, j] = 0

Y[i, j] += X[i, k] * W[k, j]�

def main(x: Tensor(("n", 512), "f32"), w: Tensor((512, 1536), "f32")):

n = sym_var()

with dataflow(): � lv0: Tensor((n, 1536), "f32") = call_tir(mm, [x, w], Tensor((n, 1536),"f32"))

lv1: Tensor((n, 3, 8, 64), "f32") = reshape(lv0, shape(n, 3, 8, 64))

lv2: Tensor((n, 8, 64), "f32") = call_dps_library(

"cutlass.attention", [lv1], Tensor((n, 8, 64), "f32")

)

ML models

Target Code

ML models

Our Approach

First-class symbolic shapeannotations to enable�dynamic shape aware analysis and optimizations

Cross-level abstraction that encapsulates high-level computation, tensor program, libraries and their interactions

Computational Graph IR

Tensor Program IR

Target Code

Multi-level ML Compilers

Multi-level abstractions�with optimizations in�each level

single-shot

lowering

other optional layers of IRs

Libraries

Composable optimizations �across dynamic shape aware computational graph, tensor program, and libraries

13 of 56

Code Generation

Cross-level abstraction that encapsulates high-level computation, tensor program, libraries and their interactions

@tensorir_function

def mm(X: Buffer(("n", 512) "f32"), W: Buffer((512, 1536), "f32"),

Y: Buffer(("n", 1536), "f32")):

n = sym_var()

for i, j, k in grid(n, 1536, 512):

with block():

with init():

Y[i, j] = 0

Y[i, j] += X[i, k] * W[k, j]�

def main(x: Tensor(("n", 512), "f32"), w: Tensor((512, 1536), "f32")):

n = sym_var()

with dataflow(): � lv0: Tensor((n, 1536), "f32") = call_tir(mm, [x, w], Tensor((n, 1536),"f32"))

lv1: Tensor((n, 3, 8, 64), "f32") = reshape(lv0, shape(n, 3, 8, 64))

lv2: Tensor((n, 8, 64), "f32") = call_dps_library(

"cutlass.attention", [lv1], Tensor((n, 8, 64), "f32")

)

ML models

Target Code and Deployable Module

ML models

Our Approach

First-class symbolic shapeannotations to enable�dynamic shape aware analyses and optimizations

Computational Graph IR

Tensor Program IR

Target Code

Multi-level ML Compilers

Multi-level abstractions�with optimizations in�each level

single-shot

lowering

other optional layers of IRs

Libraries

Composable optimizations �across dynamic shape aware computational graph, tensor program, and libraries

14 of 56

Code Generation

Cross-level abstraction that encapsulates high-level computation graph, tensor programs, libraries and their interactions

@tensorir_function

def mm(X: Buffer(("n", 512) "f32"), W: Buffer((512, 1536), "f32"),

Y: Buffer(("n", 1536), "f32")):

n = sym_var()

for i, j, k in grid(n, 1536, 512):

with block():

with init():

Y[i, j] = 0

Y[i, j] += X[i, k] * W[k, j]�

def main(x: Tensor(("n", 512), "f32"), w: Tensor((512, 1536), "f32")):

n = sym_var()

with dataflow(): � lv0: Tensor((n, 1536), "f32") = call_tir(mm, [x, w], Tensor((n, 1536),"f32"))

lv1: Tensor((n, 3, 8, 64), "f32") = reshape(lv0, shape(n, 3, 8, 64))

lv2: Tensor((n, 8, 64), "f32") = call_dps_library(

"cutlass.attention", [lv1], Tensor((n, 8, 64), "f32")

)

ML models

Target Code and Deployable Module

ML models

Our Approach

First-class symbolic shapeannotations to enable�dynamic shape aware analyses and optimizations

Computational Graph IR

Tensor Program IR

Target Code

Multi-level ML Compilers

Multi-level abstractions�with optimizations in�each level

single-shot

lowering

other optional layers of IRs

Libraries

Cross-level optimizations �across computational graph,

tensor program, and libraries,

including:

- cross-level dynamic

shape-aware operator fusion

- dynamic shape-aware memory

planning

- cross-level tensor program

workspace lifting

- CUDA Graph offloading

- tensor program optimizations

via partial lowering

15 of 56

@tensorir_function

def mm(X: Buffer(("n", 512) "f32"), W: Buffer((512, 1536), "f32"),

Y: Buffer(("n", 1536), "f32")):

n = sym_var()

for i, j, k in grid(n, 1536, 512):

with block():

with init():

Y[i, j] = 0

Y[i, j] += X[i, k] * W[k, j]�

def main(x: Tensor(("n", 512), "f32"), w: Tensor((512, 1536), "f32")):

n = sym_var()

with dataflow(): � lv0: Tensor((n, 1536), "f32") = call_tir(mm, [x, w], Tensor((n, 1536),"f32"))

lv1: Tensor((n, 3, 8, 64), "f32") = reshape(lv0, shape(n, 3, 8, 64))

lv2: Tensor((n, 8, 64), "f32") = call_dps_library(

"cutlass.attention", [lv1], Tensor((n, 8, 64), "f32")

)

Construct and import

Universal deployment

Composable

optimizations

Graph-level optimizations

Tensor program optimizations

16 of 56

SLM

Synergy Ecosystem

LLM.compile

AutoLLM

Model def

Quantize

TVM Unity Abstraction�Compilation

Computational graph

Platform Dependent optimization

dlight

Nv kernel/

library dispatch

Memory planning

fusion

distributed

Paged KV

MultiGPU/

Disco

Runtime

AWQ

NCCL

CUTLASS/TRT

import

17 of 56

Annotation

Examples

Explanation

Object

Object

Any runtime value

Shape

Shape([n, 4])

Shape(ndim=2)

Symbolic shape value (n, 4)�Shape with two dimensions

Tensor

Tensor((n, 4), "f32")

Tensor(ndim=None, dtype="f32")

Tensor with symbolic shape (n, 4)�Tensor with unknown dimensions

Tuple

Tuple[Tensor((n, 4), "f32"), Object]

Tuple of a Tensor and an Object

Callable

Callable(

[Tensor(("n", 4), "f32")],

Tensor(("n * 4",), "f32")

)

Function that takes a (n, 4) Tensor�and returns a (n*4,) Tensor

18 of 56

Annotation

def main(x: Tensor(("n", 4), "f32")) -> Tensor(("n * 2",), "f32"):

n = sym_var()

with dataflow():

lv0: Tensor((n, 4), "f32") = call_tir(exp, [x], Tensor((n, 4), "f32"))

f0: Callable = subfunc

lv1: Tensor((n * 4,), "f32") = subfunc(lv0)

lv2: Tuple[

Tensor((n * 2,), "f32"), Tensor((n * 2,), "f32")

] = split(lv1, sections=2)

lv3: Tensor((n * 2,), "f32") = lv2[0]

return lv3

def subfunc(x: Tensor(("n", 4), "f32")) -> Tensor(("n * 4",), "f32"):

n = sym_var()

with dataflow():

lv0: Tensor((n * 4,), "f32") = flatten(x)

lv1: Tensor((n * 4,), "f32") = relu(lv0)

return lv1

@tensorir_function

def exp(X: Buffer(("n", 4) "f32"), Y: Buffer(("n", 4), "f32")):� ...

Dataflow block

Cross-function call

Cross-level function call

19 of 56

Annotation

def main(x: Tensor(("n", 4), "f32")) -> Tensor(("n * 2",), "f32"):

n = sym_var()

with dataflow():

lv0: Tensor((n, 4), "f32") = call_tir(exp, [x], Tensor((n, 4), "f32"))

lv1: Tensor((n, 4), "f32") = call_dps_library(

"cutlass.rms_norm", [lv0], Tensor((n, 4), "f32")

)

f0: Callable = subfunc

lv2: Tensor((n * 4,), "f32") = subfunc(lv1)

lv3: Tuple[

Tensor((n * 2,), "f32"), Tensor((n * 2,), "f32")

] = split(lv2, sections=2)

lv4: Tensor((n * 2,), "f32") = lv3[0]

return lv4

def subfunc(x: Tensor(("n", 4), "f32")) -> Tensor(("n * 4",), "f32"):

n = sym_var()

with dataflow():

lv0: Tensor((n * 4,), "f32") = flatten(x)

lv1: Tensor((n * 4,), "f32") = relu(lv0)

return lv1

@tensorir_function

def exp(X: Buffer(("n", 4) "f32"), Y: Buffer(("n", 4), "f32")):� ...

Dataflow block

Subgraph function call

Foreign function call

Foreign function (tensor program)

20 of 56

def subfn(s: Shape(["n", "m"])) -> Tensor(("n * m",), "f32"):

...

def subgraph_func_shape_deduce_example(� x: Tensor(("n",), "f32"),� y: Shape(ndim=2)�):

n = sym_var()

f0: Callable([Shape(["n", "m"])], Tensor(("n * m",), "f32")) = subfn� lv0: Tensor((n * 4,), "f32") = f0(shape(n, 4))

lv1: Tensor((12,), "f32") = subfn(shape(3, 4))

lv2: Tensor(((n + 1) * 4,), "f32") = subfn(shape(n + 1, 4))� lv3: Tensor(ndim=1, dtype="f32") = subfn(y)

...

21 of 56

Extra symbolic shape parameter to pass in n

Parameter annotation can

be an expression

Call into fused function

def main(x: Tensor(("n", 2), "f32")):

n = sym_var()

lv0: Tensor(("2 * n",), "f32") = flatten(x)�� lv1: Tensor(("2 * n",), "f32") = add(lv0, lv0)

lv2: Tensor(("2 * n",), "f32") = relu(lv1)� ...

Before fusion

After fusion

Regions to fuse

def fused_add_relu(

x: Tensor(("n * 2",), "f32"),

y: Tensor(("n * 2",), "f32"),

s: Shape(["n"])

) -> Tensor(("n * 2",), "f32"):

lv0 = add(x, y)

lv1 = relu(lv0)

return lv1

def main(x: Tensor(("n", 2), "f32")):

n = sym_var()

lv0: Tensor(("2 * n",), "f32") = flatten(x)� lv1: Tensor(("2 * n",), "f32") = fused_add_relu(lv0, lv0, shape(n))� ...

22 of 56

def subfn(s: Tensor(["n", "m"])) -> Tensor(("n * m", ), "f32"):

lv0 = empty(s)� lv1 = flatten(lv0)

return lv0

def fallback_example(x: Tensor(ndim=1, "f32")):

n = sym_var()

f0: Callable([Shape(["n", "m"])], Tensor(("n * m", ), "f32")) = subfn� lv0: Tensor((n * 4,), "f32") = f0(shape(n, 4))

lv1: Tensor((12,), "f32") = subfn(shape(3, 4))

lv2: Tensor(((n + 1) * 4,), "f32") = subfn(shape(n + 1, 4))

...

23 of 56

program

program

24 of 56

25 of 56

26 of 56

27 of 56

def main(x: Tensor(("n", 128), "f32"), w: Tensor((128, 256), "f32")):

n = sym_var()

with dataflow(): � lv0: Tensor((n, 256), "f32") = call_tir(

mm, [x, w], Tensor((n, 256),"f32")

)

lv1: Tensor((n, 256), "f32") = rms_norm(lv0)

lv1: Tensor((n, 256), "f32") = call_dps_library(

"cutlass.rms_norm", [lv0], Tensor((n, 256), "f32")

)

...

@tensorir_function

def mm(X: Buffer(("n", 128) "f32"),

W: Buffer((128, 256), "f32"),

Y: Buffer(("n", 256), "f32")):

n = sym_var()

for i, j, k in grid(n, 256, 128):

with block():

with init():

Y[i, j] = 0

Y[i, j] += X[i, k] * W[k, j]

I. Replacement from composable partial lowering passes

TensorIR function body

Program analysis and transformations

Optimized function body

II. Loop-level TensorIR optimization

update

28 of 56

def main(x: Tensor(("n", 128), "f32"), w: Tensor((128, 256), "f32")):

n = sym_var()

with dataflow(): � lv0: Tensor((n, 256), "f32") = call_tir(

mm, [x, w], Tensor((n, 256),"f32")

)

lv1: Tensor((n, 256), "f32") = rms_norm(lv0)

lv1: Tensor((n, 256), "f32") = call_dps_library(

"cutlass.rms_norm", [lv0], Tensor((n, 256), "f32")

)

...

@tensorir_function

def mm(X: Buffer(("n", 128) "f32"),

W: Buffer((128, 256), "f32"),

Y: Buffer(("n", 256), "f32")):

n = sym_var()

for i, j, k in grid(n, 256, 128):

with block():

with init():

Y[i, j] = 0

Y[i, j] += X[i, k] * W[k, j]

I. Replacement from composable partial lowering passes

TensorIR function body

Program analysis and transformations

Optimized function body

II. Loop-level TensorIR optimization

update

29 of 56

Before cross-level workspace lifting

Lift allocation

to graph level

def main(x: Tensor(("n", 2048), "f32"),

w: Tensor((2048, 4096), "f32")):

n = sym_var()� lv0: Tensor((n, 4096), "f32") = call_tir(

mm_split_k, [x, w], Tensor((n, 4096),"f32")

)

return lv0

@tensorir_function

def mm_split_k(X: Buffer(("n", 2048) "f32"),

W: Buffer((2048, 4096), "f32"),

Y: Buffer(("n", 4096), "f32")):

n = sym_var()

workspace = alloc_buffer(8*1024*1024,"f32","global")

for i, j, k0, k1 in grid(n, 4096, ..., ...):

# Write partial accumulation of X*W into workspace

for i, j, k0 in grid(n, 4096, ...):

# Accumulate values in workspace and write to Y

After cross-level workspace lifting

def main(x: Tensor(("n", 2048), "f32"),

w: Tensor((2048, 4096), "f32")):

n = sym_var()

workspace = alloc_tensor((8*1024*1024,), "f32")� lv0: Tensor((n, 4096), "f32") = call_tir(

mm_split_k, [x, w, workspace],

Tensor((n, 4096),"f32")

)

return lv0

@tensorir_function

def mm_split_k(X: Buffer(("n", 2048) "f32"),

W: Buffer((2048, 4096), "f32"),

workspace: Buffer((8*1024*1024,), "f32"),

Y: Buffer(("n", 4096), "f32")):

n = sym_var()

for i, j, k0, k1 in grid(n, 4096, ..., ...):

# Write partial accumulation of X*W into workspace

for i, j, k0 in grid(n, 4096, ...):

# Accumulate values in workspace and write to Y

30 of 56

def optim_and_lowering_pipeline(mod, target):

passes = SequentialPass(

PartialLibraryLowering(target), # §4.5

LowerOperatorToTensorProgram(), # §4.6

CrossLevelDynShapeAwareOperatorFusion(), # §4.1

TensorProgramOptim(), # §4.5

CrossLevelTensorProgramWorkspaceLift(), # §4.3

DynShapeAwareMemoryPlanning(), # §4.2

CUDAGraphOffloading(target), # §4.4

BuildRunnableModule(target), # §4.6

)

return passes(mod)

31 of 56

Partial Library Lowering (§4.5)

Operator to Tensor Program Lowering (§4.6)

Dynamic Shape–Aware Operator Fusion (§4.1)

Tensor Program Optimizations (§4.5)

Tensor Program Workspace Lift (§4.3)

Dynamic Shape–Aware Memory Planning (§4.2)

CUDA Graph Offloading (§4.4)

Build to Runnable Module (§4.6)

Cross-level optims

Graph-level optims

Tensor program optims

Lowering passes

32 of 56

@tvm.script.ir_module

class MyIRModule:

# This is a TIR PrimFunc which calls the TIR intrinsic T.exp.

@T.prim_func

def tir_exp_func(x: T.handle, y: T.handle): ## <= D2

X = T.match_buffer(x, (n,), "float32")

Y = T.match_buffer(y, (n,), "float32")

with T.grid(n) as i:

Y[i] = T.exp(X[i])

# This is a Relax function which contains a dataflow block

# representing a computational graph, as well as a call to an

# opaque packed function which performs an in-place update to the

# data in the variable gv0.

@R.function

def relax_func(x: R.Tensor[(n, k), "float32"], w: R.Tensor[_, "float32"]):

# n, k above are implicitly defined within the function signature

# so we will be able to refer to n, k within all of relax_func

with R.dataflow(): ## <= D2

lv0 = R.match_shape(w, (k, m)) ## <= D1

lv1: R.Tensor[(n, m), "float32"] = R.dot(x, lv0)

lv2: R.Tensor[(n * m,), "float32"] = R.flatten(lv1) ## <= D1

lv3: R.Shape = (n * m,) ## <= D1

gv0 = R.call_tir(tir_exp_func, [lv2], lv3, dtype="float32") ## <= D0

R.outputs(gv0)

33 of 56

Before

In scope Variables Scope

After

Tensor((n, 4), "f32"))

Object

Any runtime value

Shape

Shape([n, 4])

Shape(ndim=2)

Symbolic shape value (n, 4)�Shape with two elements

Tensor

Tensor(ndim=None, dtype="f32"))

Tensor with symbolic shape (n, 4)�Tensor with unknown dimensions

Tuple

Tuple(Tensor((n, 4), "f32")), Object)

Tuple of a Tensor and an Object

Callable

Callable(

[Tensor(("n", 4), "f32")],

Tensor(("n * 4", ), "f32")

)

Function that takes a (n, 4) Tensor�and returns a (n * 4,) Tensor

34 of 56

def main(a: Tensor(("n", 128), "f32"),

b: Tensor((128, 256), "f32")):

n = sym_var(upper_bound=64)

...

storage_a = alloc_storage((64, 128), "f32")

storage_b = alloc_storage((128, 256), "f32")

storage_c = alloc_storage((64, 256), "f32")

c: Tensor((n, 256), "f32") = builtin.cuda_graph_run_or_capture(

subgraph_matmul_gelu, a, b, storage

)

return c

def main(input: Tensor(("n",), "f32")):

n = sym_var(upper_bound=64)

...

storage_a = alloc_storage((64, 128), "f32")

storage_b = alloc_storage((128, 256), "f32")

storage_c = alloc_storage((64, 256), "f32")

a: Tensor((n, 128), "f32") = alloc_tesnor(storage_a, (n, 128))

b: Tensor((128, 256), "f32") = alloc_tesnor(storage_b, (128, 256))

c: Tensor((n, 256), "f32") = subgraph_matmul_gelu(a, b, storage_c)

return c

# computation subgraph with statically allocated memory

def subgraph_matmul_gelu(a: Tensor(("n", 128), "f32"),

b: Tensor((128, 256), "f32"),

storage_c: Object):

n = sym_var(upper_bound=64)

c: Tensor((n, 256), "f32") = alloc_tensor(storage, (n, 256))

func_mm_gelu(a, b, c) # invoking low-level DPS tensor program

return c

Before CUDA Graph Offloading

After CUDA Graph Offloading

rewrite static subgraph calls to builtin CUDA Graph functions

35 of 56

def main(a: Tensor(("n", 64), "f32"),

b: Tensor((64, 64), "f32")):

n = sym_var(upper_bound=32)

...

storage_a = alloc_storage((32, 64), "f32")

storage_b = alloc_storage((64, 64), "f32")

a: Tensor((n, 64), "f32") = alloc_tesnor(storage_a, (n, 64))

b: Tensor((64, 64), "f32") = alloc_tesnor(storage_b, (64, 64))

...

storage_c = alloc_storage((32, 64), "f32")

storage_d = alloc_storage((32, 64), "f32")

storage_e = alloc_storage((32, 64), "f32")

e: Tensor((n, 64), "f32") = builtin.cuda_graph_run_or_capture(

subgraph, a, b, storage_c, storage_d, storage_e

)

return e

# computation subgraph with statically allocated memory

def subgraph(a: Tensor(("n", 64), "f32"), b: Tensor((64, 64), "f32"),

storage_c: Object, storage_d: Object, storage_e: Object):

n = sym_var(upper_bound=32)

c: Tensor((n, 64), "f32") = alloc_tensor(storage_c, (n, 64))

func_matmul(a, b, c) # low-level DPS matmul tensor program

d: Tensor((n, 64), "f32") = alloc_tensor(storage_d, (n, 64))

func_gelu(c, d) # low-level DPS GeLU tensor program

e: Tensor((n, 64), "f32") = alloc_tensor(storage_e, (n, 64))

func_add(a, d, e) # low-level DPS add tensor program

return e

def main(input: Tensor(("n",), "f32")):

n = sym_var(upper_bound=32)

...

a: Tensor((n, 64), "f32") = …

b: Tensor((64, 64), "f32") = …

c: Tensor((n, 64), "f32") = matmul(a, b)

d: Tensor((n, 64), "f32") = gelu(c)

e: Tensor((n, 64), "f32") = add(a, d)

return e

Before CUDA Graph Offloading

After CUDA Graph Offloading

36 of 56

AI

workloads

have

AI

workloads

have

dynamic

AI

workloads

have

dynamic

patterns

t = 0

t = 1

t = 2

Transformer Model

37 of 56

lv0

Partial lowering

Analysis feedback

Cross-level transform

call_dps_library

call_tir

high-level

operators

temp workspace allocation in

tensor program

lv1

lv2

tensor

program

for i:� B[i] = max(A[i], 0)

lv1 is an element-wise operator and is invariant to scaling. �Reduces cost of annotating these properties per op. Generalize support to more custom ops.

ws = alloc()� ...

ws = param[2]� ...

alloc in graph-level

passed as param to tensor program

lv0

lv1

lv2

lv0

lv1

lv2

lv0

lv1

lv2

lv0

lv1

lv2

38 of 56

@tensorir_function

def mm(X: Buffer(("n", 512) "f32"), W: Buffer((512, 1536), "f32"),

Y: Buffer(("n", 1536), "f32")):

n = sym_var()

for i, j, k in grid(n, 1536, 512):

with block():

with init():

Y[i, j] = 0

Y[i, j] += X[i, k] * W[k, j]�

def main(x: Tensor(("n", 512), "f32"), w: Tensor((512, 1536), "f32")):

n = sym_var()

with dataflow(): � lv0: Tensor((n, 1536), "f32") = call_tir(mm, [x, w], Tensor((n, 1536),"f32"))

lv1: Tensor((n, 3, 8, 64), "f32") = reshape(lv0, shape(n, 3, 8, 64))

lv2: Tensor((n, 8, 64), "f32") = call_dps_library(

"cutlass.attention", [lv1], Tensor((n, 8, 64), "f32")

)

ML models

Target Code

ML models

Our Approach

First-class symbolic shapeannotations to enable�dynamic shape aware analysis and optimizations

Cross-level abstraction that encapsulates high-level computation, tensor program, libraries and their interactions

Computational Graph IR

Tensor Program IR

Target Code

Multi-level ML Compilers

Multi-level abstractions�with optimizations in�each level

single-shot

lowering

other optional layers of IRs

Libraries

Composable optimizations �across dynamic shape aware computational graph, tensor program, and libraries

39 of 56

nn.Module

Relax passes

Dlight, FlashInfer, TIR

Runtime

Disco

LLM

image

audio

...

Serve

import/compose

transform

tensor-level

runtime

40 of 56

def mm(

A: Buffer((64, 64) "f32"),

B: Buffer((64, 64), "f32"),

C: Buffer((64, 64), "f32")

):

for x, y, k in grid(64, 64, 64):

with block():

with init():

C[x, y] = 0

C[x, y] += A[x, k] * B[y, j]

Tensor Program

Libraries

cublasGemm(DLTensor* A,� DLTensor* B,

DLTensor* C,

float alpha,

float beta)

cutlass_attention(DLTensor* in,� DLTensor* out)

matmul

reshape

relu

attention

w

x

Computational Graph

Specialized

Hardware Primitives

41 of 56

Cross-level computational graph abstraction

  • Combines compute graph and tensor programs
  • Tensor programs can be further refined, and we will discuss in T2
  • Cross level optimizations
    • Feedback
    • Automatic annotation, fusion
    • Partial lowering

42 of 56

@tensorir_function

def mm(X: Buffer((1, 512) "f32"), W: Buffer((512, 1536), "f32"),

Y: Buffer((1, 1536), "f32")):

for i, j, k in grid(1, 1536, 512):

with block():

with init():

Y[i, j] = 0

Y[i, j] += X[i, k] * W[k, j]�

def main(x: Tensor((1, 512), "f32"), w: Tensor((512, 1536), "f32")):

with dataflow(): � lv0: Tensor((1, 1536), "f32") = call_tir(mm, [x, w], Tensor((1, 1536),"f32"))

lv1: Tensor((1, 3, 8, 64), "f32") = reshape(lv0, shape(1, 3, 8, 64))

lv2: Tensor((1, 8, 64), "f32") = relu(lv1)

lv3: Tensor((1, 8, 64), "f32") = call_dps_library(

"cutlass.attention", [lv2], Tensor((1, 8, 64), "f32")

) ...

Tensor program

High-level computational graph

call_tir

reshape

relu

call_dps_library

w

x

Code

View

Compute Graph View

Tensor program as a compute graph node

Library as a compute graph node

43 of 56

def call_tir(tir_func, args, out_ann):

# Allocate output tensor

output = alloc_tensor(out_ann.shape, out_ann.dtype)

# Call low-level function in destination-passing style

tir_func(*args, output)

return output

44 of 56

def main(x: Tensor((8, 128), "f32"),

w: Tensor((128, 256), "f32")):

with dataflow():

lv1: Tensor((8, 256), "f32") = call_tir(

fused_mm_relu, [x, w], Tensor((8, 256),"f32")

)� return lv1

@tensorir_function

def fused_mm_relu(X: Buffer((8, 128), "f32"),

W: Buffer((128, 256), "f32"),

Z: Buffer((8, 256), "f32")):

Y = alloc_buffer((8, 256), "f32")

# matmul

for i, j, k in grid(8, 256, 128):

with block():

with init():

Y[i, j] = 0

Y[i, j] += X[i, k] * W[k, j]

# relu

for i, j in grid(8, 256):

with block():

Z[i, j] += max(Y[i, j], float32(0))

def main(x: Tensor((8, 128), "f32"),

w: Tensor((128, 256), "f32")):

with dataflow():

lv1: Tensor((8, 256), "f32") = fused_mm_relu(x)� return lv1

def fused_mm_relu(x: Tensor((8, 128), "f32"),

w: Tensor((128, 256), "f32")):

with dataflow(): � lv0: Tensor((8, 256), "f32") = call_tir(

mm, [x, w], Tensor((8, 256),"f32")

)

lv1: Tensor((8, 256), "f32") = relu(lv0)

return lv1

@tensorir_function

def mm(X: Buffer((8, 128) "f32"),

W: Buffer((128, 256), "f32"),

Y: Buffer((8, 256), "f32")):

func_attr("compute_pattern", "OutputEwiseFusible")

...

@tensorir_function

def relu(X: Buffer((8, 256) "f32"),

Y: Buffer((8, 256), "f32")):

func_attr("compute_pattern", "ElementWise")

...

Compute pattern analysis

+ FuseOps

+ FuseTensorIR

Initial Program

Follow-up scheduling

of TensorIR

def main(x: Tensor((8, 128), "f32"),

w: Tensor((128, 256), "f32")):

with dataflow(): � lv0: Tensor((8, 256), "f32") = call_tir(

mm, [x, w], Tensor((n, 256),"f32")

)

lv1: Tensor((8, 256), "f32") = relu(lv0)

return lv1

@tensorir_function

def mm(X: Buffer((8, 128) "f32"),

W: Buffer((128, 256), "f32"),

Y: Buffer((8, 256), "f32")):

for i, j, k in grid(8, 256, 128):

with block():

with init():

Y[i, j] = 0

Y[i, j] += X[i, k] * W[k, j]

@tensorir_function

def relu(X: Buffer((8, 256) "f32"),

Y: Buffer((8, 256), "f32")):

for i, j in grid(8, 256):

with block():

Y[i, j] += max(X[i, j], float32(0))

45 of 56

lv0

Partial lowering

Analysis feedback

Cross-level transform

call_dps_library

call_tir

high-level

operators

temp workspace allocation in

tensor program

lv1

lv2

tensor

program

for i:� B[i] = max(A[i], 0)

lv1 is an element-wise operator and is invariant to scaling. �Reduces cost of annotating these properties per op. Generalize support to more custom ops.

ws = alloc()� ...

ws = param[2]� ...

alloc in graph-level

passed as param to tensor program

lv0

lv1

lv2

lv0

lv1

lv2

lv0

lv1

lv2

lv0

lv1

lv2

46 of 56

Composable Tensor Data Encoding

  • Atom mixed precision
  • Leverage cross-level abstraction to represent decoding logic
  • LoRA
  • Ensure consistency of encoding across layers

47 of 56

@tensorir_function

def decode_q4(

Wdata: Buffer((128, 16) "u32"),

Wscale: Buffer((128, 4), "f16"),

W: Buffer((128, 128), "f16")

):

for x, y in grid(128, 128):

W[x, y] = (

(data[x, y // 8] >> (y % 8 * 4)) & 15 - 7

) * scale[x, y // 32]

def main(

x: Tensor((1, 128), "f16"),

Wdata: Tensor((128, 16) "u32"),

Wscale: Tensor((128, 4), "f16")

):

with dataflow():

W0: Tensor((128, 128), "f16") = call_tir(� decode_q4, [Wdata, Wscale], � Tensor((128, 128), "f16")

)

lv0: Tensor((128, 128), "f16") = matmul(x, W0)

...

@tensorir_function

def fused_decode_q4_mm(

X: Buffer((1, 128) "f16"),

Wdata: Buffer((128, 16) "u32"),

Wscale: Buffer((128, 4), "f16"),

Y: Buffer((1, 128), "f16")

):

W = alloc_buffer((128, 128), "f16")

for x, y in grid(128, 128):

W[x, y] = (

(data[x, y // 8] >> (y % 8 * 4)) & 15 - 7

) * scale[x, y // 32]

for i, j, k in grid(1, 128, 128):

with block():

with init():

Y[i, j] = 0

Y[i, j] += X[i, k] * W[k, j]

def main(

x: Tensor((1, 128), "f16"),

Wdata: Tensor((128, 16) "u32"),

Wscale: Tensor((128, 4), "f16")

):

with dataflow():

lv0: Tensor((1, 128), "f16") = call_tir(� fused_decode_q4_mm, [x, Wdata, Wscale],

Tensor((1, 128), "f16")

)

...

Intermediate lowering step, explicitly

unpack the decode, then further fuse matmul and dequantize

Enable customization

48 of 56

49 of 56

Distributed Annotation

  • Each data contains distributed data sharding
  • Communication can be represented within TensorIR and compute graph

50 of 56

@tensorir_function

def mm(A: Buffer((128, 128) "f32"), B: Buffer((128, 128), "f32"),

C: Buffer((128, 128), "f32")):

C_partial = alloc_buffer((2, 2, 64, 128), "f32")

for mesh_x in shard(2, "mesh[0]", mesh_dim=0):

for mesh_y in shard(2, "mesh[0]", mesh_dim=1):

for i, j, k in grid(64, 128, 64):

with block():

with init():

C_partial[i, j] = 0

C_partial[mesh_x, mesh_y, i, j] +=

A[mesh_x * 64 + i, mesh_y * 64 + k] * B[mesh_y * 64 + k, j]

for i, j in grid(64, 128):

with block("allreduce"):

with init():

C[mesh_x * 64 + i, j] = 0

C[mesh_x * 64 + i, j] += C_partial[mesh_x * 64 + i, j]

def main(

x: DTensor((128, 128), "f32", "mesh[0]", "S[0], S[1]"),

w: DTensor((128, 128), "f32", "mesh[0]", "S[1], S[0]"),

):

y: DTensor((128, 128), "f32", "mesh[0]", "R, S[0]") = redistribute(

w, "mesh[0]", "R, S[1]")

z: DTensor((128, 128), "f32", "mesh[0]", "S[0], R") = call_tir(

mm, [x, y], DTensor((128, 128), "f32", "mesh[0]", "S[0], R"))

shard data axis 0 along mesh axis 0�replicate over mesh axis 1

Same color indicate data stored on same device

Code View

2, 3

0, 1

1, 3

0, 2

1

2

1

3

2

x

w

redistribute

0

3

0

S[0], S[1]

S[1], S[0]

y

R, S[0]

call_tir(mm)

x

S[0], R

matmul

allreduce

tensor

program

Distributed Compute Graph View

communication

computation

51 of 56

Tensor Program Abstraction for Heterogeneous Hardware

  • Multi-level programming of GPU and accelerators
  • Automated scheduling through probabilistic transformation

52 of 56

Multi-level programming of GPU and accelerators

53 of 56

Tensor Block Abstraction

54 of 56

End to end system integration

55 of 56

Timeline

56 of 56

2025/1

2026/1

2027/1

2028/1

2029/1

2030/1

Education and Outreach Activities

TensorIR Improvements�and Initial AutoScheduling Library

Final System Integration

AutoScheduling Library for

Composable Tensor Data Encoding

Composable Tensor Data Encoding

Attention Optimization using Tensor Programs

Fully Distributed and Heterogeneous Optimizations

End to End Integration as Validation Point of Proposed Research

Joint Optimization of Computational Graph

and Tensor Program