注意
跳转到底部以下载完整的示例代码。
低内存 Dropout¶
在本教程中,您将编写一个内存高效的 Dropout 实现,其状态将由一个 int32 种子构成。这与更传统的 Dropout 实现不同,后者的状态通常由与输入形状相同的位掩码张量构成。
通过本教程,您将了解
使用 PyTorch 实现朴素 Dropout 的局限性。
Triton 中的并行伪随机数生成。
基线¶
Dropout 算子首次在 [SRIVASTAVA2014] 中引入,作为一种在低数据量场景(即正则化)下提高深度神经网络性能的方法。
它将一个向量作为输入,并产生一个形状相同的向量作为输出。输出中的每个标量都有 \(p\) 的概率被改为零,否则则从输入复制。这迫使网络即使仅使用输入中 \(1 - p\) 的标量也能表现良好。
在评估时,我们希望使用网络的全部能力,因此设置 \(p=0\)。朴素地这样做会增加输出的范数(这可能是件坏事,例如可能导致输出 softmax 温度人为降低)。为了防止这种情况,我们将输出乘以 \(\frac{1}{1 - p}\),这样无论 Dropout 概率如何,范数都保持一致。
我们首先来看一下基线实现。
import tabulate
import torch
import triton
import triton.language as tl
DEVICE = triton.runtime.driver.active.get_active_torch_device()
@triton.jit
def _dropout(
x_ptr, # pointer to the input
x_keep_ptr, # pointer to a mask of 0s and 1s
output_ptr, # pointer to the output
n_elements, # number of elements in the `x` tensor
p, # probability that an element of `x` is changed to zero
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
# Load data
x = tl.load(x_ptr + offsets, mask=mask)
x_keep = tl.load(x_keep_ptr + offsets, mask=mask)
# The line below is the crucial part, described in the paragraph above!
output = tl.where(x_keep, x / (1 - p), 0.0)
# Write-back output
tl.store(output_ptr + offsets, output, mask=mask)
def dropout(x, x_keep, p):
output = torch.empty_like(x)
assert x.is_contiguous()
n_elements = x.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
_dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024)
return output
# Input tensor
x = torch.randn(size=(10, ), device=DEVICE)
# Dropout mask
p = 0.5
x_keep = (torch.rand(size=(10, ), device=DEVICE) > p).to(torch.int32)
#
output = dropout(x, x_keep=x_keep, p=p)
print(tabulate.tabulate([
["input"] + x.tolist(),
["keep mask"] + x_keep.tolist(),
["output"] + output.tolist(),
]))
/home/runner/_work/triton/triton/python/triton/language/semantic.py:1696: UserWarning: tl.where with a non-boolean condition is deprecated and will error out in a future triton release. Got int32
warnings.warn(
--------- --------- ------- -------- ------- -------- ------- --------- -------- -------- -------
input -0.940469 0.17792 0.529538 0.13197 0.135063 1.64092 -0.309264 0.618883 -1.53066 0.46037
keep mask 0 0 0 0 0 1 0 0 1 1
output 0 0 0 0 0 3.28183 0 0 -3.06132 0.92074
--------- --------- ------- -------- ------- -------- ------- --------- -------- -------- -------
带种子的 Dropout¶
上面的 Dropout 实现工作正常,但处理起来可能有点麻烦。首先,我们需要存储 Dropout 掩码用于反向传播。其次,在使用重计算/检查点时,Dropout 状态管理会变得非常棘手(例如,请参阅 https://pytorch.ac.cn/docs/stable/checkpoint.html 中关于 preserve_rng_state 的所有注意事项)。在本教程中,我们将介绍一种替代实现,它 (1) 内存占用更小;(2) 需要更少的数据移动;并且 (3) 简化了在多次调用内核时保持随机性不变的管理。
Triton 中的伪随机数生成很简单!在本教程中,我们将使用 triton.language.rand
函数,该函数给定一个种子和一个 int32
偏移块,可以生成一个在 [0, 1) 范围内均匀分布的 float32
值块。如果您需要,Triton 还提供了其他随机数生成策略。
注意
Triton 的伪随机数生成 (PRNG) 实现基于 Philox 算法(详见 [SALMON2011])。
让我们把它们组合起来。
@triton.jit
def _seeded_dropout(
x_ptr,
output_ptr,
n_elements,
p,
seed,
BLOCK_SIZE: tl.constexpr,
):
# compute memory offsets of elements handled by this instance
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# load data from x
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
# randomly prune it
random = tl.rand(seed, offsets)
x_keep = random > p
# write-back
output = tl.where(x_keep, x / (1 - p), 0.0)
tl.store(output_ptr + offsets, output, mask=mask)
def seeded_dropout(x, p, seed):
output = torch.empty_like(x)
assert x.is_contiguous()
n_elements = x.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
_seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024)
return output
x = torch.randn(size=(10, ), device=DEVICE)
# Compare this to the baseline - dropout mask is never instantiated!
output = seeded_dropout(x, p=0.5, seed=123)
output2 = seeded_dropout(x, p=0.5, seed=123)
output3 = seeded_dropout(x, p=0.5, seed=512)
print(
tabulate.tabulate([
["input"] + x.tolist(),
["output (seed = 123)"] + output.tolist(),
["output (seed = 123)"] + output2.tolist(),
["output (seed = 512)"] + output3.tolist(),
]))
------------------- ------- --------- --------- ------- -------- -------- ------- -------- ------- ---------
input 1.48333 -0.239537 -0.640795 1.62631 0.263036 -0.71516 1.99474 -1.09546 1.81107 -0.170083
output (seed = 123) 0 -0.479074 0 0 0 -1.43032 0 0 3.62215 -0.340165
output (seed = 123) 0 -0.479074 0 0 0 -1.43032 0 0 3.62215 -0.340165
output (seed = 512) 0 0 -1.28159 3.25261 0 -1.43032 3.98947 0 0 0
------------------- ------- --------- --------- ------- -------- -------- ------- -------- ------- ---------
Et Voilà! 我们有了一个 Triton 内核,只要种子相同,它就会应用相同的 Dropout 掩码!如果您想进一步探索伪随机数在 GPU 编程中的应用,我们鼓励您探索 python/triton/language/random.py!
练习¶
扩展内核以处理矩阵,并使用一个种子向量——每行一个种子。
添加对步进 (striding) 的支持。
(挑战) 实现一个用于稀疏 Johnson-Lindenstrauss 变换的内核,该内核每次都使用一个种子即时生成投影矩阵。
参考文献¶
John K. Salmon, Mark A. Moraes, Ron O. Dror, 和 David E. Shaw, “并行随机数:易如反掌”, 2011
Nitish Srivastava 和 Geoffrey Hinton 和 Alex Krizhevsky 和 Ilya Sutskever 和 Ruslan Salakhutdinov, “Dropout: 一种防止神经网络过拟合的简单方法”, JMLR 2014
脚本总运行时间: (0 分钟 0.740 秒)