在矩阵的运算中,由于现在LLM Scaling Law,现在模型的矩阵相当的巨大。而计算单元的访存和算力有限,故此通常采用分治的思想进行并行计算,即采用分块的方式进行分块运算,这称之为tile,这里涉及到大量的程序编写的架构和编译器优化,后续我们会展开来这个说。
矩阵转置,举个例子
原矩阵 A (3行4列):
[ 1 2 3 4 ]
[ 5 6 7 8 ]
[ 9 10 11 12 ]转置为:
[ 1 5 9 ]
[ 2 6 10 ]
[ 3 7 11 ]
[ 4 8 12 ]leetGPU中,已经给出了代码框架,加上一个BLOCK_SIZE:
import torch
import triton
import triton.language as tl
@triton.jit
def matrix_transpose_kernel(
input, output,
rows, cols,
stride_ir, stride_ic,
stride_or, stride_oc,
BLOCK_SIZE: tl.constexpr,
):
pass
# input, output are tensors on the GPU
def solve(input: torch.Tensor, output: torch.Tensor, rows: int, cols: int):
stride_ir, stride_ic = cols, 1
stride_or, stride_oc = rows, 1
BLOCK_SIZE = 32
grid = (triton.cdiv(rows, BLOCK_SIZE), triton.cdiv(cols, BLOCK_SIZE))
matrix_transpose_kernel[grid](
input, output,
rows, cols,
stride_ir, stride_ic,
stride_or, stride_oc,
BLOCK_SIZE
) 在这里不难发现已经是tile的样子了,根据tile的思想,我们是要把:
[1, 2] -> [ 1 5 ]
[5, 6] -> [ 2 6 ]在triton中,我们第一考虑的是用pytorch的操作对矩阵进行变换,因为pytorch的numpy对矩阵的操作非常的完整。要想达到矩阵的访问,在numpy下常用的就是加法,过程如下:
[[0], [4]] + [[0, 1]]
=
[[0, 1],
[4, 5]]根据这个原理,我们只需要知道行从哪到哪,列从哪到哪,就可以得到对应的矩阵了。故此我们对一个大小为BLOCK_SIZE的tile进行索引。
由于有行列,记得索引出两个方向的pid,这个kernel是运行在二维的grid。
由于是进行矩阵转置,故此前行后列进行索引:
pid1 = tl.program_id(axis = 0)
pid2 = tl.program_id(axis = 1)
rows_ = pid1 * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
cols_ = pid2 * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
# offs_row[:, None] = [[0], [1]] (2x1的形状)
# offs_col[None, :] = [[0, 1]] (1x2的形状)
# 行偏移: [[0], [1]] * 4 = [[0], [4]]
# 列偏移: [[0, 1]] * 1 = [[0, 1]]
input_offsets = input + rows_[:, None] * stride_ir + cols_[None, :] * stride_ic
output_offsets = output + cols_[:, None] * stride_or + rows_[None, :] * stride_oc打上mask,然后读取,写回就行:
input_mask = (rows_[:, None] < rows) & (cols_[None, :] < cols)
output_mask = (cols_[:, None] < cols) & (rows_[None, :] < rows)
input_offsets = input + rows_[:, None] * stride_ir + cols_[None, :] * stride_ic
output_offsets = output + cols_[:, None] * stride_or + rows_[None, :] * stride_oc
matrix = tl.load(input_offsets, mask = input_mask)
tl.store(output_offsets, tl.trans(matrix), mask=output_mask)对于这个mask,我们可以分步来说明其中的矩阵操作:
首先是rows_[:, None] 进行的操作是:
原始: rows_ = [0, 1, 2, 3] # 形状: (4,)
结果: rows_[:, None] = [[0], # 形状: (4, 1)
[1],
[2],
[3]]rows_[:, None] < 3 进行的操作是:
[[0], [[0 < 3], [[True], [[T],
[1], <3 [1 < 3], = [True], = [T],
[2], → [2 < 3], → [True], → [T],
[3]] [3 < 3]] [False]] [F]]对应的,cols_[None, :] < 5 进行的操作是
[[4, 5, 6, 7]] < 5 → [[4<5, 5<5, 6<5, 7<5]] → [[T, F, F, F]]最后进行的是广播机制(Broadcasting),广播机制指的是当两个数组形状不同时,Python会自动复制数据使它们形状匹配:
[[T], [[T, F, F, F]]
[T], &
[T],
[F]]
# 其中,第一项进行复制
[[T], [[T, T, T, T], # 第一行复制4次
[T], → [T, T, T, T], # 第二行复制4次
[T], [T, T, T, T], # 第三行复制4次
[F]] [F, F, F, F]] # 第四行复制4次
# 第二项进行复制
[[T, F, F, F]] → [[T, F, F, F], # 整行复制4次
[T, F, F, F],
[T, F, F, F],
[T, F, F, F]]最后再进行运算:
[[T, T, T, T], [[T, F, F, F], [[T&T, T&F, T&F, T&F], [[T, F, F, F],
[T, T, T, T], & [T, F, F, F], = [T&T, T&F, T&F, T&F], = [T, F, F, F],
[T, T, T, T], [T, F, F, F], [T&T, T&F, T&F, T&F], [T, F, F, F],
[F, F, F, F]] [T, F, F, F]] [F&T, F&F, F&F, F&F]] [F, F, F, F]]矩阵复制比矩阵转置简单,不需要去行N×M转成列的M×N,用的是同一套offset:
import torch
import triton
import triton.language as tl
@triton.jit
def matcopy(A_ptr, B_ptr,
N,
stride_row, stride_col,
BLOCK_SIZE: tl.constexpr):
pid0 = tl.program_id(0)
pid1 = tl.program_id(1)
offset0 = pid0 * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
offset1 = pid1 * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = (offset0[:, None] < N) & (offset1[None, :] < N)
offset = offset0[:, None] * stride_row + offset1[None, :] * stride_col
a_data = tl.load(A_ptr + offset, mask=mask)
tl.store(B_ptr + offset, a_data, mask=mask)
# a, b are tensors on the GPU
def solve(a: torch.Tensor, b: torch.Tensor, N: int):
BLOCK_SIZE = 64
stride_row, stride_col = a.stride(0), a.stride(1)
grid = (triton.cdiv(N, BLOCK_SIZE), triton.cdiv(N, BLOCK_SIZE))
matcopy[grid](a, b,
N,
stride_row, stride_col,
BLOCK_SIZE)对于矩阵乘法,在并行计算中,往往是以结果为导向的,也就是A×B=C,将C tile化,然后每个block单独计算C tile的一块,从C tile的XY反推需要的AB的XY,load就行。
聪明的你肯定发现了,由于有shared memory的存在,且矩阵的访存行列是有区别的,但是triton都优化好了,我们只需要用就行,这就很省心,但是对应的cuda代码就没那么友好了,具体可以看cuda的矩阵乘法实现,这里不赘述了。核心代码如下:
# 矩阵维度说明:
# A 是 M×N 矩阵
# B 是 N×K 矩阵
# C 是 M×K 矩阵
# 计算当前线程块负责的行和列索引范围
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # 当前块处理的M维度索引 [BLOCK_M]
offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K) # 当前块处理的K维度索引 [BLOCK_K]
# 预先计算输出矩阵C的指针位置(形状:BLOCK_M × BLOCK_K)
c_ptrs = c + (offs_m[:, None] * stride_cm + offs_k[None, :] * stride_ck)
# 沿着N维度(规约维度)循环,将矩阵乘法分解为多个小块
for n0 in range(0, N, BLOCK_N):
# 计算当前迭代处理的N维度索引范围
offs_n = n0 + tl.arange(0, BLOCK_N) # [BLOCK_N]
# 计算A矩阵块的指针(形状:BLOCK_M × BLOCK_N)
# A[offs_m, offs_n] - 取A矩阵的部分行和部分列
a_ptrs = a + (offs_m[:, None]*stride_am + offs_n[None, :]*stride_an)
# 计算B矩阵块的指针(形状:BLOCK_N × BLOCK_K)
# B[offs_n, offs_k] - 取B矩阵的部分行和部分列
b_ptrs = b + (offs_n[:, None]*stride_bn + offs_k[None, :]*stride_bk)
# 从全局内存加载数据块到寄存器/共享内存
a_block = tl.load(a_ptrs, mask=a_mask, other=0.) # 加载A块,越界位置填充0
b_block = tl.load(b_ptrs, mask=b_mask, other=0.) # 加载B块,越界位置填充0
# 执行块矩阵乘法并累加结果
# acc[BLOCK_M, BLOCK_K] += a_block[BLOCK_M, BLOCK_N] @ b_block[BLOCK_N, BLOCK_K]
acc += tl.dot(a_block, b_block, input_precision="ieee")
# 将累加结果写入全局内存中的C矩阵
tl.store(c_ptrs, acc, mask=c_mask)这里for循环就是循环N个BLOCK_N,然后对于每个BLOCK_N,这里的矩阵乘法是A tile的M×N 乘 B tile的N×K,产出C tile的M×K。