福生无量摸鱼天尊

Triton is all you need —— Matrix Multiplication & Matrix Transpose & Matrix Copy

2025/10/24
5
0

在矩阵的运算中,由于现在LLM Scaling Law,现在模型的矩阵相当的巨大。而计算单元的访存和算力有限,故此通常采用分治的思想进行并行计算,即采用分块的方式进行分块运算,这称之为tile,这里涉及到大量的程序编写的架构和编译器优化,后续我们会展开来这个说。

Matrix Transpose

矩阵转置,举个例子

原矩阵 A (3行4列):
[ 1  2  3  4 ]
[ 5  6  7  8 ]
[ 9 10 11 12 ]

转置为:

[ 1  5  9 ]
[ 2  6 10 ]
[ 3  7 11 ]
[ 4  8 12 ]

leetGPU中,已经给出了代码框架,加上一个BLOCK_SIZE

import torch
import triton
import triton.language as tl

@triton.jit
def matrix_transpose_kernel(
    input, output,
    rows, cols,
    stride_ir, stride_ic,  
    stride_or, stride_oc,
    BLOCK_SIZE: tl.constexpr,
):
    pass

# input, output are tensors on the GPU
def solve(input: torch.Tensor, output: torch.Tensor, rows: int, cols: int):
    stride_ir, stride_ic = cols, 1  
    stride_or, stride_oc = rows, 1
    
    BLOCK_SIZE = 32

    grid = (triton.cdiv(rows, BLOCK_SIZE), triton.cdiv(cols, BLOCK_SIZE))

    matrix_transpose_kernel[grid](
        input, output,
        rows, cols,
        stride_ir, stride_ic,
        stride_or, stride_oc,
        BLOCK_SIZE
    ) 

在这里不难发现已经是tile的样子了,根据tile的思想,我们是要把:

[1, 2]   ->  [ 1  5 ]
[5, 6]   ->  [ 2  6 ]

在triton中,我们第一考虑的是用pytorch的操作对矩阵进行变换,因为pytorch的numpy对矩阵的操作非常的完整。要想达到矩阵的访问,在numpy下常用的就是加法,过程如下:

[[0], [4]] + [[0, 1]]
=
[[0, 1], 
 [4, 5]]

根据这个原理,我们只需要知道行从哪到哪,列从哪到哪,就可以得到对应的矩阵了。故此我们对一个大小为BLOCK_SIZE的tile进行索引。

由于有行列,记得索引出两个方向的pid,这个kernel是运行在二维的grid。

由于是进行矩阵转置,故此前行后列进行索引:

pid1 = tl.program_id(axis = 0)
pid2 = tl.program_id(axis = 1)

rows_ = pid1 * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
cols_ = pid2 * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)


# offs_row[:, None] = [[0], [1]]  (2x1的形状)
# offs_col[None, :] = [[0, 1]]    (1x2的形状)
# 行偏移: [[0], [1]] * 4 = [[0], [4]]
# 列偏移: [[0, 1]] * 1 = [[0, 1]]
input_offsets = input + rows_[:, None] * stride_ir + cols_[None, :] * stride_ic
output_offsets = output + cols_[:, None] * stride_or + rows_[None, :] * stride_oc

打上mask,然后读取,写回就行:

input_mask = (rows_[:, None] < rows) & (cols_[None, :] < cols)
output_mask = (cols_[:, None] < cols) & (rows_[None, :] < rows)

input_offsets = input + rows_[:, None] * stride_ir + cols_[None, :] * stride_ic
output_offsets = output + cols_[:, None] * stride_or + rows_[None, :] * stride_oc

matrix = tl.load(input_offsets, mask = input_mask)
tl.store(output_offsets, tl.trans(matrix), mask=output_mask)

对于这个mask,我们可以分步来说明其中的矩阵操作:

首先是rows_[:, None] 进行的操作是:

原始: rows_ = [0, 1, 2, 3]          # 形状: (4,)
结果: rows_[:, None] = [[0],        # 形状: (4, 1)
                        [1],
                        [2],
                        [3]]

rows_[:, None] < 3 进行的操作是:

[[0],     [[0 < 3],     [[True],      [[T],
 [1],  <3  [1 < 3],  =   [True],   =   [T],
 [2],  →   [2 < 3],  →   [True],   →   [T],
 [3]]      [3 < 3]]      [False]]      [F]]

对应的,cols_[None, :] < 5 进行的操作是

[[4, 5, 6, 7]] < 5  →  [[4<5, 5<5, 6<5, 7<5]]  →  [[T, F, F, F]]

最后进行的是广播机制(Broadcasting),广播机制指的是当两个数组形状不同时,Python会自动复制数据使它们形状匹配:

[[T],      [[T, F, F, F]]
 [T],  &
 [T],
 [F]]

# 其中,第一项进行复制
[[T],       [[T, T, T, T],     # 第一行复制4次
 [T],   →    [T, T, T, T],     # 第二行复制4次
 [T],        [T, T, T, T],     # 第三行复制4次
 [F]]        [F, F, F, F]]     # 第四行复制4次

# 第二项进行复制
[[T, F, F, F]]  →  [[T, F, F, F],    # 整行复制4次
                    [T, F, F, F],
                    [T, F, F, F],
                    [T, F, F, F]]

最后再进行运算:

[[T, T, T, T],     [[T, F, F, F],     [[T&T, T&F, T&F, T&F],     [[T, F, F, F],
 [T, T, T, T],  &   [T, F, F, F],  =   [T&T, T&F, T&F, T&F],  =   [T, F, F, F],
 [T, T, T, T],      [T, F, F, F],      [T&T, T&F, T&F, T&F],      [T, F, F, F],
 [F, F, F, F]]      [T, F, F, F]]      [F&T, F&F, F&F, F&F]]      [F, F, F, F]]

Matrix Copy

矩阵复制比矩阵转置简单,不需要去行N×M转成列的M×N,用的是同一套offset:

import torch
import triton
import triton.language as tl


@triton.jit
def matcopy(A_ptr, B_ptr,
            N,
            stride_row, stride_col,
            BLOCK_SIZE: tl.constexpr):
    pid0 = tl.program_id(0)
    pid1 = tl.program_id(1)

    offset0 = pid0 * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    offset1 = pid1 * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)

    mask = (offset0[:, None] < N) & (offset1[None, :] < N)
    offset = offset0[:, None] * stride_row + offset1[None, :] * stride_col
    
    a_data = tl.load(A_ptr + offset, mask=mask)
    tl.store(B_ptr + offset, a_data, mask=mask)
 

# a, b are tensors on the GPU
def solve(a: torch.Tensor, b: torch.Tensor, N: int):
    BLOCK_SIZE = 64
    stride_row, stride_col = a.stride(0), a.stride(1)

    grid = (triton.cdiv(N, BLOCK_SIZE), triton.cdiv(N, BLOCK_SIZE))
    matcopy[grid](a, b,
                  N,
                  stride_row, stride_col,
                  BLOCK_SIZE)

Matrix Multiplication 

对于矩阵乘法,在并行计算中,往往是以结果为导向的,也就是A×B=C,将C tile化,然后每个block单独计算C tile的一块,从C tileXY反推需要的ABXYload就行。

聪明的你肯定发现了,由于有shared memory的存在,且矩阵的访存行列是有区别的,但是triton都优化好了,我们只需要用就行,这就很省心,但是对应的cuda代码就没那么友好了,具体可以看cuda的矩阵乘法实现,这里不赘述了。核心代码如下:

# 矩阵维度说明:
# A 是 M×N 矩阵
# B 是 N×K 矩阵
# C 是 M×K 矩阵

# 计算当前线程块负责的行和列索引范围
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)  # 当前块处理的M维度索引 [BLOCK_M]
offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)  # 当前块处理的K维度索引 [BLOCK_K]

# 预先计算输出矩阵C的指针位置(形状:BLOCK_M × BLOCK_K)
c_ptrs = c + (offs_m[:, None] * stride_cm + offs_k[None, :] * stride_ck)

# 沿着N维度(规约维度)循环,将矩阵乘法分解为多个小块
for n0 in range(0, N, BLOCK_N):
    # 计算当前迭代处理的N维度索引范围
    offs_n = n0 + tl.arange(0, BLOCK_N)  # [BLOCK_N]
    
    # 计算A矩阵块的指针(形状:BLOCK_M × BLOCK_N)
    # A[offs_m, offs_n] - 取A矩阵的部分行和部分列
    a_ptrs = a + (offs_m[:, None]*stride_am + offs_n[None, :]*stride_an)
    
    # 计算B矩阵块的指针(形状:BLOCK_N × BLOCK_K)
    # B[offs_n, offs_k] - 取B矩阵的部分行和部分列
    b_ptrs = b + (offs_n[:, None]*stride_bn + offs_k[None, :]*stride_bk)

    # 从全局内存加载数据块到寄存器/共享内存
    a_block = tl.load(a_ptrs, mask=a_mask, other=0.)  # 加载A块,越界位置填充0
    b_block = tl.load(b_ptrs, mask=b_mask, other=0.)  # 加载B块,越界位置填充0
    
    # 执行块矩阵乘法并累加结果
    # acc[BLOCK_M, BLOCK_K] += a_block[BLOCK_M, BLOCK_N] @ b_block[BLOCK_N, BLOCK_K]
    acc += tl.dot(a_block, b_block, input_precision="ieee")

# 将累加结果写入全局内存中的C矩阵
tl.store(c_ptrs, acc, mask=c_mask)

这里for循环就是循环NBLOCK_N,然后对于每个BLOCK_N,这里的矩阵乘法是A tileM×N B tileN×K,产出C tileM×K