Speculative decoding is a technique implemented by GPT-3.5 which significantly speeds up inference (2-4x). As open source models don’t yet generally provide support, there’s an opportunity to improve performance.
The content below is drawn primarily from these two sources (link and link), refer to them for more technical explanations.
Part 1 - Setup
Let’s say we are trying to predict the next 3 tokens for the input sequence - “I saw a dog ride _ _ _ ”. Each underscore here indicates a token to be predicted.
To get started, you need three things:
Input sequence to be predicted
“I saw a dog ride _ _ _ ”
Larger and more powerful, but slower LLM
Say this ouptuts 1 token / second
A smaller and less accurate, but faster LLM
Chosen to match the larger model’s predictions as closely as possible, across different input sequences
Say this outputs 100 tokens / second
Part 2 - Step-by-step walkthrough
We can run the following operations, all within around the time it takes the larger model to predict just a single token:
Use the a smaller and faster model to predict the next few tokens [time: 0.03 seconds]. Call these “suggestions” from the small model.
“I saw a dog ride in the bus”
Construct 3 new input sequences for the larger model to use to verify the suggestions step-by-step. [time: negligible]
“I saw a dog ride _”
“I saw a dog ride in _”
“I saw a dog ride in the _”
Submit the 3 new sequences to the larger model, all at once in batch. These three predictions take only ~1 second to run, see section below for why [time: 1 second]
Evaluate the now results step by step and accept suggestions until we come to an incorrect guess by the smaller model. [time: negligible]
“I saw a dog ride in” CORRECT
“I saw a dog ride in the” CORRECT
“I saw a dog ride in the
car” INCORRECT
The large model has now predicted two tokens correctly (“in” and “the”) in around 1 second, a 2x speedup over the 1 token / second baseline without batch processing.
Part 3 - Why it works
It works for 2 main reasons (borrowed from original tweet):
A lot of sequences in LLMs are easy to predict even for smaller models. Notably, the speedup from speculative decoding depends on the smaller model’s ability to predict the larger model’s output.
You can do K predictions in a batch almost as quickly as you can do a single prediction. This is because moving model weights into memory for each round of predictions takes much longer (say .97 seconds) than each individual prediction (say 0.01 seconds). Once you have the model in memory, 1 sequence (.97 + 0.01 = 0.98 seconds) costs almost as much as 3 sequences (.97 + 3 x 0.01 = 1 second)
Series
This is part I of a series exploring open source LLMs and how to build generative AI applications. More to come!