本人求职中!大模型 Post-Training 方向 · 长三角、南方地区优先 · 如果您有兴趣,欢迎联系 zhimi64@foxmail.com 👀 了解更多

對 torch 中 einsum 函數的一些分析

背景

PyTorch 和 NumPy 等庫都提供了 einsum 函數。einsum,全稱 Einstein summation convention(愛因斯坦求和約定),是一個方便而強大的工具。

我們熟悉的矩陣乘法、轉置、求矩陣的跡等操作都可視為 einsum 的特例。除此之外,它還可以表示更複雜的運算,例如將矩陣乘法、轉置等操作複合在一起表示為一個操作。

先從矩陣乘法說起。我們可以這樣用 einsum 表示矩陣乘法:

import torch
a = torch.rand(2, 2)
b = torch.rand(2, 2)
c = torch.einsum('ij,jk->ik', a, b)

ij,jk->ik 這個式子中,j-> 左側重複使用,且沒有出現在 -> 右側,因此我們需要對 j 求和。由此可知 einsum('ij,jk->ik', a, b) 的實際含義為 \(c_{ik}=\sum_j a_{ij} b_{jk}\),等同於矩陣乘法。

下面我們簡單驗證一下計算的正確性。

eps = 1e-5
assert (c - a @ b).abs().mean() < eps
print('檢查通過')
檢查通過

上面的程式碼中,a @ b 是 PyTorch 中計算 ab 矩陣乘法時的一般寫法。可以看到我們用 einsum 得到的計算結果與標準寫法的計算結果是一致的。

程式碼 torch.einsum('ij,jk->ik', a, b) 使用了 -> 符號分隔輸入和輸出。這被稱為顯式的 einsum。你也可以不使用 ->,只給出輸入的索引,這被稱為隱式的 einsum。例如 torch.einsum('ij,jk', a, b)。注意因為隱式的 einsum 不指定輸出的索引,因此輸出的索引將會按字母表順序排列。

為了簡便起見,這篇文章裡我只討論顯式的 einsum。基於以上介紹的基礎規則,我想在這篇文章中分享一些有趣的發現和思考:

  1. 順序無關性:einsum 的輸入參數的順序是可交換的;
  2. 手寫 einsum 函數:如何基於 Python 實現一個基礎的 einsum 函數;
  3. einsum 的梯度:einsum 函數的梯度仍然可以用 einsum 表示;

順序無關性

einsum 函數的要點是:1. 根據索引從輸入參數中取數,計算累積的乘積;2. 對重名的,在輸出中沒有出現的索引求和。

可以看到 einsum 的規則和輸入參數的順序沒有關係。 因此,einsum 函數的輸入參數是可以交換順序的。

例如,雖然矩陣乘法沒有交換律,但是基於 einsum 計算矩陣乘法時,你可以交換參數的順序。

for _ in range(100):
    a = torch.rand(2, 2)
    b = torch.rand(2, 2)
    c = torch.einsum('ij,jk->ik', a, b)  # 正序
    d = torch.einsum('jk,ij->ik', b, a)  # 反序
    assert (c - d).abs().mean() < eps
print('檢查通過')
檢查通過

手寫 einsum

einsum 函數的計算是這樣一個過程:一個遍歷所有索引,計算輸入元素的乘積,然後求和的過程。如果我們有一個函數能夠遍歷所有的索引,那會很方便。

我們先來實現這樣的一個函數 iter_elements

import torch

def iter_elements(sizes):
    if len(sizes) == 0:
        yield []
        return
    for i in range(sizes[0]):
        for j in iter_elements(sizes[1:]):
            yield (i, *j)

for idx in iter_elements([2, 3]):
    print(idx)
(0, 0)
(0, 1)
(0, 2)
(1, 0)
(1, 1)
(1, 2)

上面的程式碼實現了一個遍歷所有可能索引的函數 iter_elements。對於尺寸為 \(2\times 3\) 的矩陣,這個函數輸出了所有可能的六個索引。這是一個非常方便的函數。接下來我們就基於它在 einsum 函數中遍歷每一個元素。

def einsum_forward(equation, *operands):
    input_dims, output_dim = equation.split('->')
    input_dims = input_dims.split(',')

    # 收集所有的索引名
    index_names = list(set(''.join(input_dims)))
    # 收集每個索引名對應的維度大小
    sizes = [-1] * len(index_names)
    for shape, tensor in zip(input_dims, operands):
        for index_name, size in zip(shape, tensor.shape):
            assert sizes[index_names.index(index_name)] == -1 or \
                sizes[index_names.index(index_name)] == size
            sizes[index_names.index(index_name)] = size

    # 計算輸出矩陣的尺寸
    output_size = [sizes[index_names.index(name)] for name in output_dim]
    # 將輸出矩陣用 0 初始化
    output = operands[0].new_zeros(output_size)
    # 遍歷所有的索引
    for idx in iter_elements(sizes):
        # 映射到輸出矩陣的索引
        idx_output = tuple(idx[index_names.index(name)] for name in output_dim)
        prod = 1
        for input_dim, tensor in zip(input_dims, operands):
            # 對於每一個輸入 tensor,取得對應的索引
            idx_input = tuple(idx[index_names.index(name)] for name in input_dim)
            # 計算累積的乘積
            prod = prod * tensor[idx_input]
        # 求和
        output[idx_output] += prod
    return output

接下來我們編寫一些測試用例,檢查 einsum_forwardtorch.einsum 的計算結果是否一致。

# 矩陣乘法
a = torch.rand((5, 4))
b = torch.rand((4, 5))
c = einsum_forward('ij,jk->ik', a, b)
c2 = torch.einsum('ij,jk->ik', a, b)
c3 = a @ b
assert (c - c2).abs().mean() < eps
assert (c - c3).abs().mean() < eps
# 矩陣轉置
a = torch.rand((5, 4))
b = einsum_forward('ij->ji', a)
b2 = torch.einsum('ij->ji', a)
b3 = a.T
assert (b - b2).abs().mean() < eps
assert (b - b3).abs().mean() < eps
# 矩陣的跡
a = torch.rand((5, 5))
b = einsum_forward('ii->', a)
b2 = torch.einsum('ii->', a)
b3 = torch.trace(a)
assert (b - b2).abs().mean() < eps
assert (b - b3).abs().mean() < eps
print('全部檢查通過。')
全部檢查通過。

einsum 函數的梯度

舉一個簡單的例子。假設 \(c_{ik} = \sum_{j} a_{ij} b_{jk}\)。對於任意的 \(i,j\),導數 \(\frac{\partial L}{\partial c_{ik}}\) 都已知,問如何計算 \(\frac{\partial L}{\partial a_{ij}}\)\(\frac{\partial L}{\partial b_{jk}}\)

顯然 \[ \frac{\partial L}{\partial a_{ij}} = \sum_k \frac{\partial L}{\partial c_{ik}} b_{jk}, \] \[ \frac{\partial L}{\partial b_{jk}} = \sum_i \frac{\partial L}{\partial c_{ik}} a_{ij}. \]

可以看到,在這個例子中,不論是前向的計算還是梯度的計算都可以表示為一系列乘積的和。

如果你仔細地觀察和理解了 einsum 函數的計算方式,不難得出更一般的結論,一個由 einsum 函數定義的表達式,其對任意一個輸入參數的梯度同樣可以表示為 einsum 函數的形式。

現在假設 a, b, c 是某個 einsum 操作的輸入,經過操作我們得到輸出 d,如下所示:

a = torch.rand((2, 2), requires_grad=True)
b = torch.rand((2, 2), requires_grad=True)
c = torch.rand((2, 2), requires_grad=True)
d = torch.einsum('ij,jk,kl->il', a, b, c)

假設我們已經知道了 d 的梯度 grad_d

grad_d = torch.rand(d.shape)  # 隨機初始化 grad_d,假設它就是梯度

PyTorch 能很方便地幫我們分別算出 a, b, c 的梯度。

d.backward(grad_d)
print(a.grad.shape, b.grad.shape, c.grad.shape)
torch.Size([2, 2]) torch.Size([2, 2]) torch.Size([2, 2])

但是如前所述,einsum 函數的梯度仍然可以組織成 einsum 的形式。讓我們來驗證這一點。

首先是 a 的梯度 \(\frac{\partial L}{\partial a}\)

grad_a = torch.einsum('il,jk,kl->ij', grad_d, b, c)
assert torch.mean(torch.abs(a.grad - grad_a)) < eps

接著是 b 的梯度:

grad_b = torch.einsum('il,ij,kl->jk', grad_d, a, c)
assert torch.mean(torch.abs(b.grad - grad_b)) < eps

c 的梯度:

grad_c = torch.einsum('il,ij,jk->kl', grad_d, a, b)
assert torch.mean(torch.abs(c.grad - grad_c)) < eps

至此,我們驗證了 a, b, c 三個參數的梯度計算方式,它們都可以用 einsum 函數表示,而且計算結果和 PyTorch 的計算結果一致!

總結

最近在實現自己的深度學習模型時,我設計了一個簡單的模塊,其中一部分用到了 einsum。為了加速這個模塊在梯度反傳階段的計算,我作了詳細的推導,發現 einsum 有一些有趣的性質,於是就記錄下來,形成了這篇文章。

仔細分析可以發現,其實 einsum 本質上是計算元素間乘積的和的過程。儘管 einsum 的使用方式多種多樣,但只要把握這個本質,其性質也就不難理解了。

By @執迷 in
Tags : #torch, #einsum,