IREE CodeGen
Mahesh Ravishankar, Lei Zhang, Hanhan Wang, Ahmed Taei, Thomas Raoux, Nicolas Vasilache
Cerebra Babelfish Device Inference (BDI) team
MLIR Open Design Meeting, 2020-08-20
pg. 1
Goals
} Codegeneration backends
pg. 2
Background
pg. 3
Documentations Resources
https://google.github.io/iree/design-docs/codegen-passes
https://google.github.io/iree/ir-examples
https://google.github.io/iree/xla-op-coverage
https://google.github.io/iree/tf-e2e-coverage
pg. 4
Compilation flow
MHLO Ops
Linalg op on tensors
Linalg op on memrefs
Linalg Tiling + Promotion
LLVM codegen
Distribute to workgroup + workitems
SPIR-V dialect
Detailed description of the IREE Code-generation pipeline is here
pg. 5
Running examples
func @gemm() {
%0 = ... : tensor<16x32xf32>
%1 = ... : tensor<32x8xf32>
%2 = “mhlo.dot”(%0, %1) : (tensor<16x32xf32>, tensor<32x8xf32>) -> tensor<16x8xf32>
...
}
func @elementwise {
%0 = ... : tensor<10x15xf32>
%1 = ... : tensor<10x15xf32>
%2 = ... : tensor<15xf32>
%3 = “mhlo.add”(%0, %1) : (tensor<10x15xf32>, tensor<10x15xf32) -> tensor<10x15xf32>
%4 = “mhlo.broadcast(%2) : (tensor<15xf32) -> tensor<10x15xf32>
%5 = “mhlo.mul”(%3, %4) : (tensor<10x15xf32>, tensor<10x15xf32>) -> tensor<10x15xf32>
...
}
pg. 6
MHLO To Linalg on Tensors
MHLO Ops
Linalg op on tensors
Linalg op on memrefs
pg. 7
Elementwise operation to Linalg on tensors
#map0 = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d1)>
func @elementwise() {
...
%3 = linalg.generic %0 %1 {..[#map0, #map1]..} {
// add operation
} : (tensor<10x15xf32>, tensor<10x15xf32>) -> tensor<10x15xf32>
%4 = linalg.generic %2 {..[#map0, #map1]..} {
^bb0(%arg0 : f32, %arg1 : f32):
linalg.yield %arg0 : f32
} : (tensor<15xf32>) -> tensor<10x15xf32>
%5 = linalg.generic %3 %4 {..[#map0, #map1]..} {
// mul operation
} : (tensor<10x15xf32>, tensor<10x15xf32>) -> tensor<10x15xf32>
}
pg. 8
Elementwise operation to Linalg on tensors
#map0 = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d1)>
func @elementwise() {
...
%3 = linalg.generic %0 %1 {..[#map0, #map1]..} {
// add operation
} : (tensor<10x15xf32>, tensor<10x15xf32>) -> tensor<10x15xf32>
%4 = linalg.generic %2 {..[#map0, #map1]..} {
^bb0(%arg0 : f32, %arg1 : f32):
linalg.yield %arg0 : f32
} : (tensor<15xf32>) -> tensor<10x15xf32>
%5 = linalg.generic %3 %4 {..[#map0, #map1]..} {
// mul operation
} : (tensor<10x15xf32>, tensor<10x15xf32>) -> tensor<10x15xf32>
}
pg. 9
Linalg fusion on tensors
MHLO Ops
Linalg op on tensors
Linalg op on memrefs
pg. 10
Elementwise operation after fusion
#map0 = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d1)>
func @elementwise() {
...
%3 = linalg.generic %0, %1, %2 { .. indexing_maps = [#map0, #map0, #map1, #map0] ..} {
^bb0(%arg0 : f32, %arg1 : f32, %arg2 : f32, %arg3 : f32):
%0 = addf %arg0, %arg1 : f32
%1 = mulf %0, %arg2 : f32
linalg.yield %1 : f32
} : (tensor<10x15xf32>, tensor<10x15xf32>, tensor<15xf32> -> (tensor<10x15xf32>)
...
}
pg. 11
Tensor to Buffer Conversion
MHLO Ops
Linalg op on tensors
Linalg op on memrefs
pg. 12
Running examples
func @gemm() {
%0 = ... : memref<16x32xf32>
%1 = ... : memref<32x8xf32>
%2 = ... : memref<16x8xf32>
linalg.matmul(%0, %1, %2) : (memref<16x32xf32>, memref<32x8xf32>, memref<16x8xf32>
}
func @elementwise {
%0 = ... : memref<10x15xf32>
%1 = ... : memref<10x15xf32>
%2 = ... : memref<15xf32>
%3 = ... : memref<10x15xf32>
linalg.generic %0, %1, %2, %3 {..} {
...
} : memref<10x15xf32>, memref<10x15xf32>, memref<15xf32>, memref<10x15xf32>
}
pg. 13
Linalg to SPIR-V in IREE
pg. 14
Linalg Tiling and Fusion
Linalg Tiling + Promotion
Distribute to workgroup + workitems
SPIR-V dialect
pg. 15
Example tiling of linalg.matmul operation
func @gemm() {
%0 = ... : memref<16x32xf32>
%1 = ... : memref<32x8xf32>
%2 = ... : memref<16x8xf32>
scf.parallel (%iv0, %iv1) = … {
scf.for %iv2 = {
…
%sv1 = subview %0[%iv0, iv2]
%sv2 = subview %1[%iv2, iv0]
%sv3 = subview %2[%iv0, iv1]
linalg.matmul %sv1, %sv2, %sv3 : (memref<8x4xf32>, memref<4x8xf32>, memref<8x8xf32>
}
}
}
Distributed to workgroups (2D)
Promotion to workgroup memory
Workgroup-level linalg.matmul
pg. 16
Distributing to workgroup/workitems
Linalg Tiling + Promotion
Distribute to workgroup + workitems
SPIR-V dialect
pg. 17
Matrix-Matrix multiply after distribution
func @gemm() {
...
%3 = “gpu.block_id”() {“x”} : index
%4 = “gpu.grid_dim”() {“x”} : index
%5 = “gpu.block_id() {“y”} : index
%6 = “gpu.grid_dim() {“y”} : index
...
scf.for %iv0 = {
…
%9 = “gpu.thread_id()”{“x”} :
%10 = “gpu.thread_id()”{“y”}
scf.for %iv1 {
...
}
}
}
Inter-tile k-loop
Intra-tile k-loop
pg. 18
Conversion to SPIR-V dialect
Linalg Tiling + Promotion
Distribute to workgroup + workitems
SPIR-V dialect
pg. 19
Linalg to SPIR-V in MLIR
(Alex Zinenko, Stephan Herhut, Alexander Belyaev, Denis Khalikov)
pg. 20
Linalg on Buffers to SPIR-V
Linalg Tiling
Distribute to workgroup
gpu.module
gpu.launch_func
SPIR-V dialect
mlir-vulkan-runner
pg. 21
Linalg to LLVMIR in IREE
pg. 22
Progressive lowering Of Linalg to LLVMIR
Linalg On Buffers
Vectorized SCF Loops
Scalar SCF Loops
STD OPS
LLVMIR
pg. 23
IREE CPU Codegen Compilation / Runtime
Translation
Compiler
Runtime HAL drivers
LLVMIR
LLVM bitcode
Shared library
dylib
llvm_jit
pg. 24
Example: Lowering linalg.matmul to LLVMIR
func @gemm() {
%0 = ... : memref<512x512xf32>
%1 = ... : memref<512x512xf32>
%2 = ... : memref<512x512xf32>
linalg.matmul(%0, %1, %2) : (memref<512x512xf32>, memref<512x512xf32>, memref<512x512xf32>
}
pg. 25
Lowering linalg.matmul to SCF
scf.for %arg3 = %c0 to %c512 step %c1 {
scf.for %arg4 = %c0 to %c512 step %c1 {
scf.for %arg5 = %c0 to %c512 step %c1 {
%0 = load %arg0[%arg3, %arg5] : memref<512x512xf32>
%1 = load %arg1[%arg5, %arg4] : memref<512x512xf32>
%2 = load %arg2[%arg3, %arg4] : memref<512x512xf32>
%3 = mulf %0, %1 : f32
%4 = addf %2, %3 : f32
store %4, %arg2[%arg3, %arg4] : memref<512x512xf32>
}
}
}
pg. 26
MatMul Tiling And Vectorization Strategy
scf.for %arg0 = %c0 to %c512 step %c64 {
scf.for %arg1 = %c0 to %c512 step %c64 {
scf.for %arg2 = %c0 to %c512 step %c64 {
%3 = subview %1...: memref<512x512xf32> to memref<64x64xf32…
...
scf.for %arg3 = %c0 to %c64 step %c32 {
scf.for %arg4 = %c0 to %c64 step %c32 {
scf.for %arg5 = %c0 to %c64 step %c32 {
%6 = subview %3...: memref<64x64xf32, ... to memref<32x32xf32…
...
... = vector.transfer_read ... : memref<32x32xf32, ... vector<4xf32>
... = vector.transpose ... : vector<4x4xf32>, vector<4x4xf32>
... = vector.outerproduct ... : vector<4xf32>, vector<4xf32>
1 Level of memory tiling
2 Level of memory tiling
Register level tiling & vector ops
pg. 27
Lowering To LLVMIR Dialect
llvm.func @dispatch_op_name(%arg0: !llvm.ptr<ptr<i8>>, %arg1: !llvm.ptr<i32>) {
%0 = llvm.bitcast %arg0 : !llvm.ptr<ptr<i8>> to
!llvm.ptr<struct<(ptr<float>, ptr<float>, ptr<float>)>>
%1 = llvm.load %0 : !llvm.ptr<struct<(ptr<float>, ptr<float>...
%2 = llvm.extractvalue %1[0] : !llvm.struct<(ptr<float>, ptr<float>...
%3 = llvm.mlir.undef : !llvm.struct<(ptr<float>, ptr<float>...
...
... = llvm.mlir.constant(512 : index) : !llvm.i64
... = llvm.insertvalue … [3, 0] : !llvm.struct<(ptr<float>, ptr<float>...
Fixed ABI
Static shape information recovered from IR
Dynamic shape information passed as arguments
Packed buffer arguments
First argument
pg. 28
Wrapping up
pg. 29
Current status
pg. 30
Next steps
MLIR infrastructure makes this all possible!
pg. 31