You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: _posts/2025-03-10-sampling.md
+119Lines changed: 119 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -152,6 +152,125 @@ To address this issue, in FlashInfer [v0.2.3](https://github.com/flashinfer-ai/f
152
152
153
153
Figure 4 shows the transition from round(i) to round(i+1) in Dual Pivot Rejection Sampling, in each round, if the sampled token is accepted, we return the token, otherwise, the new range's extent is $\frac{\text{high}-\text{pivot}_1}{2} < \frac{\text{high}-\text{low}}{2}$, which is at least half of the previous range. Thus it's guaranteed that the number of rounds is $O(\log(1/\epsilon))$ where $\epsilon$ is the minimal possible value in floating point representation.
154
154
155
+
## Theoretical Proof of the Correctness of Rejection Sampler
156
+
157
+
In this section, we provide a theoretical proof of the correctness of the rejection sampler, we choose the top-k sampling as an example, and
158
+
other samplers can be proved in a similar way.
159
+
160
+
### Nomenclature
161
+
162
+
| Symbol | Meaning |
163
+
|--------|---------|
164
+
| $p_i > 0$ | Un‑normalised score (unnormalised probability mass) of item $i$ |
165
+
| $T = \operatorname{Top}k = \{i_1,\dots,i_k\}$ | Indices of the **k** largest scores |
166
+
| $Z = \sum_{j \in T} p_j$ | Total mass of the top‑k items |
167
+
| $\tau$ | Current **pivot** (threshold) value |
168
+
169
+
### Theorem
170
+
171
+
The algorithm outputs each top‑k index $j \in T$ with probability
172
+
173
+
$$
174
+
\Pr[\text{output}=j] \;=\; \frac{p_j}{Z},
175
+
$$
176
+
177
+
i.e. **exactly the distribution obtained by first discarding all non‑top‑k items and then sampling categorically inside the top‑k set**.
178
+
179
+
---
180
+
181
+
### Proof
182
+
183
+
Fix any pivot $\tau < \min_{j \in T} p_j$ (true at every step because $\tau$ is always taken from a rejected non‑top‑k item).
where $S(\tau)$ is the sum of all the probabilities of the tokens that are greater than $\tau$, and $W(\tau)$ is the remaining mass of "bad" items still above the threshold.
which is contradiction, which means $\Delta_j(\tau^*) = 0$, and our solution is unique.
265
+
266
+
The algorithm starts with $\tau = 0$; therefore
267
+
268
+
$$
269
+
\Pr[\text{output}=j] = Q_j(0) = \frac{p_j}{Z},
270
+
$$
271
+
272
+
exactly the desired top‑k categorical distribution.
273
+
155
274
## Evaluation
156
275
157
276
Our evaluation demonstrates that FlashInfer's sampling kernel delivers substantial improvements in both kernel-level latency and end-to-end throughput compared to traditional sorting-based implementations.
0 commit comments