本人求职中!大模型 Post-Training 方向 · 长三角、南方地区优先 · 如果您有兴趣,欢迎联系 zhimi64@foxmail.com 👀 了解更多

对 torch 中 einsum 函数的一些分析

本文瀏覽次數

背景

PyTorch 和 NumPy 等库都提供了 einsum 函数。einsum,全称 Einstein summation convention(爱因斯坦求和约定),是一个方便而强大的工具。

我们熟悉的矩阵乘法、转置、求矩阵的迹等操作都可视为 einsum 的特例。除此之外,它还可以表示更复杂的运算,例如将矩阵乘法、转置等操作复合在一起表示为一个操作。

先从矩阵乘法说起。我们可以这样用 einsum 表示矩阵乘法:

import torch
a = torch.rand(2, 2)
b = torch.rand(2, 2)
c = torch.einsum('ij,jk->ik', a, b)

ij,jk->ik 这个式子中,j-> 左侧重复使用,且没有出现在 -> 右侧,因此我们需要对 j 求和。由此可知 einsum('ij,jk->ik', a, b) 的实际含义为 \(c_{ik}=\sum_j a_{ij} b_{jk}\),等同于矩阵乘法。

下面我们简单验证一下计算的正确性。

eps = 1e-5
assert (c - a @ b).abs().mean() < eps
print('检查通过')
检查通过

上面的程式码中,a @ b 是 PyTorch 中计算 ab 矩阵乘法时的一般写法。可以看到我们用 einsum 得到的计算结果与标准写法的计算结果是一致的。

程式码 torch.einsum('ij,jk->ik', a, b) 使用了 -> 符号分隔输入和输出。这被称为显式的 einsum。你也可以不使用 ->,只给出输入的索引,这被称为隐式的 einsum。例如 torch.einsum('ij,jk', a, b)。注意因为隐式的 einsum 不指定输出的索引,因此输出的索引将会按字母表顺序排列。

为了简便起见,这篇文章里我只讨论显式的 einsum。基于以上介绍的基础规则,我想在这篇文章中分享一些有趣的发现和思考:

  1. 顺序无关性:einsum 的输入参数的顺序是可交换的;
  2. 手写 einsum 函数:如何基于 Python 实现一个基础的 einsum 函数;
  3. einsum 的梯度:einsum 函数的梯度仍然可以用 einsum 表示;

顺序无关性

einsum 函数的要点是:1. 根据索引从输入参数中取数,计算累积的乘积;2. 对重名的,在输出中没有出现的索引求和。

可以看到 einsum 的规则和输入参数的顺序没有关系。 因此,einsum 函数的输入参数是可以交换顺序的。

例如,虽然矩阵乘法没有交换律,但是基于 einsum 计算矩阵乘法时,你可以交换参数的顺序。

for _ in range(100):
    a = torch.rand(2, 2)
    b = torch.rand(2, 2)
    c = torch.einsum('ij,jk->ik', a, b)  # 正序
    d = torch.einsum('jk,ij->ik', b, a)  # 反序
    assert (c - d).abs().mean() < eps
print('检查通过')
检查通过

手写 einsum

einsum 函数的计算是这样一个过程:一个遍历所有索引,计算输入元素的乘积,然后求和的过程。如果我们有一个函数能够遍历所有的索引,那会很方便。

我们先来实现这样的一个函数 iter_elements

import torch

def iter_elements(sizes):
    if len(sizes) == 0:
        yield []
        return
    for i in range(sizes[0]):
        for j in iter_elements(sizes[1:]):
            yield (i, *j)

for idx in iter_elements([2, 3]):
    print(idx)
(0, 0)
(0, 1)
(0, 2)
(1, 0)
(1, 1)
(1, 2)

上面的程式码实现了一个遍历所有可能索引的函数 iter_elements。对于尺寸为 \(2\times 3\) 的矩阵,这个函数输出了所有可能的六个索引。这是一个非常方便的函数。接下来我们就基于它在 einsum 函数中遍历每一个元素。

def einsum_forward(equation, *operands):
    input_dims, output_dim = equation.split('->')
    input_dims = input_dims.split(',')

    # 收集所有的索引名
    index_names = list(set(''.join(input_dims)))
    # 收集每个索引名对应的维度大小
    sizes = [-1] * len(index_names)
    for shape, tensor in zip(input_dims, operands):
        for index_name, size in zip(shape, tensor.shape):
            assert sizes[index_names.index(index_name)] == -1 or \
                sizes[index_names.index(index_name)] == size
            sizes[index_names.index(index_name)] = size

    # 计算输出矩阵的尺寸
    output_size = [sizes[index_names.index(name)] for name in output_dim]
    # 将输出矩阵用 0 初始化
    output = operands[0].new_zeros(output_size)
    # 遍历所有的索引
    for idx in iter_elements(sizes):
        # 映射到输出矩阵的索引
        idx_output = tuple(idx[index_names.index(name)] for name in output_dim)
        prod = 1
        for input_dim, tensor in zip(input_dims, operands):
            # 对于每一个输入 tensor,取得对应的索引
            idx_input = tuple(idx[index_names.index(name)] for name in input_dim)
            # 计算累积的乘积
            prod = prod * tensor[idx_input]
        # 求和
        output[idx_output] += prod
    return output

接下来我们编写一些测试用例,检查 einsum_forwardtorch.einsum 的计算结果是否一致。

# 矩阵乘法
a = torch.rand((5, 4))
b = torch.rand((4, 5))
c = einsum_forward('ij,jk->ik', a, b)
c2 = torch.einsum('ij,jk->ik', a, b)
c3 = a @ b
assert (c - c2).abs().mean() < eps
assert (c - c3).abs().mean() < eps
# 矩阵转置
a = torch.rand((5, 4))
b = einsum_forward('ij->ji', a)
b2 = torch.einsum('ij->ji', a)
b3 = a.T
assert (b - b2).abs().mean() < eps
assert (b - b3).abs().mean() < eps
# 矩阵的迹
a = torch.rand((5, 5))
b = einsum_forward('ii->', a)
b2 = torch.einsum('ii->', a)
b3 = torch.trace(a)
assert (b - b2).abs().mean() < eps
assert (b - b3).abs().mean() < eps
print('全部检查通过。')
全部检查通过。

einsum 函数的梯度

举一个简单的例子。假设 \(c_{ik} = \sum_{j} a_{ij} b_{jk}\)。对于任意的 \(i,j\),导数 \(\frac{\partial L}{\partial c_{ik}}\) 都已知,问如何计算 \(\frac{\partial L}{\partial a_{ij}}\)\(\frac{\partial L}{\partial b_{jk}}\)

显然 \[ \frac{\partial L}{\partial a_{ij}} = \sum_k \frac{\partial L}{\partial c_{ik}} b_{jk}, \] \[ \frac{\partial L}{\partial b_{jk}} = \sum_i \frac{\partial L}{\partial c_{ik}} a_{ij}. \]

可以看到,在这个例子中,不论是前向的计算还是梯度的计算都可以表示为一系列乘积的和。

如果你仔细地观察和理解了 einsum 函数的计算方式,不难得出更一般的结论,一个由 einsum 函数定义的表达式,其对任意一个输入参数的梯度同样可以表示为 einsum 函数的形式。

现在假设 a, b, c 是某个 einsum 操作的输入,经过操作我们得到输出 d,如下所示:

a = torch.rand((2, 2), requires_grad=True)
b = torch.rand((2, 2), requires_grad=True)
c = torch.rand((2, 2), requires_grad=True)
d = torch.einsum('ij,jk,kl->il', a, b, c)

假设我们已经知道了 d 的梯度 grad_d

grad_d = torch.rand(d.shape)  # 随机初始化 grad_d,假设它就是梯度

PyTorch 能很方便地帮我们分别算出 a, b, c 的梯度。

d.backward(grad_d)
print(a.grad.shape, b.grad.shape, c.grad.shape)
torch.Size([2, 2]) torch.Size([2, 2]) torch.Size([2, 2])

但是如前所述,einsum 函数的梯度仍然可以组织成 einsum 的形式。让我们来验证这一点。

首先是 a 的梯度 \(\frac{\partial L}{\partial a}\)

grad_a = torch.einsum('il,jk,kl->ij', grad_d, b, c)
assert torch.mean(torch.abs(a.grad - grad_a)) < eps

接著是 b 的梯度:

grad_b = torch.einsum('il,ij,kl->jk', grad_d, a, c)
assert torch.mean(torch.abs(b.grad - grad_b)) < eps

c 的梯度:

grad_c = torch.einsum('il,ij,jk->kl', grad_d, a, b)
assert torch.mean(torch.abs(c.grad - grad_c)) < eps

至此,我们验证了 a, b, c 三个参数的梯度计算方式,它们都可以用 einsum 函数表示,而且计算结果和 PyTorch 的计算结果一致!

总结

最近在实现自己的深度学习模型时,我设计了一个简单的模块,其中一部分用到了 einsum。为了加速这个模块在梯度反传阶段的计算,我作了详细的推导,发现 einsum 有一些有趣的性质,于是就记录下来,形成了这篇文章。

仔细分析可以发现,其实 einsum 本质上是计算元素间乘积的和的过程。尽管 einsum 的使用方式多种多样,但只要把握这个本质,其性质也就不难理解了。

By @執迷 in
Tags : #torch, #einsum,