mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix: improve text and typos
This commit is contained in:
parent
a13a850542
commit
33043f8255
@ -1,14 +1,14 @@
|
||||
# Train Medusa
|
||||
|
||||
This tutorial will show you how to train a Medusa model on a dataset of your choice.
|
||||
This tutorial will show you how to train a Medusa model on a dataset of your choice. Please check out the [speculation documentation](../conceptual/speculation.md) for more information on how Medusa works and speculation in general.
|
||||
|
||||
Training a Medusa heads can greatly improve the generation performance. Since the model is able to predict multiple tokens at once it can generate text much faster than the original model.
|
||||
## What are the benefits of training a Medusa model?
|
||||
|
||||
> Medusa adds extra "heads" to LLMs to predict multiple future tokens simultaneously. When augmenting a model with Medusa, the original model stays untouched, and only the new heads are fine-tuned during training
|
||||
Training Medusa heads can greatly improve the speed of generation. Medusa adds extra "heads" to LLMs to predict multiple future tokens simultaneously. When augmenting a model with Medusa, the original model stays untouched, and only the new heads are fine-tuned during training.
|
||||
|
||||
One of the most important things is to have a good dataset (with similar data to what will be used in production) because Medusa has a much higher hitrate when the generation is in-domain.
|
||||
One of the most important things is to have a good dataset (with similar data to what will be used in production) because Medusa has a much higher hit-rate when the generation is in-domain.
|
||||
|
||||
> If you train Medusa on a dataset that is very different from the one you will use in production, the performance of the model will be much worse since very few of the predictions will be correct.
|
||||
If you train Medusa on a dataset that is very different from the one you will use in production then the model will not be able to predict the future tokens accurately and consequently the speedup will be minimal or non-existent.
|
||||
|
||||
## Self-distillation (Generating data for training)
|
||||
|
||||
@ -22,24 +22,14 @@ We'll use this output to help train the medusa heads to predict the `n+1`, `n+2`
|
||||
|
||||
The original implementation of Medusa is available at [https://github.com/FasterDecoding/Medusa](https://github.com/FasterDecoding/Medusa) and we'll follow a very similar process to train the model as described on the original repository.
|
||||
|
||||
### Installation
|
||||
|
||||
First, you need to install the Medusa package. You can do this by cloning the repository and installing it with pip.
|
||||
|
||||
There are helpful training scripts available in the `scripts` directory of the repository, that we'll use to train the model.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/FasterDecoding/Medusa.git
|
||||
cd Medusa
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
### Training Tools
|
||||
### Getting Started
|
||||
|
||||
There are two methods for training the model:
|
||||
|
||||
- a forked version of `axlotl` that supports Medusa
|
||||
- `torchrun` that is a wrapper around `torch.distributed.launch`
|
||||
- a forked version of `axlotl` that supports Medusa
|
||||
|
||||
In this tutorial we'll use `torchrun` to train the model as it is the most straightforward way to train the model but similar steps can be followed to train the model using `axlotl` if you prefer.
|
||||
|
||||
### Training with `torchrun`
|
||||
|
||||
@ -54,7 +44,7 @@ uv venv -p 3.10
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
Now lets clone the orignal `Medusa` repository and install the library.
|
||||
Now lets clone the original `Medusa` repository and install the library.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/FasterDecoding/Medusa.git
|
||||
@ -62,7 +52,7 @@ cd Medusa
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Next we'll need some data to train on. We can use the `create_data.py` script to generate the data.
|
||||
Next we'll need some data to train on, we can use the `ShareGPT_Vicuna_unfiltered` dataset that is available on the Hugging Face Hub.
|
||||
|
||||
```bash
|
||||
apt install git-lfs
|
||||
@ -109,7 +99,7 @@ First make sure you have an instance of TGI running with the model you want to u
|
||||
model=HuggingFaceH4/zephyr-7b-beta
|
||||
volume=/home/ubuntu/.cache/huggingface/hub/
|
||||
|
||||
docker run --gpus all --shm-size 1g -e HUGGING_FACE_HUB_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model
|
||||
```
|
||||
|
||||
Now we can generate the data using the `create_data.py` script.
|
||||
@ -125,7 +115,7 @@ At this point our terminal should look like this:
|
||||
<div class="flex justify-center">
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/medusa-train-large.gif"
|
||||
width="400"
|
||||
width="550"
|
||||
/>
|
||||
</div>
|
||||
|
||||
@ -165,7 +155,7 @@ WANDB_MODE=offline torchrun --nproc_per_node=4 medusa/train/train_legacy.py \
|
||||
<div class="flex justify-center">
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/medusa-train-heads-large.gif"
|
||||
width="400"
|
||||
width="550"
|
||||
/>
|
||||
</div>
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user