PyTorch Benchmark#

基准测试(Benchmark)是编写代码的重要步骤。它帮助验证代码是否满足性能预期,比较解决相同问题的不同方法,防止性能倒退。

当涉及到 PyTorch 代码的基准测试时,有许多选项,包括 Python 内置的 timeit 模块。然而,对 PyTorch 代码进行基准测试有许多容易被忽略的注意事项,例如管理线程数量和同步 CUDA 设备。此外,为基准测试生成张量输入可能相当乏味。

本教程演示了如何使用 PyTorch benchmark 模块来避免常见错误,同时更容易比较不同代码的性能,为基准测试生成输入等。

定义 benchmark 函数#

比较使用现有 torch 算子实现 torch.dot 两种方法:一种方法使用 mulsum 的组合,而另一种方法将问题归约到 bmm

import torch


def batched_dot_mul_sum(a, b):
    '''Computes batched dot by multiplying and summing'''
    return a.mul(b).sum(-1)


def batched_dot_bmm(a, b):
    '''Computes batched dot by reducing to bmm'''
    a = a.reshape(-1, 1, a.shape[-1])
    b = b.reshape(-1, b.shape[-1], 1)
    return torch.bmm(a, b).flatten(-3)


# 输入的基准测试
x = torch.randn(10000, 64)

# 确保两个函数计算相同的输出
assert batched_dot_mul_sum(x, x).allclose(batched_dot_bmm(x, x))

使用 timeit.Timer 作基准测试#

import timeit

t0 = timeit.Timer(
    stmt='batched_dot_mul_sum(x, x)',
    setup='from __main__ import batched_dot_mul_sum',
    globals={'x': x})

t1 = timeit.Timer(
    stmt='batched_dot_bmm(x, x)',
    setup='from __main__ import batched_dot_bmm',
    globals={'x': x})

print(f'mul_sum(x, x):  {t0.timeit(100) / 100 * 1e6:>5.1f} us')
print(f'bmm(x, x):      {t1.timeit(100) / 100 * 1e6:>5.1f} us')
mul_sum(x, x):  158.7 us
bmm(x, x):      113.3 us

使用 torch.utils.benchmark.Timer 作基准测试#

import torch.utils.benchmark as benchmark

t0 = benchmark.Timer(
    stmt='batched_dot_mul_sum(x, x)',
    setup='from __main__ import batched_dot_mul_sum',
    globals={'x': x})

t1 = benchmark.Timer(
    stmt='batched_dot_bmm(x, x)',
    setup='from __main__ import batched_dot_bmm',
    globals={'x': x})

print(t0.timeit(100))
print(t1.timeit(100))
<torch.utils.benchmark.utils.common.Measurement object at 0x7fc88e3b7d60>
batched_dot_mul_sum(x, x)
setup: from __main__ import batched_dot_mul_sum
  342.40 us
  1 measurement, 100 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7fc92d5ba8c0>
batched_dot_bmm(x, x)
setup: from __main__ import batched_dot_bmm
  800.27 us
  1 measurement, 100 runs , 1 thread

尽管 API 基本功能是相同的,但仍有一些重要的区别。benchmark.Timer.timeit() 返回每次运行的时间,而不是像 timeit.Timer.timeit() 那样返回总运行时。PyTorch benchmark 模块还提供了格式化的字符串表示,用于打印结果。

另一个重要的区别,也是导致结果分歧的原因是 PyTorch benchmark 模块默认运行在一个线程中。可以使用 num_threads 参数来更改线程数。

num_threads = torch.get_num_threads()
print(f'Benchmarking on {num_threads} threads')

t0 = benchmark.Timer(
    stmt='batched_dot_mul_sum(x, x)',
    setup='from __main__ import batched_dot_mul_sum',
    globals={'x': x},
    num_threads=num_threads,
    label='Multithreaded batch dot',
    sub_label='Implemented using mul and sum')

t1 = benchmark.Timer(
    stmt='batched_dot_bmm(x, x)',
    setup='from __main__ import batched_dot_bmm',
    globals={'x': x},
    num_threads=num_threads,
    label='Multithreaded batch dot',
    sub_label='Implemented using bmm')

print(t0.timeit(100))
print(t1.timeit(100))
Benchmarking on 24 threads
<torch.utils.benchmark.utils.common.Measurement object at 0x7fc93e8aa7d0>
Multithreaded batch dot: Implemented using mul and sum
setup: from __main__ import batched_dot_mul_sum
  154.94 us
  1 measurement, 100 runs , 24 threads
<torch.utils.benchmark.utils.common.Measurement object at 0x7fc92d7d9390>
Multithreaded batch dot: Implemented using bmm
setup: from __main__ import batched_dot_bmm
  137.89 us
  1 measurement, 100 runs , 24 threads

在所有线程都可用的情况下运行 benchmark 会得到与 timeit 模块类似的结果。更重要的是,哪个版本更快取决于我们运行代码的线程数。这就是为什么用代表实际用例的线程设置来对代码进行基准测试是很重要的。另一件需要记住的重要事情是在 GPU 上进行基准测试时同步 CPU 和 CUDA。让我们在 CUDA 张量上再次运行上述基准测试,看看会发生什么。

x = torch.randn(10000, 1024, device='cuda')

t0 = timeit.Timer(
    stmt='batched_dot_mul_sum(x, x)',
    setup='from __main__ import batched_dot_mul_sum',
    globals={'x': x})

t1 = timeit.Timer(
    stmt='batched_dot_bmm(x, x)',
    setup='from __main__ import batched_dot_bmm',
    globals={'x': x})

# Ran each twice to show difference before/after warmup
print(f'mul_sum(x, x):  {t0.timeit(100) / 100 * 1e6:>5.1f} us')
print(f'mul_sum(x, x):  {t0.timeit(100) / 100 * 1e6:>5.1f} us')
print(f'bmm(x, x):      {t1.timeit(100) / 100 * 1e6:>5.1f} us')
print(f'bmm(x, x):      {t1.timeit(100) / 100 * 1e6:>5.1f} us')
mul_sum(x, x):  514.0 us
mul_sum(x, x):   27.7 us
bmm(x, x):      7679.2 us
bmm(x, x):       35.8 us
t0 = benchmark.Timer(
    stmt='batched_dot_mul_sum(x, x)',
    setup='from __main__ import batched_dot_mul_sum',
    globals={'x': x})

t1 = benchmark.Timer(
    stmt='batched_dot_bmm(x, x)',
    setup='from __main__ import batched_dot_bmm',
    globals={'x': x})

# Run only once since benchmark module does warmup for us
print(t0.timeit(100))
print(t1.timeit(100))
<torch.utils.benchmark.utils.common.Measurement object at 0x7fc88e3b7e20>
batched_dot_mul_sum(x, x)
setup: from __main__ import batched_dot_mul_sum
  231.28 us
  1 measurement, 100 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7fc88e3b7460>
batched_dot_bmm(x, x)
setup: from __main__ import batched_dot_bmm
  249.74 us
  1 measurement, 100 runs , 1 thread

结果揭示了一些有趣的事情。使用 timeit 模块的 bmm 版本的第一次运行比第二次运行要长得多。这是因为 bmm 调用 cuBLAS 需要在第一次调用时加载,这需要一些时间。这就是为什么在进行基准测试之前进行热身是很重要的,幸运的是,PyTorch 的基准测试模块会处理这些问题。

timeit 和基准测试模块之间的结果差异是因为 timeit 模块没有同步 CUDA,因此只是计时内核的启动时间。PyTorch 的基准测试模块为我们完成同步。

Blocked Autorange 基准测试#

timeit.Timer.autorange 需要至少 0.2 秒的单个连续测量,torch.utils.benchmark.blocked_autorange 执行许多度量,这些度量的总时间至少为 0.2 秒(可以通过 min_run_time 参数更改),这取决于计时开销只占整体度量的一小部分的约束。这是通过首先在每个循环中运行不断增加的运行次数来实现的,直到运行时远远大于度量开销(这也可以作为热身),然后进行度量,直到达到目标时间。这有一个有用的特性,它浪费较少的数据,并允许我们计算统计数据来估计测量的可靠性。

m0 = t0.blocked_autorange()
m1 = t1.blocked_autorange()

print(m0)
print(m1)
<torch.utils.benchmark.utils.common.Measurement object at 0x7fc93e8a9690>
batched_dot_mul_sum(x, x)
setup: from __main__ import batched_dot_mul_sum
  229.31 us
  1 measurement, 1000 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7fc88e3b75e0>
batched_dot_bmm(x, x)
setup: from __main__ import batched_dot_bmm
  Median: 182.84 us
  2 measurements, 1000 runs per measurement, 1 thread

我们还可以从返回的度量对象中检查单个统计信息。

print(f"Mean:   {m0.mean * 1e6:6.2f} us")
print(f"Median: {m0.median * 1e6:6.2f} us")
Mean:   229.31 us
Median: 229.31 us

比较基准测试结果#

到目前为止,我们一直在将两个版本的批处理 dot 与单个输入进行比较。在实践中,我们希望尝试输入的组合以及不同数量的线程。Compare 类帮助在格式化的表中显示许多度量的结果。它使用上面描述的注释 (labelsub_labelnum_threads 等)以及描述对表进行分组和组织。让我们使用 Compare 来查看函数在不同输入大小和线程数量下的执行情况。

from itertools import product

# Compare takes a list of measurements which we'll save in results.
results = []

sizes = [1, 64, 1024, 10000]
for b, n in product(sizes, sizes):
    # label and sub_label are the rows
    # description is the column
    label = 'Batched dot'
    sub_label = f'[{b}, {n}]'
    x = torch.ones((b, n))
    for num_threads in [1, 4, 16, 32]:
        results.append(benchmark.Timer(
            stmt='batched_dot_mul_sum(x, x)',
            setup='from __main__ import batched_dot_mul_sum',
            globals={'x': x},
            num_threads=num_threads,
            label=label,
            sub_label=sub_label,
            description='mul/sum',
        ).blocked_autorange(min_run_time=1))
        results.append(benchmark.Timer(
            stmt='batched_dot_bmm(x, x)',
            setup='from __main__ import batched_dot_bmm',
            globals={'x': x},
            num_threads=num_threads,
            label=label,
            sub_label=sub_label,
            description='bmm',
        ).blocked_autorange(min_run_time=1))

compare = benchmark.Compare(results)
compare.print()
[--------------- Batched dot ----------------]
                      |  mul/sum   |    bmm   
1 threads: -----------------------------------
      [1, 1]          |       5.2  |       9.1
      [1, 64]         |       5.7  |       9.1
      [1, 1024]       |       5.8  |      10.6
      [1, 10000]      |      10.2  |      11.8
      [64, 1]         |       9.7  |      15.7
      [64, 64]        |       7.5  |      14.4
      [64, 1024]      |      36.5  |     230.4
      [64, 10000]     |     294.3  |    2105.8
      [1024, 1]       |       7.0  |      16.7
      [1024, 64]      |      42.5  |      88.0
      [1024, 1024]    |     482.4  |    3505.7
      [1024, 10000]   |   27680.6  |   34140.8
      [10000, 1]      |      16.9  |      81.1
      [10000, 64]     |     338.7  |     746.7
      [10000, 1024]   |   27325.7  |   34208.9
      [10000, 10000]  |  333021.2  |  333304.7
4 threads: -----------------------------------
      [1, 1]          |       5.2  |       9.3
      [1, 64]         |       5.7  |       9.1
      [1, 1024]       |       5.9  |      12.0
      [1, 10000]      |      11.9  |      19.7
      [64, 1]         |       9.8  |      20.6
      [64, 64]        |       7.6  |      15.7
      [64, 1024]      |      37.9  |     319.9
      [64, 10000]     |      95.1  |    2994.2
      [1024, 1]       |       6.8  |      15.7
      [1024, 64]      |      46.2  |      36.7
      [1024, 1024]    |     139.4  |     991.1
      [1024, 10000]   |   10841.3  |    9254.5
      [10000, 1]      |      16.9  |      32.4
      [10000, 64]     |     105.1  |     217.0
      [10000, 1024]   |   11494.8  |    9471.9
      [10000, 10000]  |  108579.2  |   92878.2
16 threads: ----------------------------------
      [1, 1]          |       5.2  |       9.2
      [1, 64]         |       5.6  |      15.5
      [1, 1024]       |       6.3  |      14.3
      [1, 10000]      |      15.0  |      21.2
      [64, 1]         |       9.9  |      17.1
      [64, 64]        |      12.4  |      23.9
      [64, 1024]      |      49.7  |     717.2
      [64, 10000]     |      56.9  |    5740.0
      [1024, 1]       |       7.3  |      24.7
      [1024, 64]      |      43.0  |      24.7
      [1024, 1024]    |      75.3  |     299.0
      [1024, 10000]   |    8515.2  |    2623.2
      [10000, 1]      |      22.8  |      31.5
      [10000, 64]     |      62.6  |      83.5
      [10000, 1024]   |    8212.7  |    2638.4
      [10000, 10000]  |   85685.6  |   24352.7
32 threads: ----------------------------------
      [1, 1]          |       5.4  |      10.0
      [1, 64]         |       5.8  |       9.3
      [1, 1024]       |       6.3  |      11.3
      [1, 10000]      |      15.2  |      19.1
      [64, 1]         |       5.8  |      29.5
      [64, 64]        |       7.5  |      27.2
      [64, 1024]      |      65.6  |     643.7
      [64, 10000]     |     157.0  |    8721.4
      [1024, 1]       |       6.9  |      27.5
      [1024, 64]      |      65.8  |      33.1
      [1024, 1024]    |      95.3  |     213.9
      [1024, 10000]   |    7574.0  |    1503.8
      [10000, 1]      |      17.2  |      31.1
      [10000, 64]     |     140.7  |      93.3
      [10000, 1024]   |    7618.5  |    1396.9
      [10000, 10000]  |   77288.0  |   18593.7

Times are in microseconds (us).

上面的结果表明,对于多线程上运行的大张量,归约到 bmm 的版本更好,而对于较小的和/或单线程代码,另一个版本更好。

Compare 还提供了更改表格格式的函数:

compare.trim_significant_figures()
compare.colorize()
compare.print()
[-------------- Batched dot --------------]
                      |  mul/sum  |   bmm  
1 threads: --------------------------------
      [1, 1]          |        5  |       9
      [1, 64]         |        6  |       9
      [1, 1024]       |        6  |      11
      [1, 10000]      |       10  |      12
      [64, 1]         |       10  |      16
      [64, 64]        |        8  |      10
      [64, 1024]      |       36  |     230
      [64, 10000]     |      294  |    2100
      [1024, 1]       |        7  |      17
      [1024, 64]      |       42  |      88
      [1024, 1024]    |      480  |    3500
      [1024, 10000]   |    28000  |   30000
      [10000, 1]      |       17  |      81
      [10000, 64]     |      340  |     750
      [10000, 1024]   |    27000  |   30000
      [10000, 10000]  |   300000  |  300000
4 threads: --------------------------------
      [1, 1]          |        5  |       9
      [1, 64]         |        6  |       9
      [1, 1024]       |        6  |      12
      [1, 10000]      |       10  |      20
      [64, 1]         |       10  |      21
      [64, 64]        |        8  |      16
      [64, 1024]      |       38  |     320
      [64, 10000]     |       95  |    3000
      [1024, 1]       |        7  |      16
      [1024, 64]      |       46  |      40
      [1024, 1024]    |      139  |     990
      [1024, 10000]   |    11000  |    9300
      [10000, 1]      |       17  |      30
      [10000, 64]     |      105  |     220
      [10000, 1024]   |    11490  |    9470
      [10000, 10000]  |   100000  |   90000
16 threads: -------------------------------
      [1, 1]          |        5  |       9
      [1, 64]         |        6  |      16
      [1, 1024]       |        6  |      14
      [1, 10000]      |       10  |      21
      [64, 1]         |       10  |      17
      [64, 64]        |       12  |      20
      [64, 1024]      |       50  |     700
      [64, 10000]     |       60  |    6000
      [1024, 1]       |        7  |      25
      [1024, 64]      |       40  |      20
      [1024, 1024]    |       80  |     300
      [1024, 10000]   |     8520  |    2620
      [10000, 1]      |       23  |      30
      [10000, 64]     |       63  |      80
      [10000, 1024]   |     8210  |    2600
      [10000, 10000]  |    86000  |   24400
32 threads: -------------------------------
      [1, 1]          |        5  |      10
      [1, 64]         |        6  |       9
      [1, 1024]       |        6  |      11
      [1, 10000]      |       20  |      19
      [64, 1]         |        6  |      30
      [64, 64]        |        8  |      27
      [64, 1024]      |       66  |     600
      [64, 10000]     |      160  |    9000
      [1024, 1]       |        7  |      30
      [1024, 64]      |       70  |      33
      [1024, 1024]    |       95  |     210
      [1024, 10000]   |     7600  |    2000
      [10000, 1]      |       17  |      31
      [10000, 64]     |      100  |      90
      [10000, 1024]   |     7600  |    1000
      [10000, 10000]  |    80000  |   20000

Times are in microseconds (us).

保证和加载基准测试结果#

Measurement(以及 CallgrindStats)是可 pickle 的。这使得 A/B 测试变得容易,因为您可以从两个不同的环境中收集度量值,pickle 它们,然后将它们加载到一个环境中。Timer 甚至接受一个 env 构造函数参数,以便这样的 A/B 测试能够无缝地工作。

让我们想象一下,add/sumbmm 方法不是在两个 Python 函数中,而是在 PyTorch 的两个不同版本中。下面的例子演示了 A/B 如何测试它们。为了简单起见,我们只使用了形状的子集,并通过 pickle 简单地往返传递结果,而不是实际使用多个环境并将结果写入磁盘。

import pickle

ab_test_results = []
for env in ('environment A: mul/sum', 'environment B: bmm'):
    for b, n in ((1, 1), (1024, 10000), (10000, 1)):
        x = torch.ones((b, n))
        dot_fn = (batched_dot_mul_sum if env == 'environment A: mul/sum' else batched_dot_bmm)
        m = benchmark.Timer(
            stmt='batched_dot(x, x)',
            globals={'x': x, 'batched_dot': dot_fn},
            num_threads=1,
            label='Batched dot',
            description=f'[{b}, {n}]',
            env=env,
        ).blocked_autorange(min_run_time=1)
        ab_test_results.append(pickle.dumps(m))

ab_results = [pickle.loads(i) for i in ab_test_results]
compare = benchmark.Compare(ab_results)
compare.trim_significant_figures()
compare.colorize()
compare.print()
[------------------------------------- Batched dot -------------------------------------]
                                               |  [1, 1]  |  [1024, 10000]  |  [10000, 1]
1 threads: ------------------------------------------------------------------------------
  (environment A: mul/sum)  batched_dot(x, x)  |   5.3    |      28000      |      17    
  (environment B: bmm)      batched_dot(x, x)  |   9.3    |      34000      |      89    

Times are in microseconds (us).
# And just to show that we can round trip all of the results from earlier:
round_tripped_results = pickle.loads(pickle.dumps(results))
assert(str(benchmark.Compare(results)) == str(benchmark.Compare(round_tripped_results)))

生成带有模糊参数的输入#

正如我们在前一节中看到的,根据输入张量的不同,可能会有一些明显的性能差异。因此,在许多不同的输入上运行基准测试是一个好主意。然而,创建所有这些输入张量可能是乏味的,这就是 torch.utils.benchmark.Fuzzer 和相关类的作用。让我们看看如何使用 Fuzzer 为基准测试创建一些测试用例。

from torch.utils.benchmark import Fuzzer, FuzzedParameter, FuzzedTensor, ParameterAlias

# Generates random tensors with 128 to 10000000 elements and sizes k0 and k1 chosen from a
# loguniform distribution in [1, 10000], 40% of which will be discontiguous on average.
example_fuzzer = Fuzzer(
    parameters = [
        FuzzedParameter('k0', minval=1, maxval=10000, distribution='loguniform'),
        FuzzedParameter('k1', minval=1, maxval=10000, distribution='loguniform'),
    ],
    tensors = [
        FuzzedTensor('x', size=('k0', 'k1'), min_elements=128, max_elements=10000000, probability_contiguous=0.6)
    ],
    seed=0,
)

results = []
for tensors, tensor_params, params in example_fuzzer.take(10):
    # description is the column label
    sub_label=f"{params['k0']:<6} x {params['k1']:<4} {'' if tensor_params['x']['is_contiguous'] else '(discontiguous)'}"
    results.append(benchmark.Timer(
        stmt='batched_dot_mul_sum(x, x)',
        setup='from __main__ import batched_dot_mul_sum',
        globals=tensors,
        label='Batched dot',
        sub_label=sub_label,
        description='mul/sum',
    ).blocked_autorange(min_run_time=1))
    results.append(benchmark.Timer(
        stmt='batched_dot_bmm(x, x)',
        setup='from __main__ import batched_dot_bmm',
        globals=tensors,
        label='Batched dot',
        sub_label=sub_label,
        description='bmm',
    ).blocked_autorange(min_run_time=1))

compare = benchmark.Compare(results)
compare.trim_significant_figures()
compare.print()
[--------------------- Batched dot ---------------------]
                                     |  mul/sum  |   bmm 
1 threads: ----------------------------------------------
      725    x 257                   |      94   |    200
      49     x 383                   |      14   |     31
      34     x 1468                  |      32   |    180
      187    x 5039                  |     430   |   3100
      2140   x 1296 (discontiguous)  |    1900   |  73000
      78     x 1598                  |      62   |    430
      519    x 763                   |     180   |   1300
      141    x 1082                  |      74   |    540
      78     x 5    (discontiguous)  |       7   |     11
      187    x 1                     |       7   |     12

Times are in microseconds (us).

定义自己的 Fuzzers 有很大的灵活性,这对于创建一组强大的基准测试输入非常有用。但是为了使事情更简单,PyTorch 基准测试模块提供了一些内建 Fuzzers 来满足常见的基准测试需求。让我们来看看如何使用这些内置模糊器。

from torch.utils.benchmark.op_fuzzers import binary

results = []
for tensors, tensor_params, params in binary.BinaryOpFuzzer(seed=0).take(10):
    sub_label=f"{params['k0']:<6} x {params['k1']:<4} {'' if tensor_params['x']['is_contiguous'] else '(discontiguous)'}"
    results.append(benchmark.Timer(
        stmt='batched_dot_mul_sum(x, x)',
        setup='from __main__ import batched_dot_mul_sum',
        globals=tensors,
        label='Batched dot',
        sub_label=sub_label,
        description='mul/sum',
    ).blocked_autorange(min_run_time=1))
    results.append(benchmark.Timer(
        stmt='batched_dot_bmm(x, x)',
        setup='from __main__ import batched_dot_bmm',
        globals=tensors,
        label='Batched dot',
        sub_label=sub_label,
        description='bmm',
    ).blocked_autorange(min_run_time=1))

compare = benchmark.Compare(results)
compare.trim_significant_figures()
compare.colorize(rowwise=True)
compare.print()
[----------------------- Batched dot ------------------------]
                                         |  mul/sum  |   bmm  
1 threads: ---------------------------------------------------
      64     x 473  (discontiguous)      |   14000   |   80000
      16384  x 12642115 (discontiguous)  |      32   |     100
      8192   x 892                       |    7100   |   24300
      512    x 64   (discontiguous)      |   98000   |  340000
      493    x 27   (discontiguous)      |    2100   |    4950
      118    x 32   (discontiguous)      |     890   |    2790
      16     x 495  (discontiguous)      |   28000   |   37000
      488    x 62374                     |   93000   |  100000
      240372 x 69                        |   47000   |   20000
      40156  x 32   (discontiguous)      |    1800   |    5300

Times are in microseconds (us).

使用 Callgrind 收集指令计数#

优化代码的挑战之一是墙时间的变化和不透明。不确定性的来源有很多,从自适应时钟速度到与其他进程的资源争用。此外,端到端时间不能洞察时间花在什么地方,而这正是我们在优化代码时真正感兴趣的。

另一种补充方法是收集指令计数(instruction counts)。这些计数是一个代理指标,并没有捕捉性能的所有方面(例如内存或 I/O 绑定任务),但是它们确实有一些有用的属性。指令计数是可重复的,不受环境变化的影响,并提供对程序在何处花费周期的细粒度洞察。

为了了解指令计数的效用,让我们看看如何减少 batched_dot_mul_sum 的开销。显而易见的解决方案是将其转移到 C++,这样我们就可以避免多次在 Python 和 C++ 之间切换。

幸运的是,源码几乎是相同的。在 C++ 中我们必须要问的一个问题是,我们是应该通过值还是引用来获取参数。

batched_dot_src = """\
/* ---- Python ---- */
// def batched_dot_mul_sum(a, b):
//     return a.mul(b).sum(-1)

torch::Tensor batched_dot_mul_sum_v0(
    const torch::Tensor a,
    const torch::Tensor b) {
  return a.mul(b).sum(-1);
}

torch::Tensor batched_dot_mul_sum_v1(
    const torch::Tensor& a,
    const torch::Tensor& b) {
  return a.mul(b).sum(-1);
}
"""


# PyTorch makes it easy to test our C++ implementations by providing a utility
# to JIT compile C++ source into Python extensions:
import os
from torch.utils import cpp_extension
cpp_lib = cpp_extension.load_inline(
    name='cpp_lib',
    cpp_sources=batched_dot_src,
    extra_cflags=['-O3'],
    extra_include_paths=[
        # `load_inline` needs to know where to find Pybind11 headers.
        os.path.join(os.getenv('CONDA_PREFIX'), 'include')
    ],
    functions=['batched_dot_mul_sum_v0', 'batched_dot_mul_sum_v1']
)

# `load_inline` will create a shared object that is loaded into Python. When we collect
# instruction counts Timer will create a subprocess, so we need to re-import it. The
# import process is slightly more complicated for C extensions, but that's all we're
# doing here.
module_import_str = f"""\
# https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path
import importlib.util
spec = importlib.util.spec_from_file_location("cpp_lib", {repr(cpp_lib.__file__)})
cpp_lib = importlib.util.module_from_spec(spec)
spec.loader.exec_module(cpp_lib)"""

import textwrap
def pretty_print(result):
    """Import machinery for cpp_lib.so can get repetitive to look at."""
    print(repr(result).replace(textwrap.indent(module_import_str, "  "), "  import cpp_lib"))


t_baseline = benchmark.Timer(
    stmt='batched_dot_mul_sum(x, x)',
    setup='''\
from __main__ import batched_dot_mul_sum
x = torch.randn(2, 2)''')

t0 = benchmark.Timer(
    stmt='cpp_lib.batched_dot_mul_sum_v0(x, x)',
    setup=f'''\
{module_import_str}
x = torch.randn(2, 2)''')

t1 = benchmark.Timer(
    stmt='cpp_lib.batched_dot_mul_sum_v1(x, x)',
    setup=f'''\
{module_import_str}
x = torch.randn(2, 2)''')

# Moving to C++ did indeed reduce overhead, but it's hard to tell which
# calling convention is more efficient. v1 (call with references) seems to
# be a bit faster, but it's within measurement error.
pretty_print(t_baseline.blocked_autorange())
pretty_print(t0.blocked_autorange())
pretty_print(t1.blocked_autorange())
<torch.utils.benchmark.utils.common.Measurement object at 0x7fc8811b4dc0>
batched_dot_mul_sum(x, x)
setup:
  from __main__ import batched_dot_mul_sum
  x = torch.randn(2, 2)

  5.22 us
  1 measurement, 100000 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7fc92d5ba830>
cpp_lib.batched_dot_mul_sum_v0(x, x)
setup:
  import cpp_lib
  x = torch.randn(2, 2)

  4.35 us
  1 measurement, 100000 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7fc88118aef0>
cpp_lib.batched_dot_mul_sum_v1(x, x)
setup:
  import cpp_lib
  x = torch.randn(2, 2)

  4.11 us
  1 measurement, 100000 runs , 1 thread
# Let's use Callgrind to determine which is better.
stats_v0 = t0.collect_callgrind()
stats_v1 = t1.collect_callgrind()

pretty_print(stats_v0)
pretty_print(stats_v1)

# `.as_standardized` removes file names and some path prefixes, and makes
# it easier to read the function symbols.
stats_v0 = stats_v0.as_standardized()
stats_v1 = stats_v1.as_standardized()

# `.delta` diffs the instruction counts, and `.denoise` removes several
# functions in the Python interpreter that are known to have significant
# jitter.
delta = stats_v1.delta(stats_v0).denoise()

# `.transform` is a convenience API for transforming function names. It is
# useful for increasing cancelation when diff-ing instructions, as well as
# just generally improving readability.
replacements = (
    ("???:void pybind11", "pybind11"),
    ("batched_dot_mul_sum_v0", "batched_dot_mul_sum_v1"),
    ("at::Tensor, at::Tensor", "..."),
    ("at::Tensor const&, at::Tensor const&", "..."),
    ("auto torch::detail::wrap_pybind_function_impl_", "wrap_pybind_function_impl_"),
)
for before, after in replacements:
    delta = delta.transform(lambda l: l.replace(before, after))

# We can use print options to control how much of the function to display.
torch.set_printoptions(linewidth=160)

# Once parsed, the instruction counts make clear that passing `a` and `b`
# by reference is more efficient as it skips some c10::TensorImpl bookkeeping
# for the intermediate Tensors, and is also works better with PyBind11. This
# is consistent with our noisy wall time observations.
print(delta)