Relax Figures
def symbolic_shape_fn(x: Tensor(("n", 2, 2), "f32")):
n, m = sym_var(), sym_var()
lv0: Tensor((n, 4), "f32") = reshape(x, shape(n, 4))� lv1: Tensor((n * 4,), "f32") = flatten(lv0)
lv2: Tensor(ndim=1, dtype="f32") = unique(lv1)
lv3 = match_cast(lv2, Tensor((m,), "f32"))
lv4: Tensor((m,), "f32") = exp(lv3)
return lv4
def any_shape_fn(x: Tensor((?, 2, 2), "f32")):
n = get_shape_value(x, axis=0)� lv0: Tensor((?, 4), "f32") = reshape(x, (n, 4))� lv1: Tensor((?,), "f32") = flatten(lv0)
lv2: Tensor(?, "f32") = unique(lv1)
lv3: Tensor(?, "f32") = exp(lv1)
return lv3
Shape annotation with unknown ? dimensions
Proposed approach: first class inter-operator symbolic shape
def symbolic_shape_fn(x: Tensor(("n", 2, 2), "f32")):
n, m = sym_var(), sym_var()
lv0: Tensor((n, 4), "f32") = reshape(x, shape(n, 4))� lv1: Tensor((n * 4,), "f32") = flatten(lv0)
lv2: Tensor(ndim=1, dtype="f32") = unique(lv1)
lv3 = match_cast(lv2, Tensor((m,), "f32"))
lv4: Tensor((m,), "f32") = exp(lv3)
return lv4
def any_shape_fn(x: Tensor((?, 2, 2), "f32")):
n = get_shape_value(x, axis=0)� lv0: Tensor((?, 4), "f32") = reshape(x, (n, 4))� lv1: Tensor((?,), "f32") = flatten(lv0)
lv2: Tensor(?, "f32") = unique(lv1)
lv3: Tensor(?, "f32") = exp(lv2)
return lv3
Shape annotation with unknown ? dimensions
First-class symbolic shape annotation
# Graph-level end-to-end dynamic ML model
def main(x: Tensor(("n", 128), "f32"), w: Tensor((128, 256), "f32")):
n = sym_var()
with dataflow(): � lv0: Tensor((n, 256), "f32") = call_tir(
mm, [x, w], Tensor((n, 256),"f32")
)
lv1: Tensor((n, 256), "f32") = relu(lv0)
lv2: Tensor((n, 256), "f32") = call_dps_library(
"cutlass.rms_norm", [lv1], Tensor((n, 256), "f32")
)
…
# Loop-level tensor programs
@tensorir_function
def mm(X: Buffer(("n", 128) "f32"), W: Buffer((128, 256), "f32"),
Y: Buffer(("n", 256), "f32")):
n = sym_var()
for i, j, k in grid(n, 256, 128):
with block():
with init():
Y[i, j] = 0
Y[i, j] += X[i, k] * W[k, j]
call library functions from graph level
call loop-level tensor programs from graph level
call library functions from graph level
# Graph-level end-to-end dynamic ML model
def graph_mlp_func(
x: Tensor(("n", 4096), "f16"),
w1: Tensor((4096, 11008), "f16"),
w2: Tensor((11008, 4096), "f16"),
):
n = sym_var()
lv0 = call_tir(func=matmul0, args=[x, w0],
out_shape=(n, 11008), sym_args=[n])
lv1: Tensor((n, 11008), "f16") = silu(lv0)
lv2: Tensor((n, 11008), "f16") = add(lv0, lv1)
lv3 = call_tir(func=matmul1, args=[lv3, w1],
out_shape=(n, 4096), sym_args=[n])
lv4 = call_dps_library(
"cutlass.rms_norm", args=[lv3],
out_shape=(n, 4096))
return lv4
# Loop-level tensor programs
def matmul0(
A: Tensor(("n", 4096), "f16"),
B: Tensor((4096, 11008), "f16"),
out: Tensor(("n", 11008), "f16"),
n: int,
):
# Computes "matmul(A, B)" and writes results
# to "out"
...
def matmul1(...):
...
call loop-level tensor programs from graph level
# Graph-level end-to-end dynamic ML model
def main(x: Tensor(("n", 128), "f32"), w: Tensor((128, 256), "f32")):
n = sym_var()
with dataflow(): � lv0: Tensor((n, 256), "f32") = call_tir(
mm, [x, w], Tensor((n, 256),"f32")
)
lv1: Tensor((n, 256), "f32") = relu(lv0)
lv2: Tensor((n, 256), "f32") = call_dps_library(
"cutlass.rms_norm", [lv1], Tensor((n, 256), "f32")
)
…
# Loop-level tensor programs
@tensorir_function
def mm(X: Buffer(("n", 128) "f32"), W: Buffer((128, 256), "f32"),
Y: Buffer(("n", 256), "f32")):
n = sym_var()
for i, j in grid(n, 256, 128):
with block():
with init():
Y[i, j] = 0
Y[i, j] += X[i, k] * W[k, j]
def call_tir(tir_func, args, annotation, sym_args):
# Allocate output tensor
output = alloc_tensor(annotation.shape, annotation.dtype)
# Call low-level function in destination-passing style
tir_func(*args, output, *sym_args)
return output
x
(2, n),"f32"
lv0 = exp(x)
lv1 = transpose(lv0)
lv2 = relu(lv1)
lv3 = transpose(lv2)
Allocate (2,n) for lv0
Allocate (n,2) for lv1
Allocate
(n,2) for lv2
Allocate (2,n) for lv3
Before memory planning:
After memory planning:
x
(2, n),"f32"
lv0 = exp(x)
lv1 = transpose(lv0)
lv2 = relu(lv1)
lv3 = transpose(lv2)
Allocate 2*n*4 bytes for s0
Allocate 2*n*4 bytes for s1
Instantiate
lv0:(2,n)
from s0
Instantiate
lv2:(n,2)
from s0
lv1:(n,2)
from s1
lv3:(2,n)
from s1
def main(x: Tensor(("n", 128), "f32"),
w: Tensor((128, 256), "f32")):
n = sym_var()
with dataflow():
lv1: Tensor((n, 256), "f32") = call_tir(
fused_mm_relu, [x, w], Tensor((n, 256),"f32")
)� return lv1
@tensorir_function
def fused_mm_relu(X: Buffer(("n", 128), "f32"),
W: Buffer((128, 256), "f32"),
Z: Buffer(("n", 256), "f32")):
n = sym_var()
Y = alloc_buffer((n, 256), "f32")
# matmul
for i, j, k in grid(n, 256, 128):
with block():
with init():
Y[i, j] = 0
Y[i, j] += X[i, k] * W[k, j]
# relu
for i, j in grid(n, 256):
with block():
Z[i, j] += max(Y[i, j], float32(0))
def main(x: Tensor(("n", 128), "f32"),
w: Tensor((128, 256), "f32")):
n = sym_var()
with dataflow():
lv1: Tensor((n, 256), "f32") = fused_mm_relu(x)� return lv1
def fused_mm_relu(x: Tensor(("n", 128), "f32"),
w: Tensor((128, 256), "f32")):
n = sym_var()
with dataflow(): � lv0: Tensor((n, 256), "f32") = call_tir(
mm, [x, w], Tensor((n, 256),"f32")
)
lv1: Tensor((n, 256), "f32") = relu(lv0)
return lv1
@tensorir_function
def mm(X: Buffer(("n", 128) "f32"),
W: Buffer((128, 256), "f32"),
Y: Buffer(("n", 256), "f32")):
func_attr("compute_pattern", "OutputEwiseFusible")
...
@tensorir_function
def relu(X: Buffer(("n", 256) "f32"),
Y: Buffer(("n", 256), "f32")):
func_attr("compute_pattern", "ElementWise")
...
Compute pattern analysis
+ FuseOps
+ FuseTensorIR
Initial Program
Follow-up scheduling
of TensorIR
def main(x: Tensor(("n", 128), "f32"),
w: Tensor((128, 256), "f32")):
n = sym_var()
with dataflow(): � lv0: Tensor((n, 256), "f32") = call_tir(
mm, [x, w], Tensor((n, 256),"f32")
)
lv1: Tensor((n, 256), "f32") = relu(lv0)
return lv1
@tensorir_function
def mm(X: Buffer(("n", 128) "f32"),
W: Buffer((128, 256), "f32"),
Y: Buffer(("n", 256), "f32")):
n = sym_var()
for i, j, k in grid(n, 256, 128):
with block():
with init():
Y[i, j] = 0
Y[i, j] += X[i, k] * W[k, j]
@tensorir_function
def relu(X: Buffer(("n", 256) "f32"),
Y: Buffer(("n", 256), "f32")):
n = sym_var()
for i, j in grid(n, 256):
with block():
Y[i, j] += max(X[i, j], float32(0))
def main(
x: Tensor(("n", 128), "f16")
Wdata: Tensor((128, 32), "u32"),
Wscale: Tensor((128, 8), "f16"),
):
n = sym_var()
with dataflow():
lv0: Tensor((n,256),"f16") = call_tir(
fused_decode_q4_mm,
[x, Wdata, Wscale],
Tensor((n, 256), "f16")
)� return lv0
@tensorir_function
def fused_decode_q4_mm(
X: Buffer(("n", 128), "f16"),
Wdata: Buffer((128, 32), "u32"),
Wscale: Buffer((128, 8), "f16"),
Y: Buffer(("n", 256), "f16"),
):
n = sym_var()
W = alloc_buffer((128, 256), "f16")
# decode_q4
for k, j in grid(128, 256):
W[k, j] = (
(data[k, j//8] >> (k%8*4)) & 15 - 7
) * scale[i, k // 32]
# matmul
for i, j, k in grid(n, 256, 128):
if k == 0:
Y[i, j] = 0.0
Y[i, j] += X[i, k] * W[k, j]
Follow-up scheduling
of TensorIR
def main(x: Tensor(("n", 128), "f16")
Wdata: Tensor((128, 32), "u32"),
Wscale: Tensor((128, 8), "f16")):
n = sym_var()
with dataflow():� W: Tensor((128, 256), "f16") = call_tir(� decode_q4,
[Wdata, Wscale],
Tensor((128, 256), "f16")
)
lv0: Tensor((n, 256), "f16") = call_tir(
mm, [x, W], Tensor((n, 256), "f16")
)
return lv0
@tensorir_function
def decode_q4(
Wdata: Buffer((128, 32), "u32"),
Wscale: Buffer((128, 8), "f16"),
W: Buffer((128, 128), "f16"),
):
for k, j in grid(128, 256):
W[k, j] = (
(data[k, j//8] >> (k%8*4)) & 15 - 7
) * scale[i, k // 32]
@tensorir_function
def mm(X: Buffer(("n", 128), "f16"),
W: Buffer((128, 256), "f16"),
Y: Buffer(("n", 256), "f16")):
n = sym_var()
for i, j, k in grid(n, 256, 128):
if k == 0:
Y[i, j] = 0.0
Y[i, j] += X[i, k] * W[k, j]
+ FuseTensorIR
Initial Program
Compute pattern analysis
+ FuseOps
def main(x: Tensor(("n", 128), "f16")
Wdata: Tensor((128, 32), "u32"),
Wscale: Tensor((128, 8), "f16")):
n = sym_var()
with dataflow():
lv0: Tensor((n,256),"f16") = fused_decode_q4_mm(
x, Wdata, Wscale
)� return lv0
def fused_decode_q4_mm(x:Tensor(("n",128),"f16")
Wdata:Tensor((128,32),"u32"),
Wscale:Tensor((128,8),"f16")):
n = sym_var()
with dataflow():� W: Tensor((128, 256), "f16") = call_tir(� decode_q4,[Wdata,Wscale],Tensor((128,256),"f16")
)
lv0: Tensor((n, 256), "f16") = call_tir(
mm, [x, W], Tensor((n, 256),"f16")
)
return lv0
@tensorir_function
def decode_q4(Wdata: Buffer((128, 32), "u32"),
Wscale: Buffer((128, 8), "f16"),
W: Buffer((128, 128), "f16")):
func_attr("compute_pattern", "Injective")
...
@tensorir_function
def mm(X: Buffer(("n", 128), "f16"),
W: Buffer((128, 256), "f16"),
Y: Buffer(("n", 256), "f16")):
func_attr("compute_pattern", "OutputEwiseFusible")
...
@tensorir_function
def decode_q4(
Wdata: Buffer((128, 16) "u32"),
Wscale: Buffer((128, 4), "f16"),
W: Buffer((128, 128), "f16")
):
for x, y in grid(128, 128):
W[x, y] = (
(data[x, y // 8] >> (y % 8 * 4)) & 15 - 7
) * scale[x, y // 32]
def main(
x: Tensor((1, 128), "f16"),
Wdata: Tensor((128, 16) "u32"),
Wscale: Tensor((128, 4), "f16")
):
with dataflow():
W0: Tensor((128, 128), "f16") = call_tir(� decode_q4, [Wdata, Wscale], � Tensor((128, 128), "f16")
)
lv0: Tensor((128, 128), "f16") = matmul(x, W0)
...
@tensorir_function
def fused_decode_q4_mm(
X: Buffer((1, 128) "f16"),
Wdata: Buffer((128, 16) "u32"),
Wscale: Buffer((128, 4), "f16"),
Y: Buffer((1, 128), "f16")
):
W = alloc_buffer((128, 128), "f16")
for x, y in grid(128, 128):
W[x, y] = (
(data[x, y // 8] >> (y % 8 * 4)) & 15 - 7
) * scale[x, y // 32]
for i, j, k in grid(1, 128, 128):
with block():
with init():
Y[i, j] = 0
Y[i, j] += X[i, k] * W[k, j]
def main(
x: Tensor((1, 128), "f16"),
Wdata: Tensor((128, 16) "u32"),
Wscale: Tensor((128, 4), "f16")
):
with dataflow():
lv0: Tensor((1, 128), "f16") = call_tir(� fused_decode_q4_mm, [x, Wdata, Wscale],
Tensor((1, 128), "f16")
)
...
@tensorir_function
def mm(X: Buffer(("n", 512) "f32"), W: Buffer((512, 1536), "f32"),
Y: Buffer(("n", 1536), "f32")):
n = sym_var()
for i, j, k in grid(n, 1536, 512):
with block():
with init():
Y[i, j] = 0
Y[i, j] += X[i, k] * W[k, j]�
def main(x: Tensor(("n", 512), "f32"), w: Tensor((512, 1536), "f32")):
n = sym_var()
with dataflow(): � lv0: Tensor((n, 1536), "f32") = call_tir(mm, [x, w], Tensor((n, 1536),"f32"))
lv1: Tensor((n, 3, 8, 64), "f32") = reshape(lv0, shape(n, 3, 8, 64))
lv2: Tensor((n, 8, 64), "f32") = call_dps_library(
"cutlass.attention", [lv1], Tensor((n, 8, 64), "f32")
)
…
ML models
Target Code
ML models
Our Approach
First-class symbolic shape�annotations to enable�dynamic shape aware analysis and optimizations
Cross-level abstraction that encapsulates high-level computation, tensor program, libraries and their interactions
Computational Graph IR
Tensor Program IR
Target Code
Multi-level ML Compilers
Multi-level abstractions�with optimizations in�each level
single-shot
lowering
other optional layers of IRs
Libraries
Composable optimizations �across dynamic shape aware computational graph, tensor program, and libraries
Code Generation
Cross-level abstraction that encapsulates high-level computation, tensor program, libraries and their interactions
@tensorir_function
def mm(X: Buffer(("n", 512) "f32"), W: Buffer((512, 1536), "f32"),
Y: Buffer(("n", 1536), "f32")):
n = sym_var()
for i, j, k in grid(n, 1536, 512):
with block():
with init():
Y[i, j] = 0
Y[i, j] += X[i, k] * W[k, j]�
def main(x: Tensor(("n", 512), "f32"), w: Tensor((512, 1536), "f32")):
n = sym_var()
with dataflow(): � lv0: Tensor((n, 1536), "f32") = call_tir(mm, [x, w], Tensor((n, 1536),"f32"))
lv1: Tensor((n, 3, 8, 64), "f32") = reshape(lv0, shape(n, 3, 8, 64))
lv2: Tensor((n, 8, 64), "f32") = call_dps_library(
"cutlass.attention", [lv1], Tensor((n, 8, 64), "f32")
)
…
ML models
Target Code and Deployable Module
ML models
Our Approach
First-class symbolic shape�annotations to enable�dynamic shape aware analyses and optimizations
Computational Graph IR
Tensor Program IR
Target Code
Multi-level ML Compilers
Multi-level abstractions�with optimizations in�each level
single-shot
lowering
other optional layers of IRs
Libraries
Composable optimizations �across dynamic shape aware computational graph, tensor program, and libraries
…
Code Generation
Cross-level abstraction that encapsulates high-level computation graph, tensor programs, libraries and their interactions
@tensorir_function
def mm(X: Buffer(("n", 512) "f32"), W: Buffer((512, 1536), "f32"),
Y: Buffer(("n", 1536), "f32")):
n = sym_var()
for i, j, k in grid(n, 1536, 512):
with block():
with init():
Y[i, j] = 0
Y[i, j] += X[i, k] * W[k, j]�
def main(x: Tensor(("n", 512), "f32"), w: Tensor((512, 1536), "f32")):
n = sym_var()
with dataflow(): � lv0: Tensor((n, 1536), "f32") = call_tir(mm, [x, w], Tensor((n, 1536),"f32"))
lv1: Tensor((n, 3, 8, 64), "f32") = reshape(lv0, shape(n, 3, 8, 64))
lv2: Tensor((n, 8, 64), "f32") = call_dps_library(
"cutlass.attention", [lv1], Tensor((n, 8, 64), "f32")
)
…
ML models
Target Code and Deployable Module
ML models
Our Approach
First-class symbolic shape�annotations to enable�dynamic shape aware analyses and optimizations
Computational Graph IR
Tensor Program IR
Target Code
Multi-level ML Compilers
Multi-level abstractions�with optimizations in�each level
single-shot
lowering
other optional layers of IRs
Libraries
Cross-level optimizations �across computational graph,
tensor program, and libraries,
including:
- cross-level dynamic
shape-aware operator fusion
- dynamic shape-aware memory
planning
- cross-level tensor program
workspace lifting
- CUDA Graph offloading
- tensor program optimizations
via partial lowering
…
@tensorir_function
def mm(X: Buffer(("n", 512) "f32"), W: Buffer((512, 1536), "f32"),
Y: Buffer(("n", 1536), "f32")):
n = sym_var()
for i, j, k in grid(n, 1536, 512):
with block():
with init():
Y[i, j] = 0
Y[i, j] += X[i, k] * W[k, j]�
def main(x: Tensor(("n", 512), "f32"), w: Tensor((512, 1536), "f32")):
n = sym_var()
with dataflow(): � lv0: Tensor((n, 1536), "f32") = call_tir(mm, [x, w], Tensor((n, 1536),"f32"))
lv1: Tensor((n, 3, 8, 64), "f32") = reshape(lv0, shape(n, 3, 8, 64))
lv2: Tensor((n, 8, 64), "f32") = call_dps_library(
"cutlass.attention", [lv1], Tensor((n, 8, 64), "f32")
)
…
Construct and import
Universal deployment
Composable
optimizations
Graph-level optimizations
Tensor program optimizations
SLM
Synergy Ecosystem
LLM.compile
AutoLLM
Model def
Quantize
TVM Unity Abstraction�Compilation
Computational graph
Platform Dependent optimization
dlight
Nv kernel/
library dispatch
Memory planning
fusion
distributed
Paged KV
MultiGPU/
Disco
Runtime
AWQ
NCCL
CUTLASS/TRT
import
Annotation | Examples | Explanation |
Object | Object | Any runtime value |
Shape | Shape([n, 4]) Shape(ndim=2) | Symbolic shape value (n, 4)�Shape with two dimensions |
Tensor | Tensor((n, 4), "f32") Tensor(ndim=None, dtype="f32") | Tensor with symbolic shape (n, 4)�Tensor with unknown dimensions |
Tuple | Tuple[Tensor((n, 4), "f32"), Object] | Tuple of a Tensor and an Object |
Callable | Callable( [Tensor(("n", 4), "f32")], Tensor(("n * 4",), "f32") ) | Function that takes a (n, 4) Tensor�and returns a (n*4,) Tensor |
Annotation
def main(x: Tensor(("n", 4), "f32")) -> Tensor(("n * 2",), "f32"):
n = sym_var()
with dataflow():
lv0: Tensor((n, 4), "f32") = call_tir(exp, [x], Tensor((n, 4), "f32"))
f0: Callable = subfunc
lv1: Tensor((n * 4,), "f32") = subfunc(lv0)
lv2: Tuple[
Tensor((n * 2,), "f32"), Tensor((n * 2,), "f32")
] = split(lv1, sections=2)
lv3: Tensor((n * 2,), "f32") = lv2[0]
return lv3
def subfunc(x: Tensor(("n", 4), "f32")) -> Tensor(("n * 4",), "f32"):
n = sym_var()
with dataflow():
lv0: Tensor((n * 4,), "f32") = flatten(x)
lv1: Tensor((n * 4,), "f32") = relu(lv0)
return lv1
@tensorir_function
def exp(X: Buffer(("n", 4) "f32"), Y: Buffer(("n", 4), "f32")):� ...
Dataflow block
Cross-function call
Cross-level function call
Annotation
def main(x: Tensor(("n", 4), "f32")) -> Tensor(("n * 2",), "f32"):
n = sym_var()
with dataflow():
lv0: Tensor((n, 4), "f32") = call_tir(exp, [x], Tensor((n, 4), "f32"))
lv1: Tensor((n, 4), "f32") = call_dps_library(
"cutlass.rms_norm", [lv0], Tensor((n, 4), "f32")
)
f0: Callable = subfunc
lv2: Tensor((n * 4,), "f32") = subfunc(lv1)
lv3: Tuple[
Tensor((n * 2,), "f32"), Tensor((n * 2,), "f32")
] = split(lv2, sections=2)
lv4: Tensor((n * 2,), "f32") = lv3[0]
return lv4
def subfunc(x: Tensor(("n", 4), "f32")) -> Tensor(("n * 4",), "f32"):
n = sym_var()
with dataflow():
lv0: Tensor((n * 4,), "f32") = flatten(x)
lv1: Tensor((n * 4,), "f32") = relu(lv0)
return lv1
@tensorir_function
def exp(X: Buffer(("n", 4) "f32"), Y: Buffer(("n", 4), "f32")):� ...
Dataflow block
Subgraph function call
Foreign function call
Foreign function (tensor program)
def subfn(s: Shape(["n", "m"])) -> Tensor(("n * m",), "f32"):
...
def subgraph_func_shape_deduce_example(� x: Tensor(("n",), "f32"),� y: Shape(ndim=2)�):
n = sym_var()
f0: Callable([Shape(["n", "m"])], Tensor(("n * m",), "f32")) = subfn� lv0: Tensor((n * 4,), "f32") = f0(shape(n, 4))
lv1: Tensor((12,), "f32") = subfn(shape(3, 4))
lv2: Tensor(((n + 1) * 4,), "f32") = subfn(shape(n + 1, 4))� lv3: Tensor(ndim=1, dtype="f32") = subfn(y)
...
Extra symbolic shape parameter to pass in n
Parameter annotation can
be an expression
Call into fused function
def main(x: Tensor(("n", 2), "f32")):
n = sym_var()
lv0: Tensor(("2 * n",), "f32") = flatten(x)�� lv1: Tensor(("2 * n",), "f32") = add(lv0, lv0)
lv2: Tensor(("2 * n",), "f32") = relu(lv1)� ...
Before fusion
After fusion
Regions to fuse
def fused_add_relu(
x: Tensor(("n * 2",), "f32"),
y: Tensor(("n * 2",), "f32"),
s: Shape(["n"])
) -> Tensor(("n * 2",), "f32"):
lv0 = add(x, y)
lv1 = relu(lv0)
return lv1
def main(x: Tensor(("n", 2), "f32")):
n = sym_var()
lv0: Tensor(("2 * n",), "f32") = flatten(x)� lv1: Tensor(("2 * n",), "f32") = fused_add_relu(lv0, lv0, shape(n))� ...
def subfn(s: Tensor(["n", "m"])) -> Tensor(("n * m", ), "f32"):
lv0 = empty(s)� lv1 = flatten(lv0)
return lv0
def fallback_example(x: Tensor(ndim=1, "f32")):
n = sym_var()
f0: Callable([Shape(["n", "m"])], Tensor(("n * m", ), "f32")) = subfn� lv0: Tensor((n * 4,), "f32") = f0(shape(n, 4))
lv1: Tensor((12,), "f32") = subfn(shape(3, 4))
lv2: Tensor(((n + 1) * 4,), "f32") = subfn(shape(n + 1, 4))
...
program
program
def main(x: Tensor(("n", 128), "f32"), w: Tensor((128, 256), "f32")):
n = sym_var()
with dataflow(): � lv0: Tensor((n, 256), "f32") = call_tir(
mm, [x, w], Tensor((n, 256),"f32")
)
lv1: Tensor((n, 256), "f32") = rms_norm(lv0)
lv1: Tensor((n, 256), "f32") = call_dps_library(
"cutlass.rms_norm", [lv0], Tensor((n, 256), "f32")
)
...
@tensorir_function
def mm(X: Buffer(("n", 128) "f32"),
W: Buffer((128, 256), "f32"),
Y: Buffer(("n", 256), "f32")):
n = sym_var()
for i, j, k in grid(n, 256, 128):
with block():
with init():
Y[i, j] = 0
Y[i, j] += X[i, k] * W[k, j]
I. Replacement from composable partial lowering passes
TensorIR function body
Program analysis and transformations
Optimized function body
II. Loop-level TensorIR optimization
update
def main(x: Tensor(("n", 128), "f32"), w: Tensor((128, 256), "f32")):
n = sym_var()
with dataflow(): � lv0: Tensor((n, 256), "f32") = call_tir(
mm, [x, w], Tensor((n, 256),"f32")
)
lv1: Tensor((n, 256), "f32") = rms_norm(lv0)
lv1: Tensor((n, 256), "f32") = call_dps_library(
"cutlass.rms_norm", [lv0], Tensor((n, 256), "f32")
)
...
@tensorir_function
def mm(X: Buffer(("n", 128) "f32"),
W: Buffer((128, 256), "f32"),
Y: Buffer(("n", 256), "f32")):
n = sym_var()
for i, j, k in grid(n, 256, 128):
with block():
with init():
Y[i, j] = 0
Y[i, j] += X[i, k] * W[k, j]
I. Replacement from composable partial lowering passes
TensorIR function body
Program analysis and transformations
Optimized function body
II. Loop-level TensorIR optimization
update
Before cross-level workspace lifting
Lift allocation
to graph level
def main(x: Tensor(("n", 2048), "f32"),
w: Tensor((2048, 4096), "f32")):
n = sym_var()� lv0: Tensor((n, 4096), "f32") = call_tir(
mm_split_k, [x, w], Tensor((n, 4096),"f32")
)
return lv0
@tensorir_function
def mm_split_k(X: Buffer(("n", 2048) "f32"),
W: Buffer((2048, 4096), "f32"),
Y: Buffer(("n", 4096), "f32")):
n = sym_var()
workspace = alloc_buffer(8*1024*1024,"f32","global")
for i, j, k0, k1 in grid(n, 4096, ..., ...):
# Write partial accumulation of X*W into workspace
for i, j, k0 in grid(n, 4096, ...):
# Accumulate values in workspace and write to Y
After cross-level workspace lifting
def main(x: Tensor(("n", 2048), "f32"),
w: Tensor((2048, 4096), "f32")):
n = sym_var()
workspace = alloc_tensor((8*1024*1024,), "f32")� lv0: Tensor((n, 4096), "f32") = call_tir(
mm_split_k, [x, w, workspace],
Tensor((n, 4096),"f32")
)
return lv0
@tensorir_function
def mm_split_k(X: Buffer(("n", 2048) "f32"),
W: Buffer((2048, 4096), "f32"),
workspace: Buffer((8*1024*1024,), "f32"),
Y: Buffer(("n", 4096), "f32")):
n = sym_var()
for i, j, k0, k1 in grid(n, 4096, ..., ...):
# Write partial accumulation of X*W into workspace
for i, j, k0 in grid(n, 4096, ...):
# Accumulate values in workspace and write to Y
def optim_and_lowering_pipeline(mod, target):
passes = SequentialPass(
PartialLibraryLowering(target), # §4.5
LowerOperatorToTensorProgram(), # §4.6
CrossLevelDynShapeAwareOperatorFusion(), # §4.1
TensorProgramOptim(), # §4.5
CrossLevelTensorProgramWorkspaceLift(), # §4.3
DynShapeAwareMemoryPlanning(), # §4.2
CUDAGraphOffloading(target), # §4.4
BuildRunnableModule(target), # §4.6
)
return passes(mod)
Partial Library Lowering (§4.5)
Operator to Tensor Program Lowering (§4.6)
Dynamic Shape–Aware Operator Fusion (§4.1)
Tensor Program Optimizations (§4.5)
Tensor Program Workspace Lift (§4.3)
Dynamic Shape–Aware Memory Planning (§4.2)
CUDA Graph Offloading (§4.4)
Build to Runnable Module (§4.6)
Cross-level optims
Graph-level optims
Tensor program optims
Lowering passes
@tvm.script.ir_module
class MyIRModule:
# This is a TIR PrimFunc which calls the TIR intrinsic T.exp.
@T.prim_func
def tir_exp_func(x: T.handle, y: T.handle): ## <= D2
X = T.match_buffer(x, (n,), "float32")
Y = T.match_buffer(y, (n,), "float32")
with T.grid(n) as i:
Y[i] = T.exp(X[i])
# This is a Relax function which contains a dataflow block
# representing a computational graph, as well as a call to an
# opaque packed function which performs an in-place update to the
# data in the variable gv0.
@R.function
def relax_func(x: R.Tensor[(n, k), "float32"], w: R.Tensor[_, "float32"]):
# n, k above are implicitly defined within the function signature
# so we will be able to refer to n, k within all of relax_func
with R.dataflow(): ## <= D2
lv0 = R.match_shape(w, (k, m)) ## <= D1
lv1: R.Tensor[(n, m), "float32"] = R.dot(x, lv0)
lv2: R.Tensor[(n * m,), "float32"] = R.flatten(lv1) ## <= D1
lv3: R.Shape = (n * m,) ## <= D1
gv0 = R.call_tir(tir_exp_func, [lv2], lv3, dtype="float32") ## <= D0
R.outputs(gv0)
Before | In scope Variables Scope | After |
Tensor((n, 4), "f32")) | Object | Any runtime value |
Shape | Shape([n, 4]) Shape(ndim=2) | Symbolic shape value (n, 4)�Shape with two elements |
Tensor | Tensor(ndim=None, dtype="f32")) | Tensor with symbolic shape (n, 4)�Tensor with unknown dimensions |
Tuple | Tuple(Tensor((n, 4), "f32")), Object) | Tuple of a Tensor and an Object |
Callable | Callable( [Tensor(("n", 4), "f32")], Tensor(("n * 4", ), "f32") ) | Function that takes a (n, 4) Tensor�and returns a (n * 4,) Tensor |
def main(a: Tensor(("n", 128), "f32"),
b: Tensor((128, 256), "f32")):
n = sym_var(upper_bound=64)
...
storage_a = alloc_storage((64, 128), "f32")
storage_b = alloc_storage((128, 256), "f32")
storage_c = alloc_storage((64, 256), "f32")
c: Tensor((n, 256), "f32") = builtin.cuda_graph_run_or_capture(
subgraph_matmul_gelu, a, b, storage
)
return c
def main(input: Tensor(("n",), "f32")):
n = sym_var(upper_bound=64)
...
storage_a = alloc_storage((64, 128), "f32")
storage_b = alloc_storage((128, 256), "f32")
storage_c = alloc_storage((64, 256), "f32")
a: Tensor((n, 128), "f32") = alloc_tesnor(storage_a, (n, 128))
b: Tensor((128, 256), "f32") = alloc_tesnor(storage_b, (128, 256))
c: Tensor((n, 256), "f32") = subgraph_matmul_gelu(a, b, storage_c)
return c
# computation subgraph with statically allocated memory
def subgraph_matmul_gelu(a: Tensor(("n", 128), "f32"),
b: Tensor((128, 256), "f32"),
storage_c: Object):
n = sym_var(upper_bound=64)
c: Tensor((n, 256), "f32") = alloc_tensor(storage, (n, 256))
func_mm_gelu(a, b, c) # invoking low-level DPS tensor program
return c
Before CUDA Graph Offloading
After CUDA Graph Offloading
rewrite static subgraph calls to builtin CUDA Graph functions
def main(a: Tensor(("n", 64), "f32"),
b: Tensor((64, 64), "f32")):
n = sym_var(upper_bound=32)
...
storage_a = alloc_storage((32, 64), "f32")
storage_b = alloc_storage((64, 64), "f32")
a: Tensor((n, 64), "f32") = alloc_tesnor(storage_a, (n, 64))
b: Tensor((64, 64), "f32") = alloc_tesnor(storage_b, (64, 64))
...
storage_c = alloc_storage((32, 64), "f32")
storage_d = alloc_storage((32, 64), "f32")
storage_e = alloc_storage((32, 64), "f32")
e: Tensor((n, 64), "f32") = builtin.cuda_graph_run_or_capture(
subgraph, a, b, storage_c, storage_d, storage_e
)
return e
# computation subgraph with statically allocated memory
def subgraph(a: Tensor(("n", 64), "f32"), b: Tensor((64, 64), "f32"),
storage_c: Object, storage_d: Object, storage_e: Object):
n = sym_var(upper_bound=32)
c: Tensor((n, 64), "f32") = alloc_tensor(storage_c, (n, 64))
func_matmul(a, b, c) # low-level DPS matmul tensor program
d: Tensor((n, 64), "f32") = alloc_tensor(storage_d, (n, 64))
func_gelu(c, d) # low-level DPS GeLU tensor program
e: Tensor((n, 64), "f32") = alloc_tensor(storage_e, (n, 64))
func_add(a, d, e) # low-level DPS add tensor program
return e
def main(input: Tensor(("n",), "f32")):
n = sym_var(upper_bound=32)
...
a: Tensor((n, 64), "f32") = …
b: Tensor((64, 64), "f32") = …
c: Tensor((n, 64), "f32") = matmul(a, b)
d: Tensor((n, 64), "f32") = gelu(c)
e: Tensor((n, 64), "f32") = add(a, d)
return e
Before CUDA Graph Offloading
After CUDA Graph Offloading
AI | workloads | have |
AI | workloads | have | dynamic |
AI | workloads | have | dynamic | patterns |
t = 0
t = 1
t = 2
…
Transformer Model
…
lv0
Partial lowering
Analysis feedback
Cross-level transform
call_dps_library
call_tir
high-level
operators
temp workspace allocation in
tensor program
lv1
lv2
tensor
program
for i:� B[i] = max(A[i], 0)
lv1 is an element-wise operator and is invariant to scaling. �Reduces cost of annotating these properties per op. Generalize support to more custom ops.
ws = alloc()� ...
ws = param[2]� ...
alloc in graph-level
passed as param to tensor program
lv0
lv1
lv2
lv0
lv1
lv2
lv0
lv1
lv2
lv0
lv1
lv2
@tensorir_function
def mm(X: Buffer(("n", 512) "f32"), W: Buffer((512, 1536), "f32"),
Y: Buffer(("n", 1536), "f32")):
n = sym_var()
for i, j, k in grid(n, 1536, 512):
with block():
with init():
Y[i, j] = 0
Y[i, j] += X[i, k] * W[k, j]�
def main(x: Tensor(("n", 512), "f32"), w: Tensor((512, 1536), "f32")):
n = sym_var()
with dataflow(): � lv0: Tensor((n, 1536), "f32") = call_tir(mm, [x, w], Tensor((n, 1536),"f32"))
lv1: Tensor((n, 3, 8, 64), "f32") = reshape(lv0, shape(n, 3, 8, 64))
lv2: Tensor((n, 8, 64), "f32") = call_dps_library(
"cutlass.attention", [lv1], Tensor((n, 8, 64), "f32")
)
…
ML models
Target Code
ML models
Our Approach
First-class symbolic shape�annotations to enable�dynamic shape aware analysis and optimizations
Cross-level abstraction that encapsulates high-level computation, tensor program, libraries and their interactions
Computational Graph IR
Tensor Program IR
Target Code
Multi-level ML Compilers
Multi-level abstractions�with optimizations in�each level
single-shot
lowering
other optional layers of IRs
Libraries
Composable optimizations �across dynamic shape aware computational graph, tensor program, and libraries
nn.Module
Relax passes
Dlight, FlashInfer, TIR
Runtime
Disco
LLM
image
audio
...
Serve
import/compose
transform
tensor-level
runtime
def mm(
A: Buffer((64, 64) "f32"),
B: Buffer((64, 64), "f32"),
C: Buffer((64, 64), "f32")
):
for x, y, k in grid(64, 64, 64):
with block():
with init():
C[x, y] = 0
C[x, y] += A[x, k] * B[y, j]
Tensor Program
Libraries
cublasGemm(� DLTensor* A,� DLTensor* B,
DLTensor* C,
float alpha,
float beta�)
cutlass_attention(� DLTensor* in,� DLTensor* out�)
matmul
reshape
relu
attention
w
x
Computational Graph
Specialized
Hardware Primitives
Cross-level computational graph abstraction
@tensorir_function
def mm(X: Buffer((1, 512) "f32"), W: Buffer((512, 1536), "f32"),
Y: Buffer((1, 1536), "f32")):
for i, j, k in grid(1, 1536, 512):
with block():
with init():
Y[i, j] = 0
Y[i, j] += X[i, k] * W[k, j]�
def main(x: Tensor((1, 512), "f32"), w: Tensor((512, 1536), "f32")):
with dataflow(): � lv0: Tensor((1, 1536), "f32") = call_tir(mm, [x, w], Tensor((1, 1536),"f32"))
lv1: Tensor((1, 3, 8, 64), "f32") = reshape(lv0, shape(1, 3, 8, 64))
lv2: Tensor((1, 8, 64), "f32") = relu(lv1)
lv3: Tensor((1, 8, 64), "f32") = call_dps_library(
"cutlass.attention", [lv2], Tensor((1, 8, 64), "f32")
) ...
Tensor program
High-level computational graph
call_tir
reshape
relu
call_dps_library
w
x
Code
View
Compute Graph View
Tensor program as a compute graph node
Library as a compute graph node
def call_tir(tir_func, args, out_ann):
# Allocate output tensor
output = alloc_tensor(out_ann.shape, out_ann.dtype)
# Call low-level function in destination-passing style
tir_func(*args, output)
return output
def main(x: Tensor((8, 128), "f32"),
w: Tensor((128, 256), "f32")):
with dataflow():
lv1: Tensor((8, 256), "f32") = call_tir(
fused_mm_relu, [x, w], Tensor((8, 256),"f32")
)� return lv1
@tensorir_function
def fused_mm_relu(X: Buffer((8, 128), "f32"),
W: Buffer((128, 256), "f32"),
Z: Buffer((8, 256), "f32")):
Y = alloc_buffer((8, 256), "f32")
# matmul
for i, j, k in grid(8, 256, 128):
with block():
with init():
Y[i, j] = 0
Y[i, j] += X[i, k] * W[k, j]
# relu
for i, j in grid(8, 256):
with block():
Z[i, j] += max(Y[i, j], float32(0))
def main(x: Tensor((8, 128), "f32"),
w: Tensor((128, 256), "f32")):
with dataflow():
lv1: Tensor((8, 256), "f32") = fused_mm_relu(x)� return lv1
def fused_mm_relu(x: Tensor((8, 128), "f32"),
w: Tensor((128, 256), "f32")):
with dataflow(): � lv0: Tensor((8, 256), "f32") = call_tir(
mm, [x, w], Tensor((8, 256),"f32")
)
lv1: Tensor((8, 256), "f32") = relu(lv0)
return lv1
@tensorir_function
def mm(X: Buffer((8, 128) "f32"),
W: Buffer((128, 256), "f32"),
Y: Buffer((8, 256), "f32")):
func_attr("compute_pattern", "OutputEwiseFusible")
...
@tensorir_function
def relu(X: Buffer((8, 256) "f32"),
Y: Buffer((8, 256), "f32")):
func_attr("compute_pattern", "ElementWise")
...
Compute pattern analysis
+ FuseOps
+ FuseTensorIR
Initial Program
Follow-up scheduling
of TensorIR
def main(x: Tensor((8, 128), "f32"),
w: Tensor((128, 256), "f32")):
with dataflow(): � lv0: Tensor((8, 256), "f32") = call_tir(
mm, [x, w], Tensor((n, 256),"f32")
)
lv1: Tensor((8, 256), "f32") = relu(lv0)
return lv1
@tensorir_function
def mm(X: Buffer((8, 128) "f32"),
W: Buffer((128, 256), "f32"),
Y: Buffer((8, 256), "f32")):
for i, j, k in grid(8, 256, 128):
with block():
with init():
Y[i, j] = 0
Y[i, j] += X[i, k] * W[k, j]
@tensorir_function
def relu(X: Buffer((8, 256) "f32"),
Y: Buffer((8, 256), "f32")):
for i, j in grid(8, 256):
with block():
Y[i, j] += max(X[i, j], float32(0))
lv0
Partial lowering
Analysis feedback
Cross-level transform
call_dps_library
call_tir
high-level
operators
temp workspace allocation in
tensor program
lv1
lv2
tensor
program
for i:� B[i] = max(A[i], 0)
lv1 is an element-wise operator and is invariant to scaling. �Reduces cost of annotating these properties per op. Generalize support to more custom ops.
ws = alloc()� ...
ws = param[2]� ...
alloc in graph-level
passed as param to tensor program
lv0
lv1
lv2
lv0
lv1
lv2
lv0
lv1
lv2
lv0
lv1
lv2
Composable Tensor Data Encoding
@tensorir_function
def decode_q4(
Wdata: Buffer((128, 16) "u32"),
Wscale: Buffer((128, 4), "f16"),
W: Buffer((128, 128), "f16")
):
for x, y in grid(128, 128):
W[x, y] = (
(data[x, y // 8] >> (y % 8 * 4)) & 15 - 7
) * scale[x, y // 32]
def main(
x: Tensor((1, 128), "f16"),
Wdata: Tensor((128, 16) "u32"),
Wscale: Tensor((128, 4), "f16")
):
with dataflow():
W0: Tensor((128, 128), "f16") = call_tir(� decode_q4, [Wdata, Wscale], � Tensor((128, 128), "f16")
)
lv0: Tensor((128, 128), "f16") = matmul(x, W0)
...
@tensorir_function
def fused_decode_q4_mm(
X: Buffer((1, 128) "f16"),
Wdata: Buffer((128, 16) "u32"),
Wscale: Buffer((128, 4), "f16"),
Y: Buffer((1, 128), "f16")
):
W = alloc_buffer((128, 128), "f16")
for x, y in grid(128, 128):
W[x, y] = (
(data[x, y // 8] >> (y % 8 * 4)) & 15 - 7
) * scale[x, y // 32]
for i, j, k in grid(1, 128, 128):
with block():
with init():
Y[i, j] = 0
Y[i, j] += X[i, k] * W[k, j]
def main(
x: Tensor((1, 128), "f16"),
Wdata: Tensor((128, 16) "u32"),
Wscale: Tensor((128, 4), "f16")
):
with dataflow():
lv0: Tensor((1, 128), "f16") = call_tir(� fused_decode_q4_mm, [x, Wdata, Wscale],
Tensor((1, 128), "f16")
)
...
Intermediate lowering step, explicitly
unpack the decode, then further fuse matmul and dequantize
Enable customization
Distributed Annotation
@tensorir_function
def mm(A: Buffer((128, 128) "f32"), B: Buffer((128, 128), "f32"),
C: Buffer((128, 128), "f32")):
C_partial = alloc_buffer((2, 2, 64, 128), "f32")
for mesh_x in shard(2, "mesh[0]", mesh_dim=0):
for mesh_y in shard(2, "mesh[0]", mesh_dim=1):
for i, j, k in grid(64, 128, 64):
with block():
with init():
C_partial[i, j] = 0
C_partial[mesh_x, mesh_y, i, j] +=
A[mesh_x * 64 + i, mesh_y * 64 + k] * B[mesh_y * 64 + k, j]
for i, j in grid(64, 128):
with block("allreduce"):
with init():
C[mesh_x * 64 + i, j] = 0
C[mesh_x * 64 + i, j] += C_partial[mesh_x * 64 + i, j]
def main(
x: DTensor((128, 128), "f32", "mesh[0]", "S[0], S[1]"),
w: DTensor((128, 128), "f32", "mesh[0]", "S[1], S[0]"),
):
y: DTensor((128, 128), "f32", "mesh[0]", "R, S[0]") = redistribute(
w, "mesh[0]", "R, S[1]")
z: DTensor((128, 128), "f32", "mesh[0]", "S[0], R") = call_tir(
mm, [x, y], DTensor((128, 128), "f32", "mesh[0]", "S[0], R"))
shard data axis 0 along mesh axis 0�replicate over mesh axis 1
Same color indicate data stored on same device
Code View
2, 3
0, 1
1, 3
0, 2
1
2
1
3
2
x
w
redistribute
0
3
0
S[0], S[1]
S[1], S[0]
y
R, S[0]
call_tir(mm)
x
S[0], R
matmul
allreduce
tensor
program
Distributed Compute Graph View
communication
computation
Tensor Program Abstraction for Heterogeneous Hardware
Multi-level programming of GPU and accelerators
Tensor Block Abstraction
End to end system integration
Timeline
2025/1
2026/1
2027/1
2028/1
2029/1
2030/1
Education and Outreach Activities
TensorIR Improvements�and Initial AutoScheduling Library
Final System Integration
AutoScheduling Library for
Composable Tensor Data Encoding
Composable Tensor Data Encoding
Attention Optimization using Tensor Programs
Fully Distributed and Heterogeneous Optimizations
End to End Integration as Validation Point of Proposed Research
Joint Optimization of Computational Graph
and Tensor Program