Skip to content

Commit a73a346

Browse files
committed
add proof
1 parent c12fbf7 commit a73a346

File tree

1 file changed

+119
-0
lines changed

1 file changed

+119
-0
lines changed

_posts/2025-03-10-sampling.md

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,125 @@ To address this issue, in FlashInfer [v0.2.3](https://github.com/flashinfer-ai/f
152152

153153
Figure 4 shows the transition from round(i) to round(i+1) in Dual Pivot Rejection Sampling, in each round, if the sampled token is accepted, we return the token, otherwise, the new range's extent is $\frac{\text{high}-\text{pivot}_1}{2} < \frac{\text{high}-\text{low}}{2}$, which is at least half of the previous range. Thus it's guaranteed that the number of rounds is $O(\log(1/\epsilon))$ where $\epsilon$ is the minimal possible value in floating point representation.
154154

155+
## Theoretical Proof of the Correctness of Rejection Sampler
156+
157+
In this section, we provide a theoretical proof of the correctness of the rejection sampler, we choose the top-k sampling as an example, and
158+
other samplers can be proved in a similar way.
159+
160+
### Nomenclature
161+
162+
| Symbol | Meaning |
163+
|--------|---------|
164+
| $p_i > 0$ | Un‑normalised score (unnormalised probability mass) of item $i$ |
165+
| $T = \operatorname{Top}k = \{i_1,\dots,i_k\}$ | Indices of the **k** largest scores |
166+
| $Z = \sum_{j \in T} p_j$ | Total mass of the top‑k items |
167+
| $\tau$ | Current **pivot** (threshold) value |
168+
169+
### Theorem
170+
171+
The algorithm outputs each top‑k index $j \in T$ with probability
172+
173+
$$
174+
\Pr[\text{output}=j] \;=\; \frac{p_j}{Z},
175+
$$
176+
177+
i.e. **exactly the distribution obtained by first discarding all non‑top‑k items and then sampling categorically inside the top‑k set**.
178+
179+
---
180+
181+
### Proof
182+
183+
Fix any pivot $\tau < \min_{j \in T} p_j$ (true at every step because $\tau$ is always taken from a rejected non‑top‑k item).
184+
Define
185+
186+
$$
187+
Q_j(\tau) \;=\; \Pr[\text{algorithm eventually returns } j \mid \text{current pivot } \tau],
188+
\quad j \in T .
189+
$$
190+
191+
With
192+
193+
$$
194+
S(\tau) \;=\; \sum_{m : p_m > \tau} p_m
195+
\;=\; Z \;+\; W(\tau),\qquad
196+
W(\tau) \;=\!\!\!\! \sum_{r \notin T,\, p_r > \tau}\!\!\!\! p_r ,
197+
$$
198+
199+
where $S(\tau)$ is the sum of all the probabilities of the tokens that are greater than $\tau$, and $W(\tau)$ is the remaining mass of "bad" items still above the threshold.
200+
201+
The next draw obeys
202+
203+
$$
204+
\Pr[i \mid \tau] \;=\; \frac{p_i}{S(\tau)}.
205+
$$
206+
207+
Hence
208+
209+
$$
210+
Q_j(\tau)
211+
\;=\;
212+
\underbrace{\frac{p_j}{S(\tau)}}_{\text{accept immediately}}
213+
\;+\;
214+
\sum_{\substack{r \notin T \\ p_r > \tau}}
215+
\underbrace{\frac{p_r}{S(\tau)}}_{\text{draw } r}\;
216+
\underbrace{Q_j\!\bigl(p_r\bigr)}_{\text{pivot becomes } p_r}
217+
\tag{★}
218+
$$
219+
220+
We show that the following formula is a valid solution to (★):
221+
222+
$$
223+
\boxed{\,Q_j(\tau) \;=\; \dfrac{p_j}{Z}\,}
224+
\qquad\text{for every }\tau < \min_{j \in T} p_j .
225+
$$
226+
227+
We can verify the solution by substituting it into (★):
228+
229+
$$
230+
\begin{aligned}
231+
\text{RHS}
232+
&= \frac{p_j}{S(\tau)}
233+
\;+\; \frac{p_j}{Z} \frac{W(\tau)}{S(\tau)} \\
234+
&= \frac{p_j}{S(\tau)}\!\Bigl(1+\frac{W(\tau)}{Z}\Bigr) \\
235+
&= \frac{p_j}{Z} \frac{Z+W(\tau)}{S(\tau)} \\
236+
&= \frac{p_j}{Z},
237+
\end{aligned}
238+
$$
239+
240+
because $S(\tau) = Z + W(\tau)$.
241+
Thus the claimed form satisfies the recurrence, so $Q_j(\tau) \equiv p_j/Z$.
242+
243+
Now let's show that the solution is unique.
244+
Suppose there is another solution $Q_j'(\tau)$ satisfies (★), let's define
245+
$\Delta_j(\tau) = Q_j(\tau) - Q_j'(\tau)$, we have:
246+
247+
$$
248+
\Delta_j(\tau) = \sum_{\substack{r \notin T \\ p_r > \tau}}
249+
\frac{p_r}{S(\tau)} \Delta_j(p_r)
250+
$$
251+
252+
The sum of the coefficient $\sum_{\substack{r \notin T \\ p_r > \tau}} \frac{p_r}{S(\tau)} = \frac{W(\tau)}{S(\tau)}$, which satisfies:
253+
254+
$$0 \leq \frac{W(\tau)}{S(\tau)} < 1$$
255+
256+
Suppose $\tau^*$ is the pivot where $|\Delta_j(\tau)|$ reach its maximum, if it's positive, we have:
257+
258+
$$
259+
|\Delta_j(\tau^*)| \leq \sum_{\substack{r \notin T \\ p_r > \tau^*}}
260+
\frac{p_r}{S(\tau)} |\Delta_j(p_r)| \leq \sum_{\substack{r \notin T \\ p_r > \tau^*}}
261+
\frac{p_r}{S(\tau)} |\Delta_j(\tau^*)| = \frac{W(\tau^*)}{S(\tau^*)} |\Delta_j(\tau^*)| < |\Delta_j(\tau^*)|
262+
$$
263+
264+
which is contradiction, which means $\Delta_j(\tau^*) = 0$, and our solution is unique.
265+
266+
The algorithm starts with $\tau = 0$; therefore
267+
268+
$$
269+
\Pr[\text{output}=j] = Q_j(0) = \frac{p_j}{Z},
270+
$$
271+
272+
exactly the desired top‑k categorical distribution.
273+
155274
## Evaluation
156275

157276
Our evaluation demonstrates that FlashInfer's sampling kernel delivers substantial improvements in both kernel-level latency and end-to-end throughput compared to traditional sorting-based implementations.

0 commit comments

Comments
 (0)