Tutorial: Matrix chain product#
We have the tensors \(Q \in \mathbb{R}^{56x9xE}\), \(K \in \mathbb{R}^{56x9}\), \(P \in \mathbb{R}^{56x9xE}\), and \(A_e \in \mathbb{R}^{9x9}, e\in[0,E)\). For all \(e\in[0,E)\) we want to compute the matrix chain multiplication
where \(Q(:,:,e)\) selects a \(\mathbb{R}^{56x9}\) submatrix from the tensor Q and likewise for \(P(:,:,e)\).
In the tensor language we can implement the kernel as following:
func @fused_kernel(%K: memref<f32x56x56>,
%P: memref<f32x56x9x?>,
%A: group<memref<f32x9x9>x?>,
%Q: memref<f32x56x9x?>) {
%gid = group_id.x : index ; Get our index e
%p = subview %P[0:56,0:9,%gid] : memref<f32x56x9> ; Get view on submatrix
%a = load %A[%gid] : memref<f32x9x9> ; Load matrix from group
%q = subview %Q[0:56,0:9,%gid] : memref<f32x56x9> ; Get view on submatrix
%tmp = alloca : memref<f32x56x9,local> ; Reserve temporary memory
; in the Shared Local Memory
%c0 = constant 0.0 : f32
%c1 = constant 1.0 : f32
gemm.n.n %c1, %K, %p, %c0, %tmp ; Compute tmp <- K P(:,:,e)
gemm.n.n %c1, %tmp, %a, %c1, %q ; Update Q(:,:,e) <- Q(:,:,e) + tmp A_e
}
Using the tinytc-opt tool we can run compiler passes on the code to get insight on what is happening under the hood. For example, running the insert-lifetime-stop, insert-barrier, and work-group-size pass,
tinytc-opt -pinsert-lifetime-stop -pinsert-barrier -pwork-group-size test.ir
we get
func @fused_kernel(%K: memref<f32x56x56>,
%P: memref<f32x56x9x?>,
%A: group<memref<f32x9x9>x?>,
%Q: memref<f32x56x9x?>) attributes{subgroup_size=32, work_group_size=[64,1]} {
%gid = group_id.x : index
%p = subview %P[0:56,0:9,%gid] : memref<f32x56x9>
%a = load %A[%gid] : memref<f32x9x9>
%q = subview %Q[0:56,0:9,%gid] : memref<f32x56x9>
%tmp = alloca : memref<f32x56x9,local>
%c0 = constant 0x0p+0 : f32
%c1 = constant 0x1p+0 : f32
gemm %c1, %K, %p, %c0, %tmp
barrier.local
gemm %c1, %tmp, %a, %c1, %q
lifetime_stop %tmp
}
We observe that
the kernel is executed concurrently by 64 work-items,
temporary memory is only needed until after the lifetime_stop instruction after the GEMM (if multiple alloca’s are present that do not overlap, that is, lifetime_stop for alloca #1 appears before alloca #2, then Shared Local Memory is reused, reducing the total amount needed),
and that a barrier has been introduced between the GEMM calls to avoid data races.
When using SYCL, we can run the kernel using the following pseudo-code:
#include <sycl/sycl.hpp>
#include <tinytc/tinytc.hpp>
#include <tinytc/tinytc_sycl.hpp>
#include <iostream>
auto ctx = tinytc::make_compiler_context();
set_error_reporter([](ctx.get(), char const *what, const tinytc_location_t *,
void *) { std::cerr << what << std::endl; },
nullptr);
try {
// Parse tensor program
auto prog = tinytc::parse_file("fused_kernel.ir", ctx.get());
// Initialize tensors
float *K = ...;
float *P = ...;
float **A = ...;
float *Q = ...;
// JIT compile program
auto q = sycl::queue{};
auto bundle = tinytc::create_kernel_bundle(q.get_context(), q.get_device(), prog.get());
auto kernel = tinytc::create_kernel(bundle, "fused_kernel");
auto exe_range = tinytc::get_execution_range(kernel, sycl_range<3u>{1, 1, howmany});
for (int timestep = 0; timestep < num_timesteps; ++timestep) {
q.submit([&](sycl::handler &h) {
h.set_args(K, P, howmany, A, howmany, Q, howmany);
h.parallel_for(exe_range, kernel);
}).wait();
}
} catch (tinytc::status const &st) {
std::cerr << "Error (" << static_cast<int>(st) << "): " << tinytc::to_string(st) << std::endl;
} catch (std::exception const &e) {
std::cerr << e.what() << std::endl;
}
Note that a fictional time-loop was introduced around q.submit. As a general rule, JIT compilation is expensive in comparison to kernel execution, hence, a compiled program should be reused many times.