mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-08 19:04:52 +00:00
fix bachuan issue and update part of the doc
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
cf564ec0e2
commit
bbc1562014
@ -140,12 +140,6 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
|
||||
# Setting defaults for baichuan custom config which doesn't apply them.
|
||||
config.rope_theta = getattr(config, "rope_theta", 10000)
|
||||
config.num_key_value_heads = getattr(
|
||||
config, "num_key_value_heads", config.num_attention_heads
|
||||
)
|
||||
|
||||
self.rotary_emb = rotary_emb
|
||||
|
||||
# `config.attention_multiplier` is used in Granite
|
||||
@ -476,7 +470,11 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
# Skip fp8 quant for first and last layers
|
||||
self.layers = nn.ModuleList()
|
||||
self.cross_attention_layers = getattr(config, "cross_attention_layers", [])
|
||||
|
||||
# Setting defaults for baichuan custom config which doesn't apply them.
|
||||
config.rope_theta = getattr(config, "rope_theta", 10000)
|
||||
config.num_key_value_heads = getattr(
|
||||
config, "num_key_value_heads", config.num_attention_heads
|
||||
)
|
||||
rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=config.hidden_size // config.num_attention_heads,
|
||||
|
@ -86,42 +86,9 @@ We recommend always using sharding when running on a multi-card machine.
|
||||
By default, all models run with BF16 precision on Gaudi hardware.
|
||||
|
||||
#### FP8 Precision
|
||||
TGI-Gaudi supports FP8 precision inference, which can significantly reduce memory usage and improve performance for large models. We support model like W8A8 FP compressed-tensors parameters such as [RedHatAI/Mixtral-8x7B-Instruct-v0.1-FP8](https://huggingface.co/RedHatAI/Mixtral-8x7B-Instruct-v0.1-FP8) and AutoFP8 generated model[RedHatAI/Meta-Llama-3-8B-Instruct-FP8](https://huggingface.co/RedHatAI/Meta-Llama-3-8B-Instruct-FP8) .
|
||||
TGI-Gaudi supports FP8 precision inference with [Intel Neural Compressor (INC)](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP8.html).
|
||||
|
||||
To run FP8 Inference:
|
||||
|
||||
1. Measure statistics using [Optimum Habana measurement script](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation#running-with-fp8)
|
||||
2. Run the model in TGI with QUANT_CONFIG setting - e.g. `-e QUANT_CONFIG=./quantization_config/maxabs_quant.json`.
|
||||
|
||||
The following commmand example for FP8 inference is based on the assumption that measurement is done via the first step above.
|
||||
|
||||
Example for Llama3.1-70B on 8 cards with FP8 precision:
|
||||
|
||||
```bash
|
||||
model=meta-llama/Meta-Llama-3.1-70B-Instruct
|
||||
hf_token=YOUR_ACCESS_TOKEN
|
||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
|
||||
docker run -p 8080:80 \
|
||||
--runtime=habana \
|
||||
--cap-add=sys_nice \
|
||||
--ipc=host \
|
||||
-v $volume:/data \
|
||||
-v $PWD/quantization_config:/usr/src/quantization_config \
|
||||
-v $PWD/hqt_output:/usr/src/hqt_output \
|
||||
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
|
||||
-e HF_TOKEN=$hf_token \
|
||||
-e MAX_TOTAL_TOKENS=2048 \
|
||||
-e BATCH_BUCKET_SIZE=256 \
|
||||
-e PREFILL_BATCH_BUCKET_SIZE=4 \
|
||||
-e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.4-gaudi \
|
||||
--model-id $model \
|
||||
--sharded true --num-shard 8 \
|
||||
--max-input-tokens 1024 --max-total-tokens 2048 \
|
||||
--max-batch-prefill-tokens 4096 --max-batch-size 256 \
|
||||
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 512
|
||||
```
|
||||
|
||||
### How to Run Vision-Language Models (VLMs)
|
||||
|
||||
@ -139,8 +106,6 @@ docker run -p 8080:80 \
|
||||
--cap-add=sys_nice \
|
||||
--ipc=host \
|
||||
-v $volume:/data \
|
||||
-e PREFILL_BATCH_BUCKET_SIZE=1 \
|
||||
-e BATCH_BUCKET_SIZE=1 \
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.4-gaudi \
|
||||
--model-id $model \
|
||||
--max-input-tokens 4096 --max-batch-prefill-tokens 16384 \
|
||||
@ -155,7 +120,7 @@ curl -N 127.0.0.1:8080/generate \
|
||||
-H 'Content-Type: application/json'
|
||||
```
|
||||
|
||||
> Note: In Llava-v1.6-Mistral-7B, an image usually accounts for 2000 input tokens. For example, an image of size 512x512 is represented by 2800 tokens. Thus, `max-input-tokens` must be larger than the number of tokens associated with the image. Otherwise the image may be truncated. We set `BASE_IMAGE_TOKENS=2048` as the default image token value. This is the minimum value of `max-input-tokens`. You can override the environment variable `BASE_IMAGE_TOKENS` to change this value. The warmup will generate graphs with input length from `BASE_IMAGE_TOKENS` to `max-input-tokens`. For Llava-v1.6-Mistral-7B, the value of `max-batch-prefill-tokens` is 16384, which is calcualted as follows: `prefill_batch_size` = `max-batch-prefill-tokens` / `max-input-tokens`.
|
||||
> Note: In Llava-v1.6-Mistral-7B, an image usually accounts for 2000 input tokens. For example, an image of size 512x512 is represented by 2800 tokens. Thus, `max-input-tokens` must be larger than the number of tokens associated with the image. Otherwise the image may be truncated. The value of `max-batch-prefill-tokens` is 16384, which is calcualted as follows: `prefill_batch_size` = `max-batch-prefill-tokens` / `max-input-tokens`.
|
||||
|
||||
### How to Benchmark Performance
|
||||
|
||||
@ -184,39 +149,16 @@ docker run \
|
||||
|
||||
Please refer to the [inference-benchmarker README](https://github.com/huggingface/inference-benchmarker) for more details.
|
||||
|
||||
### How to Profile Performance
|
||||
|
||||
To collect performance profiling, you need to set the following environment variables:
|
||||
|
||||
| Name | Value(s) | Default | Description |
|
||||
|--------------------| :--------- | :--------------- | :------------------------------------------------------- |
|
||||
| PROF_WAITSTEP | integer | 0 | Control profile wait steps |
|
||||
| PROF_WARMUPSTEP | integer | 0 | Control profile warmup steps |
|
||||
| PROF_STEP | integer | 0 | Enable/disable profile, control profile active steps |
|
||||
| PROF_PATH | string | /tmp/hpu_profile | Define profile folder |
|
||||
| PROF_RANKS | string | 0 | Comma-separated list of ranks to profile |
|
||||
| PROF_RECORD_SHAPES | True/False | False | Control record_shapes option in the profiler |
|
||||
|
||||
To use these environment variables, add them to your docker run command with the -e flag. For example:
|
||||
|
||||
```bash
|
||||
docker run --runtime=habana --ipc=host --cap-add=sys_nice \
|
||||
-p 8080:80 -v $volume:/data -e HF_TOKEN=$hf_token \
|
||||
-e PROF_WAITSTEP=10 \
|
||||
-e PROF_WARMUPSTEP=10 \
|
||||
-e PROF_STEP=1 \
|
||||
-e PROF_PATH=/tmp/hpu_profile \
|
||||
-e PROF_RANKS=0 \
|
||||
-e PROF_RECORD_SHAPES=True \
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.4-gaudi \
|
||||
--model-id $model
|
||||
```
|
||||
|
||||
## Explanation: Understanding TGI on Gaudi
|
||||
|
||||
### The Warmup Process
|
||||
|
||||
To ensure optimal performance, warmup is performed at the beginning of each server run. This process creates queries with various input shapes based on provided parameters and runs basic TGI operations (prefill, decode, concatenate).
|
||||
Intel Gaudi accelerators perform best when operating on models with fixed tensor shapes. [Intel Gaudi Graph Compiler](https://docs.habana.ai/en/latest/Gaudi_Overview/Intel_Gaudi_Software_Suite.html#graph-compiler-and-runtime)
|
||||
generates optimized binary code that implements the given model topology on Gaudi. In its default configuration, the produced binary code may be highly dependent on input and output tensor shapes, requiring graph recompilation
|
||||
when encountering tensors with different shapes within the same topology. While these binaries efficiently utilize Gaudi, the compilation process itself can introduce noticeable overhead in end-to-end execution.
|
||||
In dynamic inference serving scenarios, minimizing the number of graph compilations and reducing the risk of graph compilation occurring during server runtime is important.
|
||||
|
||||
To ensure optimal performance, warmup is performed at the beginning of each server run. This process creates queries with various input shapes based on provided parameters and runs basic TGI operations (prefill, decode).
|
||||
|
||||
Note: Model warmup can take several minutes, especially for FP8 inference. For faster subsequent runs, refer to [Disk Caching Eviction Policy](https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_PyTorch_Models.html#disk-caching-eviction-policy).
|
||||
|
||||
@ -229,20 +171,8 @@ Note: Model warmup can take several minutes, especially for FP8 inference. For f
|
||||
#### Batch Size Parameters
|
||||
- For prefill operation, please set `--max-batch-prefill-tokens` as `bs * max-input-tokens`, where `bs` is your expected maximum prefill batch size.
|
||||
- For decode operation, please set `--max-batch-size` as `bs`, where `bs` is your expected maximum decode batch size.
|
||||
- Please note that batch size will be always padded to the nearest multiplication of `BATCH_BUCKET_SIZE` and `PREFILL_BATCH_BUCKET_SIZE`.
|
||||
- Please note that batch size will be always padded to the nearest shapes what has been warmed up. This is done to avoid out of memory issues and to ensure that the graphs are reused efficiently.
|
||||
|
||||
#### Performance and Memory Parameters
|
||||
- `PAD_SEQUENCE_TO_MULTIPLE_OF` determines sizes of input length buckets. Since warmup creates several graphs for each bucket, it's important to adjust that value proportionally to input sequence length. Otherwise, some out of memory issues can be observed.
|
||||
- `ENABLE_HPU_GRAPH` enables HPU graphs usage, which is crucial for performance results. Recommended value to keep is `true`.
|
||||
|
||||
#### Sequence Length Parameters
|
||||
- `--max-input-tokens`: Maximum possible input prompt length (default: 4095)
|
||||
- `--max-total-tokens`: Maximum possible total sequence length (input + output) (default: 4096)
|
||||
|
||||
#### Batch Size Parameters
|
||||
- `--max-batch-prefill-tokens`: Set as `bs * max-input-tokens` where `bs` is your expected maximum prefill batch size
|
||||
- `--max-batch-size`: Set as `bs` where `bs` is your expected maximum decode batch size
|
||||
- Note: Batch sizes are padded to the nearest multiple of `BATCH_BUCKET_SIZE` and `PREFILL_BATCH_BUCKET_SIZE`
|
||||
|
||||
## Reference
|
||||
|
||||
@ -253,39 +183,44 @@ This section contains reference information about the Gaudi backend.
|
||||
Text Generation Inference enables serving optimized models on Gaudi hardware. The following sections list which models (VLMs & LLMs) are supported on Gaudi.
|
||||
|
||||
**Large Language Models (LLMs)**
|
||||
- [Llama2-7B](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
|
||||
- [Llama2-70B](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf)
|
||||
- [Llama3-8B](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct)
|
||||
- [Llama3-70B](https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct)
|
||||
- [LLama3.1-8B](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct)
|
||||
- [LLama3.1-70B](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct)
|
||||
- [CodeLlama-13B](https://huggingface.co/codellama/CodeLlama-13b-hf)
|
||||
- [Opt-125m](https://huggingface.co/facebook/opt-125m)
|
||||
- [OpenAI-gpt2](https://huggingface.co/openai-community/gpt2)
|
||||
- [Mixtral-8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)
|
||||
- [Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3)
|
||||
- [Qwen2-72B](https://huggingface.co/Qwen/Qwen2-72B-Instruct)
|
||||
- [Qwen2-7B](https://huggingface.co/Qwen/Qwen2-7B-Instruct)
|
||||
- [deepseek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1)
|
||||
- [idefics2](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf)
|
||||
- [Llama](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f)
|
||||
- [CodeLlama](https://huggingface.co/codellama/CodeLlama-13b-hf)
|
||||
- [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)
|
||||
- [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3)
|
||||
- [Qwen 2]https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f)
|
||||
- [Qwen 3](https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f)
|
||||
- [Qwen 3 Moe](https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f)
|
||||
- [Phi-1.5](https://huggingface.co/microsoft/phi-1_5)
|
||||
- [Gemma-7b](https://huggingface.co/google/gemma-7b-it)
|
||||
- [Starcoder2-3b](https://huggingface.co/bigcode/starcoder2-3b)
|
||||
- [Starcoder2-15b](https://huggingface.co/bigcode/starcoder2-15b)
|
||||
- [Starcoder](https://huggingface.co/bigcode/starcoder)
|
||||
- [falcon-7b-instruct](https://huggingface.co/tiiuae/falcon-7b-instruct)
|
||||
- [Falcon-180B](https://huggingface.co/tiiuae/falcon-180B-chat)
|
||||
- [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
|
||||
- [PhiMoe](https://huggingface.co/microsoft/Phi-3.5-MoE-instruct)
|
||||
- [Gemma](https://huggingface.co/google/gemma-7b-it)
|
||||
- [Gemma2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315)
|
||||
- [Gemma3 Text](https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d)
|
||||
- [Granite](https://huggingface.co/ibm-granite/granite-3.0-8b-instruct)
|
||||
- [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus)
|
||||
- [dbrx](https://huggingface.co/databricks/dbrx-instruct)
|
||||
- [Starcoder2](https://huggingface.co/bigcode/starcoder2-3b)
|
||||
- [Falcon](https://huggingface.co/tiiuae/falcon-7b-instruct)
|
||||
- [GPT-2](https://huggingface.co/openai-community/gpt2)
|
||||
- [gpt-j-6b](https://huggingface.co/EleutherAI/gpt-j-6b)
|
||||
- [gpt-bigcode](https://huggingface.co/bigcode/gpt_bigcode-santacoder)
|
||||
- [Baichuan](https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat)
|
||||
|
||||
|
||||
**Vision-Language Models (VLMs)**
|
||||
- [LLaVA-v1.6-Mistral-7B](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf)
|
||||
- [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf)
|
||||
- [Mllama (Multimodal Llama from Meta)](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct)
|
||||
- [Idefics](https://huggingface.co/HuggingFaceM4/idefics-9b)
|
||||
- [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b)
|
||||
- [Idefics 2.5](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3)
|
||||
- [Qwen2-VL-2B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct)
|
||||
- [Qwen/Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct)
|
||||
- [idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b)
|
||||
- [idefics 3](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3)
|
||||
- [PaliGemma](https://huggingface.co/google/paligemma-3b-pt-224)
|
||||
- [Llama4](https://huggingface.co/collections/meta-llama/llama-4-67f0c30d9fe03840bc9d0164)
|
||||
- [Gemma3](https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d)
|
||||
- [Qwen 2.5 VL](https://huggingface.co/collections/Qwen/qwen25-vl-6795ffac22b334a837c0f9a5)
|
||||
- [Qwen 2 VL](https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d)
|
||||
|
||||
We also support on a best effort basis models with different parameters count that use the same model architecture but those models were not tested. For example, the gaudi backend supports `meta-llama/Llama-3.2-1B` as the architecture is the standard llama3 architecture. If you have an issue with a model, please open an issue on the [Gaudi backend repository](https://github.com/huggingface/text-generation-inference/issues).
|
||||
If you have an issue with a model, please open an issue on the [Gaudi backend repository](https://github.com/huggingface/text-generation-inference/issues).
|
||||
|
||||
### Environment Variables
|
||||
|
||||
@ -293,16 +228,10 @@ The following table contains the environment variables that can be used to confi
|
||||
|
||||
| Name | Value(s) | Default | Description | Usage |
|
||||
|-----------------------------| :--------- | :--------------- | :------------------------------------------------------------------------------------------------------------------------------- | :--------------------------- |
|
||||
| ENABLE_HPU_GRAPH | True/False | True | Enable hpu graph or not | add -e in docker run command |
|
||||
| LIMIT_HPU_GRAPH | True/False | True | Skip HPU graph usage for prefill to save memory, set to `True` for large sequence/decoding lengths(e.g. 300/212) | add -e in docker run command |
|
||||
| BATCH_BUCKET_SIZE | integer | 8 | Batch size for decode operation will be rounded to the nearest multiple of this number. This limits the number of cached graphs | add -e in docker run command |
|
||||
| PREFILL_BATCH_BUCKET_SIZE | integer | 4 | Batch size for prefill operation will be rounded to the nearest multiple of this number. This limits the number of cached graphs | add -e in docker run command |
|
||||
| PAD_SEQUENCE_TO_MULTIPLE_OF | integer | 128 | For prefill operation, sequences will be padded to a multiple of provided value. | add -e in docker run command |
|
||||
| SKIP_TOKENIZER_IN_TGI | True/False | False | Skip tokenizer for input/output processing | add -e in docker run command |
|
||||
| WARMUP_ENABLED | True/False | True | Enable warmup during server initialization to recompile all graphs. This can increase TGI setup time. | add -e in docker run command |
|
||||
| QUEUE_THRESHOLD_MS | integer | 120 | Controls the threshold beyond which the request are considered overdue and handled with priority. Shorter requests are prioritized otherwise. | add -e in docker run command |
|
||||
| USE_FLASH_ATTENTION | True/False | True | Whether to enable Habana Flash Attention, provided that the model supports it. Please refer to https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_PyTorch_Models.html?highlight=fusedsdpa#using-fused-scaled-dot-product-attention-fusedsdpa | add -e in docker run command |
|
||||
| FLASH_ATTENTION_RECOMPUTE | True/False | True | Whether to enable Habana Flash Attention in recompute mode on first token generation. | add -e in docker run command |
|
||||
| VLLM_SKIP_WARMUP | True/False | False | Skip graph warmup during server initialization which is not recommended, but could be used for debug. | add -e in docker run command |
|
||||
|
||||
|
||||
## Contributing
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user