Papers
arxiv:2603.15854

FlashSampling: Fast and Memory-Efficient Exact Sampling

Published on Mar 16
· Submitted by
Yifan Zhang
on Mar 18
Authors:
,
,
,
,

Abstract

FlashSampling enables efficient categorical sampling by fusing the operation into the language model head matmul, eliminating memory overhead and reducing decoding time by up to 19%.

AI-generated summary

Sampling from a categorical distribution is mathematically simple, but in large-vocabulary decoding, it often triggers extra memory traffic and extra kernels after the LM head. We present FlashSampling, an exact sampling primitive that fuses sampling into the LM-head matmul and never materializes the logits tensor in HBM. The method is simple: compute logits tile-by-tile on chip, add Gumbel noise, keep only one maximizer per row and per vocabulary tile, and finish with a small reduction over tiles. The fused tiled kernel is exact because argmax decomposes over a partition; grouped variants for online and tensor-parallel settings are exact by hierarchical factorization of the categorical distribution. Across H100, H200, B200, and B300 GPUs, FlashSampling speeds up kernel-level decode workloads, and in end-to-end vLLM experiments, it reduces time per output token by up to 19% on the models we test. These results show that exact sampling, with no approximation, can be integrated into the matmul itself, turning a bandwidth-bound postprocessing step into a lightweight epilogue. Project Page: https://github.com/FlashSampling/FlashSampling.

Community

Paper author Paper submitter

Sampling from a categorical distribution is mathematically simple, but in large-vocabulary decoding, it often triggers extra memory traffic and extra kernels after the LM head. We present FlashSampling, an exact sampling primitive that fuses sampling into the LM-head matmul and never materializes the logits tensor in HBM. The method is simple: compute logits tile-by-tile on chip, add Gumbel noise, keep only one maximizer per row and per vocabulary tile, and finish with a small reduction over tiles. The fused tiled kernel is exact because $\argmax$ decomposes over a partition; grouped variants for online and tensor-parallel settings are exact by hierarchical factorization of the categorical distribution. Across H100, H200, B200, and B300 GPUs, FlashSampling speeds up kernel-level decode workloads, and in end-to-end vLLM experiments, it reduces time per output token by up to 19 on the models we test. These results show that exact sampling, with no approximation, can be integrated into the matmul itself, turning a bandwidth-bound postprocessing step into a lightweight epilogue.

Sign up or log in to comment

Models citing this paper 0

No model linking this paper

Cite arxiv.org/abs/2603.15854 in a model README.md to link it from this page.

Datasets citing this paper 0

No dataset linking this paper

Cite arxiv.org/abs/2603.15854 in a dataset README.md to link it from this page.

Spaces citing this paper 0

No Space linking this paper

Cite arxiv.org/abs/2603.15854 in a Space README.md to link it from this page.

Collections including this paper 0

No Collection including this paper

Add this paper to a collection to link it from this page.