Introduction to Speculative Decoding
The standard model of generating test from a language model (LM) is autoregressive sampling which involves generation \(K\) tokens using \(K\) sequential runs of the LM. Speculative decoding speeds up LM inference by running 2 models in parallel:
- Target Model: The LM used for token decoding/ generation
- Draft Model: A smaller (and faster) LM that is run in parallel to speed up inference
At a high level, the sampling of tokens is described by:
- Generate \(K\) tokens from the draft model. Call these draft tokens.
- Score the draft tokens using the target model.
- Using a rejection sampling scheme, accept a subset of the \(K\) draft tokens in order - \((j+1)\)-th token can’t be accepted if \(j\)-th token is rejected.
If there is a strong agreement between target and draft models - this approach permits the generation of multiple tokens each time the target model is called. In the worst case, this approach will reduce to classic LM decoding.
Rejection Sampling
Denote the draft model as \(p(.|.)\) and denote the target model as \(q(.|.)\). Given a prompt sequence of tokens \(x_{1},...,x_{n}\) and \(K\) draft tokens \(\tilde{x}_{n+1},...,\tilde{x}_{n+k}\) generated from a draft model \(p\), we accept the token \(\tilde{x}_{n+1}\) with probability
\[\min\left(1, \frac{q(\tilde{x}_{n+1}|x_1, \dots, x_n)}{p(\tilde{x}_{n+1}|x_1, \dots, x_n)}\right)\]
If token \(\tilde{x}_{n+1}\) is accepted, it is no longer a draft token and is considered part of the prompt sequence of tokens - \(x_{1}, ..., x_{n}, x_{n+1}\). If token \(\tilde{x}_{n+1}\) is rejected, the token \(x_{n+1}\) is resampled from the following distribution
\[x_{n+1} \sim (q(x| x_{1}, ..., x_{n}) - p(x| x_{1}, ..., x_{n}))_{+}\]
where \[(f(x))_{+} = \frac{max(0, f(x))}{\sum_{x} max(0, f(x))}\]
\(f(x)\) is the token logit probability for token \(x\) and \(\sum_{x} max(0, f(x))\) is the sum of token logit probabilities for all tokens in the vocabulary. Written another way, if \(|V|\) is the number of tokens in the vocabulary, we can rewrite \((f(x))_{+}\) to be clearer
\[ (f(x_{i}))_{+} = \frac{max(0, f(x_{i}))}{\sum_{i=1}^{|V|} max(0, f(x_{i}))} \]
Looking closer at the difference between target and draft model logit distributions we notice that the difference \((q(x| x_{1}, ..., x_{n}) - p(x| x_{1}, ..., x_{n}))\) is computed for every token.
Consider an illustrative example, suppose we are generating a token \(x_{n+1}\) (following \(x_{1}, ..., x_{n}\)) and let the vocabulary consist of 3 possible tokens - Token A, Token B, Token C. Let the token probabilities from the target and draft models be:
| Token (\(x\)) | Target Model \(q(x)\) | Draft Model \(p(x)\) |
|---|---|---|
| A | 0.6 | 0.4 |
| B | 0.3 | 0.5 |
| C | 0.1 | 0.1 |
Let Token B be sampled. Compute its acceptance probability:
\[\min\left(1, \frac{q(\tilde{x}_{n+1}|x_1, \dots, x_n)}{p(\tilde{x}_{n+1}|x_1, \dots, x_n)}\right) = min\left(1, \frac{q(B)}{p(B)} \right) = min\left(1, \frac{0.3}{0.5} \right) = min(1, 0.6) = 0.6\]
Sample \(r \sim U[0, 1]\), in this case let \(r=0.7\) which implies that token B is to be rejected. Hence we resample a token from the distribution \((q(x) - p(x))_{+}\).
| Token (\(x\)) | \(q(x) - p(x)\) | \(max(0, q(x) - p(x))\) | \(\sum_{x} max(0, q(x) - p(x))\) | \((q(x) - p(x))_{+}\) |
|---|---|---|---|---|
| A | \(0.6 - 0.4 = 0.2\) | \(max(0, 0.2) = 0.2\) | \(0.2 + 0 + 0 = 0.2\) | \(0.2/ 0.2 = 1\) |
| B | \(0.3 - 0.5 = -0.2\) | \(max(0, -0.2) = 0\) | \(0.2 + 0 + 0 = 0.2\) | \(0/ 0.2 = 0\) |
| C | \(0.1 - 0.1 = 0.0\) | \(max(0, 0) = 0\). | \(0.2 + 0 + 0 = 0.2\) | \(0/ 0.2 = 0\) |
The new token \(x_{n+1}\) is sampled from the resampled distribution described by the final column of the above table.
Speculative Sampling with Autoregressive Target and Draft Models
From the DeepMind paper on speculative decoding the algorithm is given as below:

Remarks
- For every iteration of the outer loop, \(K\) draft tokens are proposed by the draft model.
- Up to \(K\) tokens can be accepted based on the rejection sampling scheme in the second inner for loop.
- If \(K-j\) tokens are accepted, token \(x_{n+(K-j+1)}\) is resampled and the second inner for loop is exited. The new token sequence is \(x_{1}, \cdots ,x_{n}, x_{n+1}, \cdots, x_{n+(K-j+1)}\). These are relabeled to \(x_{1}, \cdots, x_{n}\) for the next iteration in the outer while loop. If \(n<T\), a new set of \(K\) draft tokens are generated and the inner loops continue.
- The initial prompt sequence is labelled from \(x_{0}\) this seems like a misnomer given that the sequence is 1-indexed everywhere else. This can be reconciled by setting \(x_{0}\) to a start of sequence (
) token, hence when computing logits it makes little difference. - The \(K+1\) logits are said to be computed in parallel, in practice they can be derived from a single set of logits \(q(x| x_{1}, \cdots, x_{n}, \tilde{x}_{1}, \cdots, \tilde{x}_{K})\).