mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-26 12:32:10 +00:00
Merge branch 'habana-main' into 2.3.0
This commit is contained in:
commit
8686a0fc6d
@ -73,6 +73,7 @@ RUN cd server && \
|
|||||||
pip install -r requirements.txt && \
|
pip install -r requirements.txt && \
|
||||||
bash ./dill-0.3.8-patch.sh && \
|
bash ./dill-0.3.8-patch.sh && \
|
||||||
pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.17.0 && \
|
pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.17.0 && \
|
||||||
|
BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \
|
||||||
pip install . --no-cache-dir
|
pip install . --no-cache-dir
|
||||||
|
|
||||||
# Install benchmarker
|
# Install benchmarker
|
||||||
|
54
README.md
54
README.md
@ -40,36 +40,40 @@ To use [🤗 text-generation-inference](https://github.com/huggingface/text-gene
|
|||||||
> ```bash
|
> ```bash
|
||||||
> docker build -t tgi_gaudi .
|
> docker build -t tgi_gaudi .
|
||||||
> ```
|
> ```
|
||||||
2. Launch a local server instance:
|
2. Use one of the following snippets to launch a local server instance:
|
||||||
|
> [!NOTE]
|
||||||
|
> For gated models such as [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf), you will have to pass `-e HF_TOKEN=<token>` to the `docker run` commands below with a valid Hugging Face Hub read token.
|
||||||
|
|
||||||
i. On 1 Gaudi card
|
i. On 1 Gaudi card
|
||||||
```bash
|
```bash
|
||||||
model=meta-llama/Llama-2-7b-hf
|
model=meta-llama/Llama-2-7b-hf
|
||||||
hf_token=YOUR_ACCESS_TOKEN
|
hf_token=YOUR_ACCESS_TOKEN
|
||||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
|
|
||||||
docker run -p 8080:80 -v $volume:/data --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none -e HUGGING_FACE_HUB_TOKEN=$hf_token -e ENABLE_HPU_GRAPH=true -e LIMIT_HPU_GRAPH=true -e USE_FLASH_ATTENTION=true -e FLASH_ATTENTION_RECOMPUTE=true --cap-add=sys_nice --ipc=host ghcr.io/huggingface/tgi-gaudi:2.0.5 --model-id $model --max-input-tokens 1024 --max-total-tokens 2048
|
docker run -p 8080:80 -v $volume:/data --runtime=habana -e HABANA_VISIBLE_DEVICES=all \
|
||||||
```
|
-e OMPI_MCA_btl_vader_single_copy_mechanism=none -e HF_TOKEN=$hf_token \
|
||||||
> For gated models such as [StarCoder](https://huggingface.co/bigcode/starcoder), you will have to pass `-e HUGGING_FACE_HUB_TOKEN=<token>` to the `docker run` command above with a valid Hugging Face Hub read token.
|
-e ENABLE_HPU_GRAPH=true -e LIMIT_HPU_GRAPH=true -e USE_FLASH_ATTENTION=true \
|
||||||
|
-e FLASH_ATTENTION_RECOMPUTE=true --cap-add=sys_nice --ipc=host \
|
||||||
ii. On 1 Gaudi card using PyTorch eager mode with torch compile:
|
ghcr.io/huggingface/tgi-gaudi:2.0.5 --model-id $model --max-input-tokens 1024 \
|
||||||
```bash
|
--max-total-tokens 2048
|
||||||
model=meta-llama/Llama-2-7b-hf
|
|
||||||
hf_token=YOUR_ACCESS_TOKEN
|
|
||||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
|
||||||
|
|
||||||
docker run -p 8080:80 -v $volume:/data --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e PT_HPU_LAZY_MODE=0 -e OMPI_MCA_btl_vader_single_copy_mechanism=none -e HUGGING_FACE_HUB_TOKEN=$hf_token --cap-add=sys_nice --ipc=host ghcr.io/huggingface/tgi-gaudi:2.0.5 --model-id $model --max-input-tokens 1024 --max-total-tokens 2048
|
|
||||||
```
|
```
|
||||||
|
|
||||||
iii. On 8 Gaudi cards:
|
ii. On 8 Gaudi cards:
|
||||||
```bash
|
```bash
|
||||||
model=meta-llama/Llama-2-70b-hf
|
model=meta-llama/Llama-2-70b-hf
|
||||||
hf_token=YOUR_ACCESS_TOKEN
|
hf_token=YOUR_ACCESS_TOKEN
|
||||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
|
|
||||||
docker run -p 8080:80 -v $volume:/data --runtime=habana -e PT_HPU_ENABLE_LAZY_COLLECTIVES=true -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none -e HUGGING_FACE_HUB_TOKEN=$hf_token -e ENABLE_HPU_GRAPH=true -e LIMIT_HPU_GRAPH=true -e USE_FLASH_ATTENTION=true -e FLASH_ATTENTION_RECOMPUTE=true --cap-add=sys_nice --ipc=host ghcr.io/huggingface/tgi-gaudi:2.0.5 --model-id $model --sharded true --num-shard 8 --max-input-tokens 1024 --max-total-tokens 2048
|
docker run -p 8080:80 -v $volume:/data --runtime=habana -e PT_HPU_ENABLE_LAZY_COLLECTIVES=true \
|
||||||
|
-e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none \
|
||||||
|
-e HF_TOKEN=$hf_token -e ENABLE_HPU_GRAPH=true -e LIMIT_HPU_GRAPH=true \
|
||||||
|
-e USE_FLASH_ATTENTION=true -e FLASH_ATTENTION_RECOMPUTE=true --cap-add=sys_nice \
|
||||||
|
--ipc=host ghcr.io/huggingface/tgi-gaudi:2.0.5 --model-id $model --sharded true \
|
||||||
|
--num-shard 8 --max-input-tokens 1024 --max-total-tokens 2048
|
||||||
```
|
```
|
||||||
3. You can then send a simple request:
|
3. Wait for the TGI-Gaudi server to come online. You will see something like so:
|
||||||
|
> 2024-05-22T19:31:48.302239Z INFO text_generation_router: router/src/main.rs:378: Connected
|
||||||
|
You can then send a simple request to the server from a separate terminal:
|
||||||
```bash
|
```bash
|
||||||
curl 127.0.0.1:8080/generate \
|
curl 127.0.0.1:8080/generate \
|
||||||
-X POST \
|
-X POST \
|
||||||
@ -124,7 +128,7 @@ docker run -p 8080:80 \
|
|||||||
--runtime=habana \
|
--runtime=habana \
|
||||||
-v $volume:/data \
|
-v $volume:/data \
|
||||||
-e HABANA_VISIBLE_DEVICES=all \
|
-e HABANA_VISIBLE_DEVICES=all \
|
||||||
-e HUGGING_FACE_HUB_TOKEN=$hf_token \
|
-e HF_TOKEN=$hf_token \
|
||||||
-e OMPI_MCA_btl_vader_single_copy_mechanism=none \
|
-e OMPI_MCA_btl_vader_single_copy_mechanism=none \
|
||||||
-e TEXT_GENERATION_SERVER_IGNORE_EOS_TOKEN=true \
|
-e TEXT_GENERATION_SERVER_IGNORE_EOS_TOKEN=true \
|
||||||
-e MAX_TOTAL_TOKENS=2048 \
|
-e MAX_TOTAL_TOKENS=2048 \
|
||||||
@ -155,7 +159,7 @@ docker run -p 8080:80 \
|
|||||||
--runtime=habana \
|
--runtime=habana \
|
||||||
-v $volume:/data \
|
-v $volume:/data \
|
||||||
-e HABANA_VISIBLE_DEVICES=all \
|
-e HABANA_VISIBLE_DEVICES=all \
|
||||||
-e HUGGING_FACE_HUB_TOKEN=$hf_token \
|
-e HF_TOKEN=$hf_token \
|
||||||
-e OMPI_MCA_btl_vader_single_copy_mechanism=none \
|
-e OMPI_MCA_btl_vader_single_copy_mechanism=none \
|
||||||
-e TEXT_GENERATION_SERVER_IGNORE_EOS_TOKEN=true \
|
-e TEXT_GENERATION_SERVER_IGNORE_EOS_TOKEN=true \
|
||||||
-e PT_HPU_ENABLE_LAZY_COLLECTIVES=true \
|
-e PT_HPU_ENABLE_LAZY_COLLECTIVES=true \
|
||||||
@ -188,7 +192,7 @@ docker run -p 8080:80 \
|
|||||||
--runtime=habana \
|
--runtime=habana \
|
||||||
-v $volume:/data \
|
-v $volume:/data \
|
||||||
-e HABANA_VISIBLE_DEVICES=all \
|
-e HABANA_VISIBLE_DEVICES=all \
|
||||||
-e HUGGING_FACE_HUB_TOKEN=$hf_token \
|
-e HF_TOKEN=$hf_token \
|
||||||
-e OMPI_MCA_btl_vader_single_copy_mechanism=none \
|
-e OMPI_MCA_btl_vader_single_copy_mechanism=none \
|
||||||
-e TEXT_GENERATION_SERVER_IGNORE_EOS_TOKEN=true \
|
-e TEXT_GENERATION_SERVER_IGNORE_EOS_TOKEN=true \
|
||||||
-e MAX_TOTAL_TOKENS=2048 \
|
-e MAX_TOTAL_TOKENS=2048 \
|
||||||
@ -219,7 +223,7 @@ docker run -p 8080:80 \
|
|||||||
--runtime=habana \
|
--runtime=habana \
|
||||||
-v $volume:/data \
|
-v $volume:/data \
|
||||||
-e HABANA_VISIBLE_DEVICES=all \
|
-e HABANA_VISIBLE_DEVICES=all \
|
||||||
-e HUGGING_FACE_HUB_TOKEN=$hf_token \
|
-e HF_TOKEN=$hf_token \
|
||||||
-e OMPI_MCA_btl_vader_single_copy_mechanism=none \
|
-e OMPI_MCA_btl_vader_single_copy_mechanism=none \
|
||||||
-e TEXT_GENERATION_SERVER_IGNORE_EOS_TOKEN=true \
|
-e TEXT_GENERATION_SERVER_IGNORE_EOS_TOKEN=true \
|
||||||
-e PT_HPU_ENABLE_LAZY_COLLECTIVES=true \
|
-e PT_HPU_ENABLE_LAZY_COLLECTIVES=true \
|
||||||
@ -281,7 +285,7 @@ curl -N 127.0.0.1:8080/generate_stream \
|
|||||||
|
|
||||||
## Running TGI with FP8 Precision
|
## Running TGI with FP8 Precision
|
||||||
|
|
||||||
TGI-Gaudi supports FP8 precision inference with INC (Intel Neural Compressor) and HQT (Habana Quantization Toolkit). FP8 inference can be run by setting QUANT_CONFIG environment variable in the docker command. From TGI-Gaudi 2.0.4 release, INC is used by default for quantization. HQT will be removed in future releases. To use HQT, disable INC by setting `-e USE_INC=0` in docker command.
|
TGI-Gaudi supports FP8 precision inference with [Intel Neural Compressor (INC)](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP8.html). FP8 inference can be run by setting QUANT_CONFIG environment variable in the docker command.
|
||||||
|
|
||||||
To run FP8 Inference:
|
To run FP8 Inference:
|
||||||
|
|
||||||
@ -303,7 +307,7 @@ docker run -p 8080:80 \
|
|||||||
-v $PWD/hqt_output:/usr/src/hqt_output \
|
-v $PWD/hqt_output:/usr/src/hqt_output \
|
||||||
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
|
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
|
||||||
-e HABANA_VISIBLE_DEVICES=all \
|
-e HABANA_VISIBLE_DEVICES=all \
|
||||||
-e HUGGING_FACE_HUB_TOKEN=$hf_token \
|
-e HF_TOKEN=$hf_token \
|
||||||
-e OMPI_MCA_btl_vader_single_copy_mechanism=none \
|
-e OMPI_MCA_btl_vader_single_copy_mechanism=none \
|
||||||
-e TEXT_GENERATION_SERVER_IGNORE_EOS_TOKEN=true \
|
-e TEXT_GENERATION_SERVER_IGNORE_EOS_TOKEN=true \
|
||||||
-e MAX_TOTAL_TOKENS=2048 \
|
-e MAX_TOTAL_TOKENS=2048 \
|
||||||
@ -337,7 +341,7 @@ docker run -p 8080:80 \
|
|||||||
-v $PWD/hqt_output:/usr/src/hqt_output \
|
-v $PWD/hqt_output:/usr/src/hqt_output \
|
||||||
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
|
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
|
||||||
-e HABANA_VISIBLE_DEVICES=all \
|
-e HABANA_VISIBLE_DEVICES=all \
|
||||||
-e HUGGING_FACE_HUB_TOKEN=$hf_token \
|
-e HF_TOKEN=$hf_token \
|
||||||
-e OMPI_MCA_btl_vader_single_copy_mechanism=none \
|
-e OMPI_MCA_btl_vader_single_copy_mechanism=none \
|
||||||
-e TEXT_GENERATION_SERVER_IGNORE_EOS_TOKEN=true \
|
-e TEXT_GENERATION_SERVER_IGNORE_EOS_TOKEN=true \
|
||||||
-e PT_HPU_ENABLE_LAZY_COLLECTIVES=true \
|
-e PT_HPU_ENABLE_LAZY_COLLECTIVES=true \
|
||||||
@ -374,7 +378,7 @@ docker run -p 8080:80 \
|
|||||||
-v $PWD/hqt_output:/usr/src/hqt_output \
|
-v $PWD/hqt_output:/usr/src/hqt_output \
|
||||||
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
|
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
|
||||||
-e HABANA_VISIBLE_DEVICES=all \
|
-e HABANA_VISIBLE_DEVICES=all \
|
||||||
-e HUGGING_FACE_HUB_TOKEN=$hf_token \
|
-e HF_TOKEN=$hf_token \
|
||||||
-e OMPI_MCA_btl_vader_single_copy_mechanism=none \
|
-e OMPI_MCA_btl_vader_single_copy_mechanism=none \
|
||||||
-e TEXT_GENERATION_SERVER_IGNORE_EOS_TOKEN=true \
|
-e TEXT_GENERATION_SERVER_IGNORE_EOS_TOKEN=true \
|
||||||
-e MAX_TOTAL_TOKENS=2048 \
|
-e MAX_TOTAL_TOKENS=2048 \
|
||||||
@ -408,7 +412,7 @@ docker run -p 8080:80 \
|
|||||||
-v $PWD/hqt_output:/usr/src/hqt_output \
|
-v $PWD/hqt_output:/usr/src/hqt_output \
|
||||||
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
|
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
|
||||||
-e HABANA_VISIBLE_DEVICES=all \
|
-e HABANA_VISIBLE_DEVICES=all \
|
||||||
-e HUGGING_FACE_HUB_TOKEN=$hf_token \
|
-e HF_TOKEN=$hf_token \
|
||||||
-e OMPI_MCA_btl_vader_single_copy_mechanism=none \
|
-e OMPI_MCA_btl_vader_single_copy_mechanism=none \
|
||||||
-e TEXT_GENERATION_SERVER_IGNORE_EOS_TOKEN=true \
|
-e TEXT_GENERATION_SERVER_IGNORE_EOS_TOKEN=true \
|
||||||
-e PT_HPU_ENABLE_LAZY_COLLECTIVES=true \
|
-e PT_HPU_ENABLE_LAZY_COLLECTIVES=true \
|
||||||
|
@ -7,7 +7,7 @@ This example provide a simple way of usage of `tgi-gaudi` with continuous batchi
|
|||||||
### Install
|
### Install
|
||||||
|
|
||||||
```
|
```
|
||||||
pip install -r requirements
|
pip install -r requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
### Setup TGI server
|
### Setup TGI server
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
|
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import habana_frameworks.torch as htorch
|
||||||
|
|
||||||
quant_config = os.getenv("QUANT_CONFIG", "")
|
quant_config = os.getenv("QUANT_CONFIG", "")
|
||||||
is_quantization_enabled = quant_config != ""
|
is_quantization_enabled = quant_config != ""
|
||||||
@ -10,18 +11,35 @@ if is_quantization_enabled:
|
|||||||
os.environ.setdefault("USE_DEFAULT_QUANT_PARAM", "true")
|
os.environ.setdefault("USE_DEFAULT_QUANT_PARAM", "true")
|
||||||
os.environ.setdefault("UPDATE_GRAPH_OUTPUT_MME", "false")
|
os.environ.setdefault("UPDATE_GRAPH_OUTPUT_MME", "false")
|
||||||
os.environ.setdefault("ENABLE_CALC_DYNAMIC_RANGE", "false")
|
os.environ.setdefault("ENABLE_CALC_DYNAMIC_RANGE", "false")
|
||||||
os.environ.setdefault(
|
os.environ.setdefault("UPDATE_MME_OUTPUT_PRECISION_FILTER", "v_proj,matmul_av")
|
||||||
"UPDATE_MME_OUTPUT_PRECISION_FILTER", "v_proj,matmul_av")
|
|
||||||
os.environ.setdefault("EXPERIMENTAL_WEIGHT_SHARING", "FALSE")
|
os.environ.setdefault("EXPERIMENTAL_WEIGHT_SHARING", "FALSE")
|
||||||
|
|
||||||
|
|
||||||
|
def patch_scoped_linear_all_reduce(model):
|
||||||
|
from deepspeed.module_inject.layers import LinearAllreduce
|
||||||
|
from optimum.habana.transformers.models.modeling_all_models import ScopedLinearAllReduce
|
||||||
|
|
||||||
|
for name, module in model.named_children():
|
||||||
|
if type(module) is LinearAllreduce:
|
||||||
|
SL = ScopedLinearAllReduce(mod=module)
|
||||||
|
setattr(model, name, SL)
|
||||||
|
patch_scoped_linear_all_reduce(module)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_quantization(model):
|
||||||
|
if is_quantization_enabled:
|
||||||
|
htorch.core.quantization._mark_params_as_const(model)
|
||||||
|
htorch.core.quantization._check_params_as_const(model)
|
||||||
|
htorch.core.hpu_initialize(model)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
def prepare_model_for_quantization(model):
|
def prepare_model_for_quantization(model):
|
||||||
if is_quantization_enabled:
|
if is_quantization_enabled:
|
||||||
if os.getenv("USE_INC", "1") != "0":
|
if model.config.model_type in ["llama", "falcon", "qwen2", "starcoder2", "gemma"]:
|
||||||
from neural_compressor.torch.quantization import FP8Config, convert
|
patch_scoped_linear_all_reduce(model)
|
||||||
config = FP8Config.from_json_file(quant_config)
|
from neural_compressor.torch.quantization import FP8Config, convert
|
||||||
model = convert(model, config)
|
|
||||||
else:
|
config = FP8Config.from_json_file(quant_config)
|
||||||
import habana_quantization_toolkit
|
model = convert(model, config)
|
||||||
habana_quantization_toolkit.prep_model(model)
|
return model
|
||||||
return model
|
|
||||||
|
@ -665,7 +665,7 @@ class CausalLM(Model):
|
|||||||
model = self.get_deepspeed_model(
|
model = self.get_deepspeed_model(
|
||||||
model_id, dtype, revision
|
model_id, dtype, revision
|
||||||
)
|
)
|
||||||
model = self.prepare_model_for_quantization(model)
|
model = hq_env.prepare_model_for_quantization(model)
|
||||||
else:
|
else:
|
||||||
get_repo_root(model_id)
|
get_repo_root(model_id)
|
||||||
|
|
||||||
@ -684,12 +684,15 @@ class CausalLM(Model):
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
**model_kwargs
|
**model_kwargs
|
||||||
)
|
)
|
||||||
model = self.prepare_model_for_quantization(model)
|
model = hq_env.prepare_model_for_quantization(model)
|
||||||
model = model.eval().to(device)
|
model = model.eval().to(device)
|
||||||
|
|
||||||
self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1
|
self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1
|
||||||
self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true"
|
self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true"
|
||||||
model = remove_kv_cache_from_output(model)
|
|
||||||
|
if model.config.model_type not in ["gpt_bigcode"]: # gpt_bigcode/starcoderbase-3b skips remove_kv_cache_from_output()
|
||||||
|
model = remove_kv_cache_from_output(model)
|
||||||
|
|
||||||
if self.enable_hpu_graph:
|
if self.enable_hpu_graph:
|
||||||
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
|
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
|
||||||
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
|
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
|
||||||
@ -700,7 +703,7 @@ class CausalLM(Model):
|
|||||||
"TORCH COMPILE", f'Torch compiling of model')
|
"TORCH COMPILE", f'Torch compiling of model')
|
||||||
model.model = torch.compile(model.model, backend="hpu_backend", options={"keep_input_mutations": True})
|
model.model = torch.compile(model.model, backend="hpu_backend", options={"keep_input_mutations": True})
|
||||||
|
|
||||||
model = self.setup_quantization(model)
|
model = hq_env.setup_quantization(model)
|
||||||
|
|
||||||
if model.config.model_type not in MODELS_OPTIMIZED_WITH_STATIC_SHAPES:
|
if model.config.model_type not in MODELS_OPTIMIZED_WITH_STATIC_SHAPES:
|
||||||
raise ValueError(f"Model type {model.config.model_type} is not supported!")
|
raise ValueError(f"Model type {model.config.model_type} is not supported!")
|
||||||
@ -727,10 +730,13 @@ class CausalLM(Model):
|
|||||||
"return_dict": True,
|
"return_dict": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
if model.config.model_type in ["llama", "mistral", "starcoder2", "qwen2"]:
|
|
||||||
|
|
||||||
if model.config.model_type in ["llama", "mistral", "qwen2"]:
|
if model.config.model_type in ["llama", "mistral", "starcoder2", "qwen2", "falcon", "gemma"]:
|
||||||
|
|
||||||
|
if model.config.model_type not in ["falcon"]:
|
||||||
self.kwargs["attn_softmax_bf16"] = True
|
self.kwargs["attn_softmax_bf16"] = True
|
||||||
|
|
||||||
|
if model.config.model_type not in ["gemma"]:
|
||||||
self.kwargs["trim_logits"] = True
|
self.kwargs["trim_logits"] = True
|
||||||
|
|
||||||
if os.getenv("USE_FLASH_ATTENTION", "false").lower() == "true":
|
if os.getenv("USE_FLASH_ATTENTION", "false").lower() == "true":
|
||||||
@ -832,29 +838,6 @@ class CausalLM(Model):
|
|||||||
'type': rope_scaling, 'factor': float(rope_factor)
|
'type': rope_scaling, 'factor': float(rope_factor)
|
||||||
}
|
}
|
||||||
|
|
||||||
def setup_quantization(self, model):
|
|
||||||
if hq_env.is_quantization_enabled:
|
|
||||||
htorch.core.quantization._mark_params_as_const(model)
|
|
||||||
htorch.core.quantization._check_params_as_const(model)
|
|
||||||
htorch.core.hpu_initialize(model)
|
|
||||||
return model
|
|
||||||
|
|
||||||
def prepare_model_for_quantization(self, model):
|
|
||||||
if hq_env.is_quantization_enabled:
|
|
||||||
if model.config.model_type == "llama":
|
|
||||||
self.patch_scoped_linear_all_reduce(model)
|
|
||||||
model = hq_env.prepare_model_for_quantization(model)
|
|
||||||
return model
|
|
||||||
|
|
||||||
def patch_scoped_linear_all_reduce(self, model):
|
|
||||||
from deepspeed.module_inject.layers import LinearAllreduce
|
|
||||||
from optimum.habana.transformers.models.modeling_all_models import ScopedLinearAllReduce
|
|
||||||
for name, module in model.named_children():
|
|
||||||
if type(module) is LinearAllreduce:
|
|
||||||
SL = ScopedLinearAllReduce(mod=module)
|
|
||||||
setattr(model, name, SL)
|
|
||||||
self.patch_scoped_linear_all_reduce(module)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def batch_type(self) -> Type[CausalLMBatch]:
|
def batch_type(self) -> Type[CausalLMBatch]:
|
||||||
return CausalLMBatch
|
return CausalLMBatch
|
||||||
@ -903,7 +886,7 @@ class CausalLM(Model):
|
|||||||
|
|
||||||
kwargs.update(self.kwargs)
|
kwargs.update(self.kwargs)
|
||||||
|
|
||||||
if past_key_values is not None:
|
if past_key_values is not None and self.model.config.model_type not in ["gpt_bigcode"]:
|
||||||
return self.model.forward(**kwargs)
|
return self.model.forward(**kwargs)
|
||||||
else:
|
else:
|
||||||
outputs = self.model.forward(**kwargs)
|
outputs = self.model.forward(**kwargs)
|
||||||
@ -988,7 +971,7 @@ class CausalLM(Model):
|
|||||||
else:
|
else:
|
||||||
batch.position_ids += 1
|
batch.position_ids += 1
|
||||||
# Update past key values
|
# Update past key values
|
||||||
if prefill:
|
if prefill or self.model.config.model_type in ["gpt_bigcode"]:
|
||||||
batch.past_key_values = past
|
batch.past_key_values = past
|
||||||
|
|
||||||
htorch.core.mark_step()
|
htorch.core.mark_step()
|
||||||
@ -1032,7 +1015,7 @@ class CausalLM(Model):
|
|||||||
else:
|
else:
|
||||||
token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device)
|
token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device)
|
||||||
input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1)
|
input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1)
|
||||||
batch.logits = self.forward(
|
logits = self.forward(
|
||||||
input_ids,
|
input_ids,
|
||||||
batch.attention_mask,
|
batch.attention_mask,
|
||||||
batch.position_ids,
|
batch.position_ids,
|
||||||
@ -1040,6 +1023,10 @@ class CausalLM(Model):
|
|||||||
batch.past_key_values,
|
batch.past_key_values,
|
||||||
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None,
|
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None,
|
||||||
)
|
)
|
||||||
|
if self.model.config.model_type in ["gpt_bigcode"]:
|
||||||
|
batch.logits, batch.past = logits
|
||||||
|
else:
|
||||||
|
batch.logits = logits
|
||||||
|
|
||||||
htorch.core.mark_step()
|
htorch.core.mark_step()
|
||||||
|
|
||||||
|
@ -576,7 +576,7 @@ class VlmCausalLM(Model):
|
|||||||
model = self.get_deepspeed_model(
|
model = self.get_deepspeed_model(
|
||||||
model_class, model_id, dtype, revision
|
model_class, model_id, dtype, revision
|
||||||
)
|
)
|
||||||
model = self.prepare_model_for_quantization(model)
|
model = hq_env.prepare_model_for_quantization(model)
|
||||||
else:
|
else:
|
||||||
get_repo_root(model_id)
|
get_repo_root(model_id)
|
||||||
|
|
||||||
@ -595,7 +595,7 @@ class VlmCausalLM(Model):
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
**model_kwargs
|
**model_kwargs
|
||||||
)
|
)
|
||||||
model = self.prepare_model_for_quantization(model)
|
model = hq_env.prepare_model_for_quantization(model)
|
||||||
model = model.eval().to(device)
|
model = model.eval().to(device)
|
||||||
|
|
||||||
self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1
|
self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1
|
||||||
@ -611,7 +611,7 @@ class VlmCausalLM(Model):
|
|||||||
"TORCH COMPILE", f'Torch compiling of model')
|
"TORCH COMPILE", f'Torch compiling of model')
|
||||||
model.model = torch.compile(model.model, backend="hpu_backend", options={"keep_input_mutations": True})
|
model.model = torch.compile(model.model, backend="hpu_backend", options={"keep_input_mutations": True})
|
||||||
|
|
||||||
model = self.setup_quantization(model)
|
model = hq_env.setup_quantization(model)
|
||||||
|
|
||||||
if model.config.model_type not in MODELS_OPTIMIZED_WITH_STATIC_SHAPES:
|
if model.config.model_type not in MODELS_OPTIMIZED_WITH_STATIC_SHAPES:
|
||||||
raise ValueError(f"Model type {model.config.model_type} is not supported!")
|
raise ValueError(f"Model type {model.config.model_type} is not supported!")
|
||||||
@ -750,29 +750,6 @@ class VlmCausalLM(Model):
|
|||||||
'type': rope_scaling, 'factor': float(rope_factor)
|
'type': rope_scaling, 'factor': float(rope_factor)
|
||||||
}
|
}
|
||||||
|
|
||||||
def setup_quantization(self, model):
|
|
||||||
if hq_env.is_quantization_enabled:
|
|
||||||
htorch.core.quantization._mark_params_as_const(model)
|
|
||||||
htorch.core.quantization._check_params_as_const(model)
|
|
||||||
htorch.core.hpu_initialize(model)
|
|
||||||
return model
|
|
||||||
|
|
||||||
def prepare_model_for_quantization(self, model):
|
|
||||||
if hq_env.is_quantization_enabled:
|
|
||||||
if model.config.model_type == "llama":
|
|
||||||
self.patch_scoped_linear_all_reduce(model)
|
|
||||||
model = hq_env.prepare_model_for_quantization(model)
|
|
||||||
return model
|
|
||||||
|
|
||||||
def patch_scoped_linear_all_reduce(self, model):
|
|
||||||
from deepspeed.module_inject.layers import LinearAllreduce
|
|
||||||
from optimum.habana.transformers.models.modeling_all_models import ScopedLinearAllReduce
|
|
||||||
for name, module in model.named_children():
|
|
||||||
if type(module) is LinearAllreduce:
|
|
||||||
SL = ScopedLinearAllReduce(mod=module)
|
|
||||||
setattr(model, name, SL)
|
|
||||||
self.patch_scoped_linear_all_reduce(module)
|
|
||||||
|
|
||||||
def decode(self, generated_ids: List[int]) -> str:
|
def decode(self, generated_ids: List[int]) -> str:
|
||||||
return self.tokenizer.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
return self.tokenizer.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||||
|
|
||||||
|
@ -1,3 +1,6 @@
|
|||||||
|
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
|
||||||
|
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
from typing import List, Optional, Tuple, Set, Union
|
from typing import List, Optional, Tuple, Set, Union
|
||||||
|
|
||||||
@ -162,7 +165,11 @@ class StoppingCriteria:
|
|||||||
self.max_new_tokens = max_new_tokens
|
self.max_new_tokens = max_new_tokens
|
||||||
self.current_tokens = 0
|
self.current_tokens = 0
|
||||||
self.current_output = ""
|
self.current_output = ""
|
||||||
self.ignore_eos_token = ignore_eos_token
|
|
||||||
|
if os.getenv("TEXT_GENERATION_SERVER_IGNORE_EOS_TOKEN", "false") == "true":
|
||||||
|
self.ignore_eos_token = True
|
||||||
|
else:
|
||||||
|
self.ignore_eos_token = ignore_eos_token
|
||||||
|
|
||||||
def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
|
def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
|
||||||
self.current_tokens += 1
|
self.current_tokens += 1
|
||||||
|
Loading…
Reference in New Issue
Block a user