speculative decoding complete guide added

This commit is contained in:
Shirin Yamani 2024-09-16 18:17:37 -06:00
parent 38fcafcf96
commit a18e071690
8 changed files with 693 additions and 0 deletions

View File

@ -0,0 +1,133 @@
"""Byte pair encoding utilities.
Copied from: https://github.com/openai/gpt-2/blob/master/src/encoder.py.
"""
import json
import os
from functools import lru_cache
import regex as re
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = (
list(range(ord("!"), ord("~") + 1))
+ list(range(ord("¡"), ord("¬") + 1))
+ list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
class Encoder:
def __init__(self, encoder, bpe_merges, errors="replace"):
self.encoder = encoder
self.decoder = {v: k for k, v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self.pat = re.compile(
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except Exception:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = " ".join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
for token in re.findall(self.pat, text):
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
bpe_tokens.extend(
self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
)
return bpe_tokens
def decode(self, tokens):
text = "".join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode(
"utf-8", errors=self.errors
)
return text
def get_encoder(model_name, models_dir):
with open(os.path.join(models_dir, model_name, "encoder.json"), "r") as f:
encoder = json.load(f)
with open(
os.path.join(models_dir, model_name, "vocab.bpe"), "r", encoding="utf-8"
) as f:
bpe_data = f.read()
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]]
return Encoder(encoder=encoder, bpe_merges=bpe_merges)

View File

@ -0,0 +1,142 @@
import numpy as np
def gelu(x):
return 0.5 * x * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * x**3)))
def softmax(x):
exp_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
return exp_x / np.sum(exp_x, axis=-1, keepdims=True)
def layer_norm(x, g, b, eps: float = 1e-5):
mean = np.mean(x, axis=-1, keepdims=True)
variance = np.var(x, axis=-1, keepdims=True)
x = (x - mean) / np.sqrt(
variance + eps
) # normalize x to have mean=0 and var=1 over last axis
return g * x + b # scale and offset with gamma/beta params
def linear(x, w, b): # [m, in], [in, out], [out] -> [m, out]
return x @ w + b
def ffn(x, c_fc, c_proj): # [n_seq, n_embd] -> [n_seq, n_embd]
# project up
a = gelu(linear(x, **c_fc)) # [n_seq, n_embd] -> [n_seq, 4*n_embd]
# project back down
x = linear(a, **c_proj) # [n_seq, 4*n_embd] -> [n_seq, n_embd]
return x
def attention(
q, k, v, mask
): # [n_q, d_k], [n_k, d_k], [n_k, d_v], [n_q, n_k] -> [n_q, d_v]
return softmax(q @ k.T / np.sqrt(q.shape[-1]) + mask) @ v
def mha(x, c_attn, c_proj, n_head): # [n_seq, n_embd] -> [n_seq, n_embd]
# qkv projection
x = linear(x, **c_attn) # [n_seq, n_embd] -> [n_seq, 3*n_embd]
# split into qkv
qkv = np.split(x, 3, axis=-1) # [n_seq, 3*n_embd] -> [3, n_seq, n_embd]
# split into heads
qkv_heads = list(
map(lambda x: np.split(x, n_head, axis=-1), qkv)
) # [3, n_seq, n_embd] -> [3, n_head, n_seq, n_embd/n_head]
# causal mask to hide future inputs from being attended to
causal_mask = (1 - np.tri(x.shape[0], dtype=x.dtype)) * -1e10 # [n_seq, n_seq]
# perform attention over each head
out_heads = [
attention(q, k, v, causal_mask) for q, k, v in zip(*qkv_heads)
] # [3, n_head, n_seq, n_embd/n_head] -> [n_head, n_seq, n_embd/n_head]
# merge heads
x = np.hstack(out_heads) # [n_head, n_seq, n_embd/n_head] -> [n_seq, n_embd]
# out projection
x = linear(x, **c_proj) # [n_seq, n_embd] -> [n_seq, n_embd]
return x
def transformer_block(
x, mlp, attn, ln_1, ln_2, n_head
): # [n_seq, n_embd] -> [n_seq, n_embd]
# multi-head causal self attention
x = x + mha(
layer_norm(x, **ln_1), **attn, n_head=n_head
) # [n_seq, n_embd] -> [n_seq, n_embd]
# position-wise feed forward network
x = x + ffn(layer_norm(x, **ln_2), **mlp) # [n_seq, n_embd] -> [n_seq, n_embd]
return x
def gpt2(inputs, wte, wpe, blocks, ln_f, n_head): # [n_seq] -> [n_seq, n_vocab]
# token + positional embeddings
x = wte[inputs] + wpe[range(len(inputs))] # [n_seq] -> [n_seq, n_embd]
# forward pass through n_layer transformer blocks
for block in blocks:
x = transformer_block(
x, **block, n_head=n_head
) # [n_seq, n_embd] -> [n_seq, n_embd]
# projection to vocab
x = layer_norm(x, **ln_f) # [n_seq, n_embd] -> [n_seq, n_embd]
return x @ wte.T # [n_seq, n_embd] -> [n_seq, n_vocab]
def generate(inputs, params, n_head, n_tokens_to_generate):
from tqdm import tqdm
for _ in tqdm(
range(n_tokens_to_generate), "generating"
): # auto-regressive decode loop
logits = gpt2(inputs, **params, n_head=n_head) # model forward pass
next_id = np.argmax(logits[-1]) # greedy sampling
inputs.append(int(next_id)) # append prediction to input
return inputs[len(inputs) - n_tokens_to_generate :] # only return generated ids
def main(
prompt: str,
n_tokens_to_generate: int = 40,
model_size: str = "124M",
models_dir: str = "models",
):
from gpt2.utils import load_encoder_hparams_and_params
# load encoder, hparams, and params from the released open-ai gpt-2 files
encoder, hparams, params = load_encoder_hparams_and_params(model_size, models_dir)
# encode the input string using the BPE tokenizer
input_ids = encoder.encode(prompt)
# make sure we are not surpassing the max sequence length of our model
assert len(input_ids) + n_tokens_to_generate < hparams["n_ctx"]
# generate output ids
output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate)
# decode the ids back into a string
output_text = encoder.decode(output_ids)
return output_text
if __name__ == "__main__":
import fire
fire.Fire(main)

View File

@ -0,0 +1,81 @@
import json
import os
import re
import numpy as np
import requests
import tensorflow as tf
from tqdm import tqdm
from gpt2.encoder import get_encoder
def download_gpt2_files(model_size, model_dir):
assert model_size in ["124M", "355M", "774M", "1558M"]
for filename in [
"checkpoint",
"encoder.json",
"hparams.json",
"model.ckpt.data-00000-of-00001",
"model.ckpt.index",
"model.ckpt.meta",
"vocab.bpe",
]:
url = "https://openaipublic.blob.core.windows.net/gpt-2/models"
r = requests.get(f"{url}/{model_size}/{filename}", stream=True)
r.raise_for_status()
with open(os.path.join(model_dir, filename), "wb") as f:
file_size = int(r.headers["content-length"])
chunk_size = 1000
with tqdm(
ncols=100,
desc="Fetching " + filename,
total=file_size,
unit_scale=True,
) as pbar:
# 1k for chunk_size, since Ethernet packet size is around 1500 bytes
for chunk in r.iter_content(chunk_size=chunk_size):
f.write(chunk)
pbar.update(chunk_size)
def load_gpt2_params_from_tf_ckpt(tf_ckpt_path, hparams):
def set_in_nested_dict(d, keys, val):
if not keys:
return val
if keys[0] not in d:
d[keys[0]] = {}
d[keys[0]] = set_in_nested_dict(d[keys[0]], keys[1:], val)
return d
params = {"blocks": [{} for _ in range(hparams["n_layer"])]}
for name, _ in tf.train.list_variables(tf_ckpt_path):
array = np.squeeze(tf.train.load_variable(tf_ckpt_path, name))
name = name[len("model/") :]
if name.startswith("h"):
m = re.match(r"h([0-9]+)/(.*)", name)
n = int(m[1])
sub_name = m[2]
set_in_nested_dict(params["blocks"][n], sub_name.split("/"), array)
else:
set_in_nested_dict(params, name.split("/"), array)
return params
def load_encoder_hparams_and_params(model_size, models_dir):
assert model_size in ["124M", "355M", "774M", "1558M"]
model_dir = os.path.join(models_dir, model_size)
tf_ckpt_path = tf.train.latest_checkpoint(model_dir)
if not tf_ckpt_path: # download files if necessary
os.makedirs(model_dir, exist_ok=True)
download_gpt2_files(model_size, model_dir)
tf_ckpt_path = tf.train.latest_checkpoint(model_dir)
encoder = get_encoder(model_size, models_dir)
hparams = json.load(open(os.path.join(model_dir, "hparams.json")))
params = load_gpt2_params_from_tf_ckpt(tf_ckpt_path, hparams)
return encoder, hparams, params

View File

@ -0,0 +1,19 @@
import numpy as np
def max_fn(x):
x_max = np.where(x > 0, x, 0)
return x_max / np.sum(x_max)
def get_sample(p):
# here p is given bc we wanna allocate the same probability to each token,
# if p=[.25, .25] then uniform else, higher prob will go to higher token
# print(p)
# print(np.arange(p.shape[-1]))
return np.random.choice(np.arange(p.shape[-1]), p=p)
# used np.array for the broadcasting feature of the numpy n elementwise operation
# print(max_fn(np.array([1,2,-2,5,5])))
# print(get_sample(np.array([0.1,0.3, 0.6])))

Binary file not shown.

After

Width:  |  Height:  |  Size: 60 KiB

View File

@ -0,0 +1,143 @@
import numpy as np
import time
from tqdm import tqdm
from gpt2.utils import load_encoder_hparams_and_params
from gpt2.gpt import gpt2, softmax
from helper import get_sample, max_fn
import functools
def auto_reg_sampling(input_seq, model, N_future):
n = len(input_seq)
T = len(input_seq) + N_future
with tqdm(total=N_future, desc="autoreg sampling") as pbar:
while n < T:
input_seq = np.append(input_seq, get_sample(model(input_seq)[-1]))
n += 1
pbar.update(1)
return input_seq
def spec_sampling(input_seq, draft_model, target_model, N_future, k):
n = len(input_seq)
T = len(input_seq) + N_future
with tqdm(total=N_future, desc="Spec Sampling") as pbar:
while n < T:
prev_n = n
# step1: autoreg generate from draft model and sample p
input_draft = input_seq
for _ in range(k):
p = draft_model(input_draft) # out logits
input_draft = np.append(input_draft, get_sample(p[-1]))
# step2: input the whole seq of draft to target model
q = target_model(input_draft)
# step3: Acceptance/ Rejection based on the p/q ratio
all_generated_tokens_accepted = True
for _ in range(k):
i = n - 1
j = input_draft[i + 1]
if np.random.random() < min(1, q[i][j] / p[i][j]): # accepted
input_seq = np.append(input_seq, j)
n += 1
else: # rejected ---> resample from q-p
input_seq = np.append(input_seq, get_sample(max_fn(q[i] - p[i])))
n += 1
all_generated_tokens_accepted = False
break
# step 4
if all_generated_tokens_accepted:
input_seq = np.append(input_seq, get_sample(q[-1]))
n += 1
# for the bar
pbar.update(n - prev_n)
assert n == len(input_seq), f"{n} {len(input_seq)}"
return input_seq
def create_model_fn(params, hparams, temperature, eps=1e-10):
f = functools.partial(gpt2, **params, n_head=hparams["n_head"])
def model_fn(inputs):
logits = f(inputs)
logits = logits / (temperature + eps) # eps to avoid division by zero
probs = softmax(logits)
return probs
return model_fn
def main(
prompt: str = "Quantization also improves latency and throughput but suffer from perf",
n_tokens_to_generate: int = 40,
draft_model_size: str = "124M",
target_model_size: str = "355M",
models_dir: str = "models",
K: int = 4,
temperature: float = 0.0,
seed: int = 123,
):
# seed numpy rng
np.random.seed(seed)
# load encoder, hparams, and params from the released open-ai gpt-2 files
encoder, draft_hparams, draft_params = load_encoder_hparams_and_params(
draft_model_size, models_dir
)
_, target_hparams, target_params = load_encoder_hparams_and_params(
target_model_size, models_dir
)
draft_model = create_model_fn(draft_params, draft_hparams, temperature)
target_model = create_model_fn(target_params, target_hparams, temperature)
# encode inputs
input_ids = encoder.encode(prompt)
def run_sampling_fn(decode_fn, input_seq, **kwargs):
start = time.perf_counter()
output_ids = decode_fn(input_seq=input_seq, **kwargs)
text = encoder.decode(output_ids)
elapsed_time = time.perf_counter() - start
return text, elapsed_time
# autoregressive sampling
autoregressive_text, autoregressive_time = run_sampling_fn(
auto_reg_sampling,
input_seq=input_ids, # Pass correct parameter
model=target_model,
N_future=n_tokens_to_generate, # Use N_future instead of N
)
# speculative sampling
speculative_text, speculative_time = run_sampling_fn(
spec_sampling,
input_seq=input_ids, # Pass correct parameter
target_model=target_model,
draft_model=draft_model,
N_future=n_tokens_to_generate, # Use N_future instead of N
k=K,
)
# print results
print()
print("Autoregressive Decoding:")
print("-" * 50)
print(f"Time = {autoregressive_time:.2f}s")
print(f"Text = {autoregressive_text}")
print()
print("Speculative Decoding:")
print("-" * 50)
print(f"Time = {speculative_time:.2f}s")
print(f"Text = {speculative_text}")
if __name__ == "__main__":
import fire
fire.Fire(main)

View File

@ -0,0 +1,169 @@
# 🔄🔍 Speculative Decoding
This project provides a simple implementation of the [Accelerating Large Language Model Decoding with Speculative Sampling](https://arxiv.org/abs/2302.01318) paper by Leviathan et al. The implementation uses pure NumPy for a basic GPT-2 model, demonstrating the concept of speculative decoding in a straightforward manner.
Key features of this implementation:
- Uses NumPy for all computations, making it easy to understand and modify
- Implements speculative decoding for a GPT-2 model
- Compares performance between standard autoregressive sampling and speculative sampling
- Provides a clear example of how speculative decoding can accelerate language model inference
This simple implementation serves as an educational tool to understand the core concepts of speculative decoding and its potential benefits in accelerating large language model inference.
# Speculative Decoding in a nutshell
Speculative decoding is an innovative technique designed to accelerate the inference process of large language models. Here's a brief overview of how it works:
1. Draft Model: A smaller, faster "draft" model generates a sequence of K tokens quickly.
2. Target Model: The larger, more accurate "target" model processes the entire sequence (input + draft) in parallel.
3. Verification: The target model's output is compared with the draft model's predictions.
4. Accept or Reject:
- If the target model agrees with a draft token, it's accepted.
- If there's a disagreement, the draft is rejected, and the target model's prediction is used instead.
5. Efficiency Gain: This approach allows the target model to process multiple tokens in a single forward pass, potentially reducing the number of expensive computations.
The key advantage is that when the draft model's predictions are mostly correct, the process can be significantly faster than traditional autoregressive decoding. Even when the draft model makes mistakes, the performance doesn't degrade below that of standard autoregressive sampling.
This method leverages the speed of smaller models and the accuracy of larger ones, offering a balance between inference speed and output quality.
# 🚀 How to Use
To run the speculative decoding implementation, use the following command:
```bash
python main.py \
--prompt "Quantization also improves latency and throughput but suffer from perf" \
--n_tokens_to_generate 60 \
--draft_model_size "124M" \
--target_model_size "355M" \
--K 4 \
--temperature 0 # 0 for greedy sampling
```
Sample Output:
```
Autoregressive Decoding
--------------------------------------------------
Time = 112.19s
Text = Quantization also improves latency and throughput but suffer from perfomance issues.
The problem is that the performance of the GPU is not the only thing that matters. The CPU is also important. The CPU is the main bottleneck in the GPU. The CPU is the main bottleneck in the GPU.
The CPU is the main bottleneck in the GPU
Speculative Decoding
--------------------------------------------------
Time = 74.12s
Text = Quantization also improves latency and throughput but suffer from perfomance issues.
The problem is that the performance of the GPU is not the only thing that matters. The CPU is also important. The CPU is the main bottleneck in the GPU. The CPU is the main bottleneck in the GPU.
The CPU is the main bottleneck in the GPU. The CPU
```
# 🤔💭 Why this works?
Most of the work getting done is **NOT** about computation, but its actually about all those read/writes to access memory.
Bc whats happening is that the input lives on the memory and when you do any computation, it has to travel to the GPU/ to all the caches and registers to do the computation and then back to the memory. This is a very slow process.
![alt text](img/image.png)
So each time we are doing round trips which is slow and very expensive. SO the idea is basically we gonna do a single trip to GPU and while that memory or at least a chunk of it is in the GPU, we are gonna do as much computation as possible and then we gonna load back the results to the memory.
> "Now the clever idea is to use a small and cheap draft model to first generate a candidate sequence of K tokens - a 'draft'. Then we feed all of these together through the big model in a batch. This is almost as fast as feeding in just one token, per the above. Then we go from left to right over the logits predicted by the model and sample tokens. Any sample that agrees with the draft allows us to immediately skip forward to the next token. If there is a disagreement then we throw the draft away and eat the cost of doing some throwaway work (sampling the draft and the forward passing for all the later tokens).
>
> The reason this works in practice is that most of the time the draft tokens get accepted, because they are easy, so even a much smaller draft model gets them. As these easy tokens get accepted, we skip through those parts in leaps. The hard tokens where the big model disagrees 'fall back' to original speed, but actually a bit slower because of all the extra work."
>
> — Andrej Karpathy
# 🧮💡Why this works mathematically?
Speculative decoding's mathematical foundation is rooted in rejection sampling, a Monte Carlo method used to generate samples from a draft/smaller distribution when direct sampling from the target/larger distribution is difficult.
## Mathematical Foundation: [Rejection Sampling](https://en.wikipedia.org/wiki/Rejection_sampling)
Speculative decoding's mathematical foundation is rooted in rejection sampling, a Monte Carlo method used to generate samples from a target distribution when direct sampling is difficult. The process involves using a proposal distribution (the draft model) that's easier to sample from, then accepting or rejecting these samples based on comparison with the target distribution (the large model). The rejection sampling theorem guarantees that if we sample from the proposal distribution and accept samples with probability proportional to the ratio of target to proposal distributions, the accepted samples will follow the target distribution exactly. The reason of why this so magically works roots back to the bayes rule that we use to calculate the conditional probability of the next token given the previous context.
## ❌🎯 Rejection Sampling Theorem
The theorem states that if we have a target distribution \( p \) and a proposal distribution \( q \), and we sample from \( q \) and accept samples with probability proportional to the ratio of \( p \) to \( q \), the accepted samples will follow the target distribution \( p \).
Mathematically, this can be expressed as:
1. Sample y from q(y)
2. Accept y with probability min(1, p(y) / (M * q(y)))
Where:
- p(y) is the target distribution
- q(y) is the proposal distribution
- M is a constant such that M ≥ max(p(y) / q(y)) for all y
If we follow this procedure, the accepted samples will be distributed according to p(y).
## Question: What if we dont have access to the same family model for both draft and target model?
Alternative methods like;
- Medusa
- N-gram
### Medusa
Medusa is a [simple method](https://arxiv.org/abs/2401.10774) to create many tokens in a single pass using fine-tuned LM heads in addition to your existing models.
You can check a few existing fine-tunes for popular models:
- [text-generation-inference/gemma-7b-it-medusa](https://huggingface.co/text-generation-inference/gemma-7b-it-medusa)
- [text-generation-inference/Mixtral-8x7B-Instruct-v0.1-medusa](https://huggingface.co/text-generation-inference/Mixtral-8x7B-Instruct-v0.1-medusa)
- [text-generation-inference/Mistral-7B-Instruct-v0.2-medusa](https://huggingface.co/text-generation-inference/Mistral-7B-Instruct-v0.2-medusa)
In order to create your own medusa heads for your own finetune, you should check own the original medusa repo. [../basic_tutorials/train_medusa.md](../basic_tutorials/train_medusa.md)
In order to use medusa models in TGI, simply point to a medusa enabled model, and everything will load automatically.
### N-gram
If you don't have a medusa model, or don't have the resource to fine-tune, you can try to use `n-gram`.
N-gram works by trying to find matching tokens in the previous sequence, and use those as speculation for generating new tokens. For example, if the tokens "np.mean" appear multiple times in the sequence, the model can speculate that the next continuation of the tokens "np." is probably also "mean".
This is an extremely simple method, which works best for code, or highly repetitive text. This might not be beneficial, if the speculation misses too much.
In order to enable n-gram speculation simply use
`--speculate 2` in your flags. [Details about the flag](https://huggingface.co/docs/text-generation-inference/basic_tutorials/launcher#speculate)
Please refer to [Speculation](https://huggingface.co/docs/text-generation-inference/conceptual/speculation) for more details.
# ⚡🚀 Summary of most common speed up techniques:
## 🧠💻 Faster Training
- Device: Move on to GPU
- Mix percisions
- Gradient Accumulation
- Distributed Training:
## ⚡🤖 Faster Inference
- Quantization
- Speculative Decoding (This repo 💖)
- Pruning
- Caching
- inference-attention: KV cache
- in production: Prompt cache/ Exact cache/ Semantic cache
- Knowledge Distillation

View File

@ -0,0 +1,6 @@
fire==0.6.0
numpy==2.1.1
regex==2024.5.15
Requests==2.32.3
tensorflow==2.17.0
tqdm==4.66.4