diff --git a/docs/source/conceptual/speculative-decoding/gpt2/encoder.py b/docs/source/conceptual/speculative-decoding/gpt2/encoder.py new file mode 100644 index 00000000..15959651 --- /dev/null +++ b/docs/source/conceptual/speculative-decoding/gpt2/encoder.py @@ -0,0 +1,133 @@ +"""Byte pair encoding utilities. + +Copied from: https://github.com/openai/gpt-2/blob/master/src/encoder.py. +""" + +import json +import os +from functools import lru_cache + +import regex as re + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("ยก"), ord("ยฌ") + 1)) + + list(range(ord("ยฎ"), ord("รฟ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +class Encoder: + def __init__(self, encoder, bpe_merges, errors="replace"): + self.encoder = encoder + self.decoder = {v: k for k, v in self.encoder.items()} + self.errors = errors # how to handle errors in decoding + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) + self.cache = {} + + # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions + self.pat = re.compile( + r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" + ) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token) + pairs = get_pairs(word) + + if not pairs: + return token + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except Exception: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + for token in re.findall(self.pat, text): + token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) + bpe_tokens.extend( + self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ") + ) + return bpe_tokens + + def decode(self, tokens): + text = "".join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode( + "utf-8", errors=self.errors + ) + return text + + +def get_encoder(model_name, models_dir): + with open(os.path.join(models_dir, model_name, "encoder.json"), "r") as f: + encoder = json.load(f) + with open( + os.path.join(models_dir, model_name, "vocab.bpe"), "r", encoding="utf-8" + ) as f: + bpe_data = f.read() + bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]] + return Encoder(encoder=encoder, bpe_merges=bpe_merges) diff --git a/docs/source/conceptual/speculative-decoding/gpt2/gpt.py b/docs/source/conceptual/speculative-decoding/gpt2/gpt.py new file mode 100644 index 00000000..fa90b650 --- /dev/null +++ b/docs/source/conceptual/speculative-decoding/gpt2/gpt.py @@ -0,0 +1,142 @@ +import numpy as np + + +def gelu(x): + return 0.5 * x * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * x**3))) + + +def softmax(x): + exp_x = np.exp(x - np.max(x, axis=-1, keepdims=True)) + return exp_x / np.sum(exp_x, axis=-1, keepdims=True) + + +def layer_norm(x, g, b, eps: float = 1e-5): + mean = np.mean(x, axis=-1, keepdims=True) + variance = np.var(x, axis=-1, keepdims=True) + x = (x - mean) / np.sqrt( + variance + eps + ) # normalize x to have mean=0 and var=1 over last axis + return g * x + b # scale and offset with gamma/beta params + + +def linear(x, w, b): # [m, in], [in, out], [out] -> [m, out] + return x @ w + b + + +def ffn(x, c_fc, c_proj): # [n_seq, n_embd] -> [n_seq, n_embd] + # project up + a = gelu(linear(x, **c_fc)) # [n_seq, n_embd] -> [n_seq, 4*n_embd] + + # project back down + x = linear(a, **c_proj) # [n_seq, 4*n_embd] -> [n_seq, n_embd] + + return x + + +def attention( + q, k, v, mask +): # [n_q, d_k], [n_k, d_k], [n_k, d_v], [n_q, n_k] -> [n_q, d_v] + return softmax(q @ k.T / np.sqrt(q.shape[-1]) + mask) @ v + + +def mha(x, c_attn, c_proj, n_head): # [n_seq, n_embd] -> [n_seq, n_embd] + # qkv projection + x = linear(x, **c_attn) # [n_seq, n_embd] -> [n_seq, 3*n_embd] + + # split into qkv + qkv = np.split(x, 3, axis=-1) # [n_seq, 3*n_embd] -> [3, n_seq, n_embd] + + # split into heads + qkv_heads = list( + map(lambda x: np.split(x, n_head, axis=-1), qkv) + ) # [3, n_seq, n_embd] -> [3, n_head, n_seq, n_embd/n_head] + + # causal mask to hide future inputs from being attended to + causal_mask = (1 - np.tri(x.shape[0], dtype=x.dtype)) * -1e10 # [n_seq, n_seq] + + # perform attention over each head + out_heads = [ + attention(q, k, v, causal_mask) for q, k, v in zip(*qkv_heads) + ] # [3, n_head, n_seq, n_embd/n_head] -> [n_head, n_seq, n_embd/n_head] + + # merge heads + x = np.hstack(out_heads) # [n_head, n_seq, n_embd/n_head] -> [n_seq, n_embd] + + # out projection + x = linear(x, **c_proj) # [n_seq, n_embd] -> [n_seq, n_embd] + + return x + + +def transformer_block( + x, mlp, attn, ln_1, ln_2, n_head +): # [n_seq, n_embd] -> [n_seq, n_embd] + # multi-head causal self attention + x = x + mha( + layer_norm(x, **ln_1), **attn, n_head=n_head + ) # [n_seq, n_embd] -> [n_seq, n_embd] + + # position-wise feed forward network + x = x + ffn(layer_norm(x, **ln_2), **mlp) # [n_seq, n_embd] -> [n_seq, n_embd] + + return x + + +def gpt2(inputs, wte, wpe, blocks, ln_f, n_head): # [n_seq] -> [n_seq, n_vocab] + # token + positional embeddings + x = wte[inputs] + wpe[range(len(inputs))] # [n_seq] -> [n_seq, n_embd] + + # forward pass through n_layer transformer blocks + for block in blocks: + x = transformer_block( + x, **block, n_head=n_head + ) # [n_seq, n_embd] -> [n_seq, n_embd] + + # projection to vocab + x = layer_norm(x, **ln_f) # [n_seq, n_embd] -> [n_seq, n_embd] + return x @ wte.T # [n_seq, n_embd] -> [n_seq, n_vocab] + + +def generate(inputs, params, n_head, n_tokens_to_generate): + from tqdm import tqdm + + for _ in tqdm( + range(n_tokens_to_generate), "generating" + ): # auto-regressive decode loop + logits = gpt2(inputs, **params, n_head=n_head) # model forward pass + next_id = np.argmax(logits[-1]) # greedy sampling + inputs.append(int(next_id)) # append prediction to input + + return inputs[len(inputs) - n_tokens_to_generate :] # only return generated ids + + +def main( + prompt: str, + n_tokens_to_generate: int = 40, + model_size: str = "124M", + models_dir: str = "models", +): + from gpt2.utils import load_encoder_hparams_and_params + + # load encoder, hparams, and params from the released open-ai gpt-2 files + encoder, hparams, params = load_encoder_hparams_and_params(model_size, models_dir) + + # encode the input string using the BPE tokenizer + input_ids = encoder.encode(prompt) + + # make sure we are not surpassing the max sequence length of our model + assert len(input_ids) + n_tokens_to_generate < hparams["n_ctx"] + + # generate output ids + output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate) + + # decode the ids back into a string + output_text = encoder.decode(output_ids) + + return output_text + + +if __name__ == "__main__": + import fire + + fire.Fire(main) diff --git a/docs/source/conceptual/speculative-decoding/gpt2/utils.py b/docs/source/conceptual/speculative-decoding/gpt2/utils.py new file mode 100644 index 00000000..d82685c2 --- /dev/null +++ b/docs/source/conceptual/speculative-decoding/gpt2/utils.py @@ -0,0 +1,81 @@ +import json +import os +import re + +import numpy as np +import requests +import tensorflow as tf +from tqdm import tqdm + +from gpt2.encoder import get_encoder + + +def download_gpt2_files(model_size, model_dir): + assert model_size in ["124M", "355M", "774M", "1558M"] + for filename in [ + "checkpoint", + "encoder.json", + "hparams.json", + "model.ckpt.data-00000-of-00001", + "model.ckpt.index", + "model.ckpt.meta", + "vocab.bpe", + ]: + url = "https://openaipublic.blob.core.windows.net/gpt-2/models" + r = requests.get(f"{url}/{model_size}/{filename}", stream=True) + r.raise_for_status() + + with open(os.path.join(model_dir, filename), "wb") as f: + file_size = int(r.headers["content-length"]) + chunk_size = 1000 + with tqdm( + ncols=100, + desc="Fetching " + filename, + total=file_size, + unit_scale=True, + ) as pbar: + # 1k for chunk_size, since Ethernet packet size is around 1500 bytes + for chunk in r.iter_content(chunk_size=chunk_size): + f.write(chunk) + pbar.update(chunk_size) + + +def load_gpt2_params_from_tf_ckpt(tf_ckpt_path, hparams): + def set_in_nested_dict(d, keys, val): + if not keys: + return val + if keys[0] not in d: + d[keys[0]] = {} + d[keys[0]] = set_in_nested_dict(d[keys[0]], keys[1:], val) + return d + + params = {"blocks": [{} for _ in range(hparams["n_layer"])]} + for name, _ in tf.train.list_variables(tf_ckpt_path): + array = np.squeeze(tf.train.load_variable(tf_ckpt_path, name)) + name = name[len("model/") :] + if name.startswith("h"): + m = re.match(r"h([0-9]+)/(.*)", name) + n = int(m[1]) + sub_name = m[2] + set_in_nested_dict(params["blocks"][n], sub_name.split("/"), array) + else: + set_in_nested_dict(params, name.split("/"), array) + + return params + + +def load_encoder_hparams_and_params(model_size, models_dir): + assert model_size in ["124M", "355M", "774M", "1558M"] + + model_dir = os.path.join(models_dir, model_size) + tf_ckpt_path = tf.train.latest_checkpoint(model_dir) + if not tf_ckpt_path: # download files if necessary + os.makedirs(model_dir, exist_ok=True) + download_gpt2_files(model_size, model_dir) + tf_ckpt_path = tf.train.latest_checkpoint(model_dir) + + encoder = get_encoder(model_size, models_dir) + hparams = json.load(open(os.path.join(model_dir, "hparams.json"))) + params = load_gpt2_params_from_tf_ckpt(tf_ckpt_path, hparams) + + return encoder, hparams, params diff --git a/docs/source/conceptual/speculative-decoding/helper.py b/docs/source/conceptual/speculative-decoding/helper.py new file mode 100644 index 00000000..6b1b6dff --- /dev/null +++ b/docs/source/conceptual/speculative-decoding/helper.py @@ -0,0 +1,19 @@ +import numpy as np + + +def max_fn(x): + x_max = np.where(x > 0, x, 0) + return x_max / np.sum(x_max) + + +def get_sample(p): + # here p is given bc we wanna allocate the same probability to each token, + # if p=[.25, .25] then uniform else, higher prob will go to higher token + # print(p) + # print(np.arange(p.shape[-1])) + return np.random.choice(np.arange(p.shape[-1]), p=p) + + +# used np.array for the broadcasting feature of the numpy n elementwise operation +# print(max_fn(np.array([1,2,-2,5,5]))) +# print(get_sample(np.array([0.1,0.3, 0.6]))) diff --git a/docs/source/conceptual/speculative-decoding/img/image.png b/docs/source/conceptual/speculative-decoding/img/image.png new file mode 100644 index 00000000..87dbd992 Binary files /dev/null and b/docs/source/conceptual/speculative-decoding/img/image.png differ diff --git a/docs/source/conceptual/speculative-decoding/main.py b/docs/source/conceptual/speculative-decoding/main.py new file mode 100644 index 00000000..6d421984 --- /dev/null +++ b/docs/source/conceptual/speculative-decoding/main.py @@ -0,0 +1,143 @@ +import numpy as np +import time +from tqdm import tqdm +from gpt2.utils import load_encoder_hparams_and_params +from gpt2.gpt import gpt2, softmax +from helper import get_sample, max_fn +import functools + + +def auto_reg_sampling(input_seq, model, N_future): + n = len(input_seq) + T = len(input_seq) + N_future + + with tqdm(total=N_future, desc="autoreg sampling") as pbar: + while n < T: + input_seq = np.append(input_seq, get_sample(model(input_seq)[-1])) + n += 1 + pbar.update(1) + return input_seq + + +def spec_sampling(input_seq, draft_model, target_model, N_future, k): + n = len(input_seq) + T = len(input_seq) + N_future + + with tqdm(total=N_future, desc="Spec Sampling") as pbar: + while n < T: + prev_n = n + # step1: autoreg generate from draft model and sample p + input_draft = input_seq + for _ in range(k): + p = draft_model(input_draft) # out logits + input_draft = np.append(input_draft, get_sample(p[-1])) + + # step2: input the whole seq of draft to target model + q = target_model(input_draft) + + # step3: Acceptance/ Rejection based on the p/q ratio + all_generated_tokens_accepted = True + for _ in range(k): + i = n - 1 + j = input_draft[i + 1] + + if np.random.random() < min(1, q[i][j] / p[i][j]): # accepted + input_seq = np.append(input_seq, j) + n += 1 + else: # rejected ---> resample from q-p + input_seq = np.append(input_seq, get_sample(max_fn(q[i] - p[i]))) + n += 1 + all_generated_tokens_accepted = False + break + + # step 4 + if all_generated_tokens_accepted: + input_seq = np.append(input_seq, get_sample(q[-1])) + n += 1 + + # for the bar + pbar.update(n - prev_n) + assert n == len(input_seq), f"{n} {len(input_seq)}" + return input_seq + + +def create_model_fn(params, hparams, temperature, eps=1e-10): + f = functools.partial(gpt2, **params, n_head=hparams["n_head"]) + + def model_fn(inputs): + logits = f(inputs) + logits = logits / (temperature + eps) # eps to avoid division by zero + probs = softmax(logits) + return probs + + return model_fn + + +def main( + prompt: str = "Quantization also improves latency and throughput but suffer from perf", + n_tokens_to_generate: int = 40, + draft_model_size: str = "124M", + target_model_size: str = "355M", + models_dir: str = "models", + K: int = 4, + temperature: float = 0.0, + seed: int = 123, +): + # seed numpy rng + np.random.seed(seed) + + # load encoder, hparams, and params from the released open-ai gpt-2 files + encoder, draft_hparams, draft_params = load_encoder_hparams_and_params( + draft_model_size, models_dir + ) + _, target_hparams, target_params = load_encoder_hparams_and_params( + target_model_size, models_dir + ) + draft_model = create_model_fn(draft_params, draft_hparams, temperature) + target_model = create_model_fn(target_params, target_hparams, temperature) + + # encode inputs + input_ids = encoder.encode(prompt) + + def run_sampling_fn(decode_fn, input_seq, **kwargs): + start = time.perf_counter() + output_ids = decode_fn(input_seq=input_seq, **kwargs) + text = encoder.decode(output_ids) + elapsed_time = time.perf_counter() - start + return text, elapsed_time + + # autoregressive sampling + autoregressive_text, autoregressive_time = run_sampling_fn( + auto_reg_sampling, + input_seq=input_ids, # Pass correct parameter + model=target_model, + N_future=n_tokens_to_generate, # Use N_future instead of N + ) + + # speculative sampling + speculative_text, speculative_time = run_sampling_fn( + spec_sampling, + input_seq=input_ids, # Pass correct parameter + target_model=target_model, + draft_model=draft_model, + N_future=n_tokens_to_generate, # Use N_future instead of N + k=K, + ) + + # print results + print() + print("Autoregressive Decoding:") + print("-" * 50) + print(f"Time = {autoregressive_time:.2f}s") + print(f"Text = {autoregressive_text}") + print() + print("Speculative Decoding:") + print("-" * 50) + print(f"Time = {speculative_time:.2f}s") + print(f"Text = {speculative_text}") + + +if __name__ == "__main__": + import fire + + fire.Fire(main) diff --git a/docs/source/conceptual/speculative-decoding/readme.md b/docs/source/conceptual/speculative-decoding/readme.md new file mode 100644 index 00000000..60520f89 --- /dev/null +++ b/docs/source/conceptual/speculative-decoding/readme.md @@ -0,0 +1,169 @@ +# ๐Ÿ”„๐Ÿ” Speculative Decoding +This project provides a simple implementation of the [Accelerating Large Language Model Decoding with Speculative Sampling](https://arxiv.org/abs/2302.01318) paper by Leviathan et al. The implementation uses pure NumPy for a basic GPT-2 model, demonstrating the concept of speculative decoding in a straightforward manner. + +Key features of this implementation: +- Uses NumPy for all computations, making it easy to understand and modify +- Implements speculative decoding for a GPT-2 model +- Compares performance between standard autoregressive sampling and speculative sampling +- Provides a clear example of how speculative decoding can accelerate language model inference + +This simple implementation serves as an educational tool to understand the core concepts of speculative decoding and its potential benefits in accelerating large language model inference. + +# Speculative Decoding in a nutshell +Speculative decoding is an innovative technique designed to accelerate the inference process of large language models. Here's a brief overview of how it works: + +1. Draft Model: A smaller, faster "draft" model generates a sequence of K tokens quickly. + +2. Target Model: The larger, more accurate "target" model processes the entire sequence (input + draft) in parallel. + +3. Verification: The target model's output is compared with the draft model's predictions. + +4. Accept or Reject: + - If the target model agrees with a draft token, it's accepted. + - If there's a disagreement, the draft is rejected, and the target model's prediction is used instead. + +5. Efficiency Gain: This approach allows the target model to process multiple tokens in a single forward pass, potentially reducing the number of expensive computations. + +The key advantage is that when the draft model's predictions are mostly correct, the process can be significantly faster than traditional autoregressive decoding. Even when the draft model makes mistakes, the performance doesn't degrade below that of standard autoregressive sampling. + +This method leverages the speed of smaller models and the accuracy of larger ones, offering a balance between inference speed and output quality. + + +# ๐Ÿš€ How to Use + +To run the speculative decoding implementation, use the following command: + +```bash +python main.py \ + --prompt "Quantization also improves latency and throughput but suffer from perf" \ + --n_tokens_to_generate 60 \ + --draft_model_size "124M" \ + --target_model_size "355M" \ + --K 4 \ + --temperature 0 # 0 for greedy sampling +``` +Sample Output: + +``` +Autoregressive Decoding +-------------------------------------------------- +Time = 112.19s +Text = Quantization also improves latency and throughput but suffer from perfomance issues. + +The problem is that the performance of the GPU is not the only thing that matters. The CPU is also important. The CPU is the main bottleneck in the GPU. The CPU is the main bottleneck in the GPU. + +The CPU is the main bottleneck in the GPU + +Speculative Decoding +-------------------------------------------------- +Time = 74.12s +Text = Quantization also improves latency and throughput but suffer from perfomance issues. + +The problem is that the performance of the GPU is not the only thing that matters. The CPU is also important. The CPU is the main bottleneck in the GPU. The CPU is the main bottleneck in the GPU. + +The CPU is the main bottleneck in the GPU. The CPU + +``` + + + +# ๐Ÿค”๐Ÿ’ญ Why this works? +Most of the work getting done is **NOT** about computation, but its actually about all those read/writes to access memory. +Bc whats happening is that the input lives on the memory and when you do any computation, it has to travel to the GPU/ to all the caches and registers to do the computation and then back to the memory. This is a very slow process. +![alt text](img/image.png) + +So each time we are doing round trips which is slow and very expensive. SO the idea is basically we gonna do a single trip to GPU and while that memory or at least a chunk of it is in the GPU, we are gonna do as much computation as possible and then we gonna load back the results to the memory. + +> "Now the clever idea is to use a small and cheap draft model to first generate a candidate sequence of K tokens - a 'draft'. Then we feed all of these together through the big model in a batch. This is almost as fast as feeding in just one token, per the above. Then we go from left to right over the logits predicted by the model and sample tokens. Any sample that agrees with the draft allows us to immediately skip forward to the next token. If there is a disagreement then we throw the draft away and eat the cost of doing some throwaway work (sampling the draft and the forward passing for all the later tokens). +> +> The reason this works in practice is that most of the time the draft tokens get accepted, because they are easy, so even a much smaller draft model gets them. As these easy tokens get accepted, we skip through those parts in leaps. The hard tokens where the big model disagrees 'fall back' to original speed, but actually a bit slower because of all the extra work." +> +> โ€” Andrej Karpathy + + + +# ๐Ÿงฎ๐Ÿ’กWhy this works mathematically? + +Speculative decoding's mathematical foundation is rooted in rejection sampling, a Monte Carlo method used to generate samples from a draft/smaller distribution when direct sampling from the target/larger distribution is difficult. + +## Mathematical Foundation: [Rejection Sampling](https://en.wikipedia.org/wiki/Rejection_sampling) + +Speculative decoding's mathematical foundation is rooted in rejection sampling, a Monte Carlo method used to generate samples from a target distribution when direct sampling is difficult. The process involves using a proposal distribution (the draft model) that's easier to sample from, then accepting or rejecting these samples based on comparison with the target distribution (the large model). The rejection sampling theorem guarantees that if we sample from the proposal distribution and accept samples with probability proportional to the ratio of target to proposal distributions, the accepted samples will follow the target distribution exactly. The reason of why this so magically works roots back to the bayes rule that we use to calculate the conditional probability of the next token given the previous context. + +## โŒ๐ŸŽฏ Rejection Sampling Theorem + +The theorem states that if we have a target distribution \( p \) and a proposal distribution \( q \), and we sample from \( q \) and accept samples with probability proportional to the ratio of \( p \) to \( q \), the accepted samples will follow the target distribution \( p \). + +Mathematically, this can be expressed as: + +1. Sample y from q(y) +2. Accept y with probability min(1, p(y) / (M * q(y))) + +Where: +- p(y) is the target distribution +- q(y) is the proposal distribution +- M is a constant such that M โ‰ฅ max(p(y) / q(y)) for all y + +If we follow this procedure, the accepted samples will be distributed according to p(y). + +## Question: What if we dont have access to the same family model for both draft and target model? + +Alternative methods like; + +- Medusa +- N-gram + + +### Medusa + + +Medusa is a [simple method](https://arxiv.org/abs/2401.10774) to create many tokens in a single pass using fine-tuned LM heads in addition to your existing models. + + +You can check a few existing fine-tunes for popular models: + +- [text-generation-inference/gemma-7b-it-medusa](https://huggingface.co/text-generation-inference/gemma-7b-it-medusa) +- [text-generation-inference/Mixtral-8x7B-Instruct-v0.1-medusa](https://huggingface.co/text-generation-inference/Mixtral-8x7B-Instruct-v0.1-medusa) +- [text-generation-inference/Mistral-7B-Instruct-v0.2-medusa](https://huggingface.co/text-generation-inference/Mistral-7B-Instruct-v0.2-medusa) + + +In order to create your own medusa heads for your own finetune, you should check own the original medusa repo. [../basic_tutorials/train_medusa.md](../basic_tutorials/train_medusa.md) + + +In order to use medusa models in TGI, simply point to a medusa enabled model, and everything will load automatically. + + +### N-gram + + +If you don't have a medusa model, or don't have the resource to fine-tune, you can try to use `n-gram`. +N-gram works by trying to find matching tokens in the previous sequence, and use those as speculation for generating new tokens. For example, if the tokens "np.mean" appear multiple times in the sequence, the model can speculate that the next continuation of the tokens "np." is probably also "mean". + +This is an extremely simple method, which works best for code, or highly repetitive text. This might not be beneficial, if the speculation misses too much. + + +In order to enable n-gram speculation simply use + +`--speculate 2` in your flags. [Details about the flag](https://huggingface.co/docs/text-generation-inference/basic_tutorials/launcher#speculate) + + +Please refer to [Speculation](https://huggingface.co/docs/text-generation-inference/conceptual/speculation) for more details. + + +# โšก๐Ÿš€ Summary of most common speed up techniques: +## ๐Ÿง ๐Ÿ’ป Faster Training +- Device: Move on to GPU +- Mix percisions +- Gradient Accumulation +- Distributed Training: + +## โšก๐Ÿค– Faster Inference +- Quantization +- Speculative Decoding (This repo ๐Ÿ’–) +- Pruning +- Caching + - inference-attention: KV cache + - in production: Prompt cache/ Exact cache/ Semantic cache +- Knowledge Distillation + + diff --git a/docs/source/conceptual/speculative-decoding/requirements.txt b/docs/source/conceptual/speculative-decoding/requirements.txt new file mode 100644 index 00000000..a65dc8f7 --- /dev/null +++ b/docs/source/conceptual/speculative-decoding/requirements.txt @@ -0,0 +1,6 @@ +fire==0.6.0 +numpy==2.1.1 +regex==2024.5.15 +Requests==2.32.3 +tensorflow==2.17.0 +tqdm==4.66.4