Memory Efficient Attention Implementation

Memory Efficient SDPA Benchmark

▼ code ▼ output ▶ uv-logs | Cell: benchmark | 3.94s | Raw GitHub
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "numpy",
#     "torch==2.8.0",
#     "kernels-benchmark-tools",
# ]
#
# [tool.uv.sources]
# kernels-benchmark-tools = { path = "../../../../../tools", editable = true }
# ///
import torch
import sys
from kernels_benchmark_tools import KernelTypeEnum, run_benchmark


def torch_mem_eff(q, k, v):
    qt, kt, vt = (x.transpose(1, 2).contiguous() for x in (q, k, v))
    with torch.nn.attention.sdpa_kernel(
        torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION
    ):
        o = torch.nn.functional.scaled_dot_product_attention(qt, kt, vt)
    return o.transpose(1, 2).contiguous()


run_benchmark(
    kernel_type=KernelTypeEnum.ATTENTION,
    impl_name="torch_mem_eff",
    impl_tags={"family": "torch-sdpa", "backend": "EFFICIENT", "compile": "none"},
    impl_func=torch_mem_eff,
)
Running attention benchmark on cuda with 6 workloads.

======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L128_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         5.20%     361.468us        33.36%       2.319ms       2.319ms       0.000us         0.00%       5.387ms       5.387ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       5.370ms       100.63%       5.370ms       5.370ms             1  
                     aten::scaled_dot_product_attention         0.48%      33.240us         2.68%     186.333us      62.111us       0.000us         0.00%       4.719ms       1.573ms             3  
          aten::_scaled_dot_product_efficient_attention         0.35%      24.389us         2.20%     153.093us      51.031us       0.000us         0.00%       4.719ms       1.573ms             3  
                     aten::_efficient_attention_forward         0.53%      37.120us         1.50%     104.111us      34.704us       4.719ms        88.44%       4.719ms       1.573ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       4.719ms        88.44%       4.719ms       1.573ms             3  
                                       aten::contiguous         0.18%      12.841us        24.53%       1.706ms     189.522us       0.000us         0.00%     667.809us      74.201us             9  
                                            aten::clone         0.46%      31.899us        24.35%       1.693ms     188.095us       0.000us         0.00%     667.809us      74.201us             9  
                                            aten::copy_         1.13%      78.352us        22.86%       1.589ms     176.604us     617.121us        11.56%     667.809us      74.201us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     617.121us        11.56%     617.121us      68.569us             9  
                                Activity Buffer Request        20.52%       1.427ms        20.52%       1.427ms       1.427ms      50.688us         0.95%      50.688us      50.688us             1  
                                        aten::transpose         0.98%      68.237us         1.30%      90.074us       3.753us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.31%      21.837us         0.31%      21.837us       0.910us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.25%      17.541us         1.03%      71.521us       7.947us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         1.19%      82.429us         1.19%      82.429us       3.925us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         1.61%     111.770us         1.61%     111.770us       9.314us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.05%       3.512us         0.05%       3.512us       1.171us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.11%       7.660us         0.11%       7.660us       2.553us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        66.64%       4.633ms        66.64%       4.633ms       4.633ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 6.952ms
Self CUDA time total: 5.336ms



======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L256_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         3.61%     259.378us        29.44%       2.116ms       2.116ms       0.000us         0.00%       5.734ms       5.734ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       5.688ms       100.15%       5.688ms       5.688ms             1  
                     aten::scaled_dot_product_attention         0.27%      19.560us         2.06%     147.832us      49.277us       0.000us         0.00%       5.042ms       1.681ms             3  
          aten::_scaled_dot_product_efficient_attention         0.27%      19.340us         1.78%     128.272us      42.757us       0.000us         0.00%       5.042ms       1.681ms             3  
                     aten::_efficient_attention_forward         0.39%      28.380us         1.18%      84.990us      28.330us       5.042ms        88.79%       5.042ms       1.681ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       5.042ms        88.79%       5.042ms       1.681ms             3  
                                       aten::contiguous         0.11%       8.118us        23.11%       1.661ms     184.525us       0.000us         0.00%     691.453us      76.828us             9  
                                            aten::clone         0.32%      22.761us        23.00%       1.653ms     183.623us       0.000us         0.00%     691.453us      76.828us             9  
                                            aten::copy_         0.95%      68.519us        21.65%       1.556ms     172.887us     636.925us        11.21%     691.453us      76.828us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     636.925us        11.21%     636.925us      70.769us             9  
                                Activity Buffer Request        19.69%       1.415ms        19.69%       1.415ms       1.415ms      54.528us         0.96%      54.528us      54.528us             1  
                                        aten::transpose         0.75%      54.034us         1.00%      71.792us       2.991us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.25%      17.758us         0.25%      17.758us       0.740us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.18%      12.992us         1.03%      73.863us       8.207us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         1.22%      87.512us         1.22%      87.512us       4.167us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         1.35%      96.951us         1.35%      96.951us       8.079us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.03%       2.489us         0.03%       2.489us       0.830us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.04%       3.130us         0.04%       3.130us       1.043us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        70.56%       5.071ms        70.56%       5.071ms       5.071ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 7.187ms
Self CUDA time total: 5.679ms



======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L320_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         3.31%     247.873us        28.16%       2.111ms       2.111ms       0.000us         0.00%       6.014ms       6.014ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       5.964ms       100.13%       5.964ms       5.964ms             1  
                     aten::scaled_dot_product_attention         0.26%      19.681us         1.94%     145.404us      48.468us       0.000us         0.00%       5.300ms       1.767ms             3  
          aten::_scaled_dot_product_efficient_attention         0.25%      18.780us         1.68%     125.723us      41.908us       0.000us         0.00%       5.300ms       1.767ms             3  
                     aten::_efficient_attention_forward         0.40%      29.910us         1.12%      83.752us      27.917us       5.300ms        89.00%       5.300ms       1.767ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       5.300ms        89.00%       5.300ms       1.767ms             3  
                                       aten::contiguous         0.10%       7.548us        22.32%       1.673ms     185.921us       0.000us         0.00%     713.444us      79.272us             9  
                                            aten::clone         0.29%      21.851us        22.22%       1.666ms     185.082us       0.000us         0.00%     713.444us      79.272us             9  
                                            aten::copy_         0.89%      66.441us        21.22%       1.591ms     176.813us     655.331us        11.00%     713.444us      79.272us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     655.331us        11.00%     655.331us      72.815us             9  
                                Activity Buffer Request        19.37%       1.452ms        19.37%       1.452ms       1.452ms      58.113us         0.98%      58.113us      58.113us             1  
                                        aten::transpose         0.68%      50.773us         0.90%      67.843us       2.827us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.23%      17.070us         0.23%      17.070us       0.711us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.16%      12.290us         0.70%      52.570us       5.841us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         0.87%      64.980us         0.87%      64.980us       3.094us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         1.28%      96.085us         1.28%      96.085us       8.007us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.03%       2.520us         0.03%       2.520us       0.840us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.04%       3.050us         0.04%       3.050us       1.017us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        71.84%       5.386ms        71.84%       5.386ms       5.386ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 7.498ms
Self CUDA time total: 5.956ms



======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L384_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         3.20%     247.803us        30.17%       2.338ms       2.338ms       0.000us         0.00%       6.050ms       6.050ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       6.000ms       100.13%       6.000ms       6.000ms             1  
                     aten::scaled_dot_product_attention         0.37%      28.670us         2.04%     158.093us      52.698us       0.000us         0.00%       5.339ms       1.780ms             3  
          aten::_scaled_dot_product_efficient_attention         0.26%      20.220us         1.67%     129.423us      43.141us       0.000us         0.00%       5.339ms       1.780ms             3  
                     aten::_efficient_attention_forward         0.38%      29.560us         1.08%      83.863us      27.954us       5.339ms        89.10%       5.339ms       1.780ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       5.339ms        89.10%       5.339ms       1.780ms             3  
                                       aten::contiguous         0.10%       7.610us        24.36%       1.887ms     209.722us       0.000us         0.00%     711.328us      79.036us             9  
                                            aten::clone         0.28%      21.914us        24.26%       1.880ms     208.876us       0.000us         0.00%     711.328us      79.036us             9  
                                            aten::copy_         0.87%      67.261us        23.30%       1.806ms     200.640us     653.248us        10.90%     711.328us      79.036us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     653.248us        10.90%     653.248us      72.583us             9  
                                Activity Buffer Request        18.39%       1.425ms        18.39%       1.425ms       1.425ms      58.080us         0.97%      58.080us      58.080us             1  
                                        aten::transpose         0.68%      52.310us         0.90%      69.650us       2.902us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.22%      17.340us         0.22%      17.340us       0.723us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.16%      12.088us         0.67%      52.209us       5.801us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         0.84%      64.993us         0.84%      64.993us       3.095us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         4.36%     337.546us         4.36%     337.546us      28.129us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.03%       2.491us         0.03%       2.491us       0.830us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.04%       3.020us         0.04%       3.020us       1.007us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        69.83%       5.411ms        69.83%       5.411ms       5.411ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 7.749ms
Self CUDA time total: 5.992ms



======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L448_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         3.22%     253.272us        29.03%       2.283ms       2.283ms       0.000us         0.00%       6.248ms       6.248ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       6.196ms       100.13%       6.196ms       6.196ms             1  
                     aten::scaled_dot_product_attention         0.25%      19.441us         2.25%     176.884us      58.961us       0.000us         0.00%       5.524ms       1.841ms             3  
          aten::_scaled_dot_product_efficient_attention         0.26%      20.811us         2.00%     157.443us      52.481us       0.000us         0.00%       5.524ms       1.841ms             3  
                     aten::_efficient_attention_forward         0.41%      31.883us         1.42%     111.902us      37.301us       5.524ms        89.27%       5.524ms       1.841ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       5.524ms        89.27%       5.524ms       1.841ms             3  
                                       aten::contiguous         0.10%       7.580us        22.97%       1.807ms     200.732us       0.000us         0.00%     724.035us      80.448us             9  
                                            aten::clone         0.28%      22.150us        22.88%       1.799ms     199.890us       0.000us         0.00%     724.035us      80.448us             9  
                                            aten::copy_         0.85%      67.019us        21.94%       1.725ms     191.709us     664.226us        10.73%     724.035us      80.448us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     664.226us        10.73%     664.226us      73.803us             9  
                                Activity Buffer Request        18.12%       1.425ms        18.12%       1.425ms       1.425ms      59.809us         0.97%      59.809us      59.809us             1  
                                        aten::transpose         0.68%      53.201us         0.91%      71.182us       2.966us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.23%      17.981us         0.23%      17.981us       0.749us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.15%      12.001us         0.65%      51.482us       5.720us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         0.81%      63.729us         0.81%      63.729us       3.035us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         3.60%     283.426us         3.60%     283.426us      23.619us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.03%       2.490us         0.03%       2.490us       0.830us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.04%       2.980us         0.04%       2.980us       0.993us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        70.97%       5.581ms        70.97%       5.581ms       5.581ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 7.864ms
Self CUDA time total: 6.188ms



======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L512_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         3.10%     256.636us        27.41%       2.272ms       2.272ms       0.000us         0.00%       6.685ms       6.685ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       6.632ms       100.12%       6.632ms       6.632ms             1  
                     aten::scaled_dot_product_attention         0.23%      18.791us         1.80%     149.483us      49.828us       0.000us         0.00%       5.954ms       1.985ms             3  
          aten::_scaled_dot_product_efficient_attention         0.24%      19.642us         1.58%     130.692us      43.564us       0.000us         0.00%       5.954ms       1.985ms             3  
                     aten::_efficient_attention_forward         0.40%      33.027us         1.05%      86.901us      28.967us       5.954ms        89.88%       5.954ms       1.985ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       5.954ms        89.88%       5.954ms       1.985ms             3  
                                       aten::contiguous         0.09%       7.531us        21.68%       1.797ms     199.660us       0.000us         0.00%     731.136us      81.237us             9  
                                            aten::clone         0.27%      22.649us        21.59%       1.789ms     198.823us       0.000us         0.00%     731.136us      81.237us             9  
                                            aten::copy_         0.82%      67.700us        20.66%       1.712ms     190.261us     670.176us        10.12%     731.136us      81.237us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     670.176us        10.12%     670.176us      74.464us             9  
                                Activity Buffer Request        17.30%       1.434ms        17.30%       1.434ms       1.434ms      60.960us         0.92%      60.960us      60.960us             1  
                                        aten::transpose         0.90%      75.001us         1.12%      92.890us       3.870us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.22%      17.889us         0.22%      17.889us       0.745us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.15%      12.259us         0.66%      54.410us       6.046us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         0.81%      67.133us         0.81%      67.133us       3.197us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         2.82%     234.057us         2.82%     234.057us      19.505us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.03%       2.420us         0.03%       2.420us       0.807us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.04%       3.430us         0.04%       3.430us       1.143us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        72.59%       6.017ms        72.59%       6.017ms       6.017ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 8.289ms
Self CUDA time total: 6.624ms


impl                     wl                  p50(ms)  ok
torch_mem_eff            cuda_attn_L128_bfloat16     1.81  True
torch_mem_eff            cuda_attn_L256_bfloat16     1.88  True
torch_mem_eff            cuda_attn_L320_bfloat16     1.97  True
torch_mem_eff            cuda_attn_L384_bfloat16     1.97  True
torch_mem_eff            cuda_attn_L448_bfloat16     2.09  True
torch_mem_eff            cuda_attn_L512_bfloat16     2.22  True

Artifacts:

attention.jsonl