SageAttention Implementation

SageAttention Benchmark (INT8 Quantized)

▼ code ▼ output ▶ uv-logs | Cell: benchmark | 4.12s | Raw GitHub
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "numpy",
#     "torch==2.8.0",
#     "kernels",
#     "kernels-benchmark-tools",
# ]
#
# [tool.uv.sources]
# kernels-benchmark-tools = { path = "../../../../../tools", editable = true }
# ///
import torch
import sys
from kernels_benchmark_tools import KernelTypeEnum, run_benchmark
from kernels import get_kernel

# Load the sage attention kernel
hf_kernels_sage_attn = get_kernel("kernels-community/sage_attention")


def sage_attention(query, key, value):
    """SageAttention with INT8 Q/K quantization and FP16 P/V"""
    return hf_kernels_sage_attn.fwd(query, key, value, is_causal=False)[0]


run_benchmark(
    kernel_type=KernelTypeEnum.ATTENTION,
    impl_name="sage_int8_fp16",
    impl_tags={"family": "sageattention", "backend": "int8_fp16_cuda", "compile": "none"},
    impl_func=sage_attention,
)
Running attention benchmark on cuda with 6 workloads.
impl                     wl                  p50(ms)  ok
sage_int8_fp16           cuda_attn_L128_bfloat16    FAIL  False
  Error: module 'sage_attention_cb34d81dafacbad9' has no attribute 'fwd'
sage_int8_fp16           cuda_attn_L256_bfloat16    FAIL  False
  Error: module 'sage_attention_cb34d81dafacbad9' has no attribute 'fwd'
sage_int8_fp16           cuda_attn_L320_bfloat16    FAIL  False
  Error: module 'sage_attention_cb34d81dafacbad9' has no attribute 'fwd'
sage_int8_fp16           cuda_attn_L384_bfloat16    FAIL  False
  Error: module 'sage_attention_cb34d81dafacbad9' has no attribute 'fwd'
sage_int8_fp16           cuda_attn_L448_bfloat16    FAIL  False
  Error: module 'sage_attention_cb34d81dafacbad9' has no attribute 'fwd'
sage_int8_fp16           cuda_attn_L512_bfloat16    FAIL  False
  Error: module 'sage_attention_cb34d81dafacbad9' has no attribute 'fwd'
Fetching 11 files: 0%| | 0/11 [00:00<?, ?it/s] Fetching 11 files: 18%|█▊ | 2/11 [00:00<00:00, 17.35it/s] Fetching 11 files: 73%|███████▎ | 8/11 [00:00<00:00, 15.18it/s] Fetching 11 files: 100%|██████████| 11/11 [00:00<00:00, 21.06it/s]

Artifacts:

attention.jsonl