diff --git a/Dockerfile b/Dockerfile index c7967bea..b64c6079 100644 --- a/Dockerfile +++ b/Dockerfile @@ -73,6 +73,7 @@ RUN cd server && \ pip install -r requirements.txt && \ bash ./dill-0.3.8-patch.sh && \ pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.17.0 && \ + BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \ pip install . --no-cache-dir # Install benchmarker diff --git a/README.md b/README.md index 000aa524..2fa8836b 100644 --- a/README.md +++ b/README.md @@ -40,36 +40,40 @@ To use [🤗 text-generation-inference](https://github.com/huggingface/text-gene > ```bash > docker build -t tgi_gaudi . > ``` -2. Launch a local server instance: +2. Use one of the following snippets to launch a local server instance: +> [!NOTE] +> For gated models such as [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf), you will have to pass `-e HF_TOKEN=` to the `docker run` commands below with a valid Hugging Face Hub read token. - i. On 1 Gaudi card + i. On 1 Gaudi card ```bash model=meta-llama/Llama-2-7b-hf hf_token=YOUR_ACCESS_TOKEN volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run - docker run -p 8080:80 -v $volume:/data --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none -e HUGGING_FACE_HUB_TOKEN=$hf_token -e ENABLE_HPU_GRAPH=true -e LIMIT_HPU_GRAPH=true -e USE_FLASH_ATTENTION=true -e FLASH_ATTENTION_RECOMPUTE=true --cap-add=sys_nice --ipc=host ghcr.io/huggingface/tgi-gaudi:2.0.5 --model-id $model --max-input-tokens 1024 --max-total-tokens 2048 - ``` - > For gated models such as [StarCoder](https://huggingface.co/bigcode/starcoder), you will have to pass `-e HUGGING_FACE_HUB_TOKEN=` to the `docker run` command above with a valid Hugging Face Hub read token. - - ii. On 1 Gaudi card using PyTorch eager mode with torch compile: - ```bash - model=meta-llama/Llama-2-7b-hf - hf_token=YOUR_ACCESS_TOKEN - volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run - - docker run -p 8080:80 -v $volume:/data --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e PT_HPU_LAZY_MODE=0 -e OMPI_MCA_btl_vader_single_copy_mechanism=none -e HUGGING_FACE_HUB_TOKEN=$hf_token --cap-add=sys_nice --ipc=host ghcr.io/huggingface/tgi-gaudi:2.0.5 --model-id $model --max-input-tokens 1024 --max-total-tokens 2048 + docker run -p 8080:80 -v $volume:/data --runtime=habana -e HABANA_VISIBLE_DEVICES=all \ + -e OMPI_MCA_btl_vader_single_copy_mechanism=none -e HF_TOKEN=$hf_token \ + -e ENABLE_HPU_GRAPH=true -e LIMIT_HPU_GRAPH=true -e USE_FLASH_ATTENTION=true \ + -e FLASH_ATTENTION_RECOMPUTE=true --cap-add=sys_nice --ipc=host \ + ghcr.io/huggingface/tgi-gaudi:2.0.5 --model-id $model --max-input-tokens 1024 \ + --max-total-tokens 2048 ``` - iii. On 8 Gaudi cards: + ii. On 8 Gaudi cards: ```bash model=meta-llama/Llama-2-70b-hf hf_token=YOUR_ACCESS_TOKEN volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run - docker run -p 8080:80 -v $volume:/data --runtime=habana -e PT_HPU_ENABLE_LAZY_COLLECTIVES=true -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none -e HUGGING_FACE_HUB_TOKEN=$hf_token -e ENABLE_HPU_GRAPH=true -e LIMIT_HPU_GRAPH=true -e USE_FLASH_ATTENTION=true -e FLASH_ATTENTION_RECOMPUTE=true --cap-add=sys_nice --ipc=host ghcr.io/huggingface/tgi-gaudi:2.0.5 --model-id $model --sharded true --num-shard 8 --max-input-tokens 1024 --max-total-tokens 2048 + docker run -p 8080:80 -v $volume:/data --runtime=habana -e PT_HPU_ENABLE_LAZY_COLLECTIVES=true \ + -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none \ + -e HF_TOKEN=$hf_token -e ENABLE_HPU_GRAPH=true -e LIMIT_HPU_GRAPH=true \ + -e USE_FLASH_ATTENTION=true -e FLASH_ATTENTION_RECOMPUTE=true --cap-add=sys_nice \ + --ipc=host ghcr.io/huggingface/tgi-gaudi:2.0.5 --model-id $model --sharded true \ + --num-shard 8 --max-input-tokens 1024 --max-total-tokens 2048 ``` -3. You can then send a simple request: +3. Wait for the TGI-Gaudi server to come online. You will see something like so: + > 2024-05-22T19:31:48.302239Z INFO text_generation_router: router/src/main.rs:378: Connected + You can then send a simple request to the server from a separate terminal: ```bash curl 127.0.0.1:8080/generate \ -X POST \ @@ -124,7 +128,7 @@ docker run -p 8080:80 \ --runtime=habana \ -v $volume:/data \ -e HABANA_VISIBLE_DEVICES=all \ - -e HUGGING_FACE_HUB_TOKEN=$hf_token \ + -e HF_TOKEN=$hf_token \ -e OMPI_MCA_btl_vader_single_copy_mechanism=none \ -e TEXT_GENERATION_SERVER_IGNORE_EOS_TOKEN=true \ -e MAX_TOTAL_TOKENS=2048 \ @@ -155,7 +159,7 @@ docker run -p 8080:80 \ --runtime=habana \ -v $volume:/data \ -e HABANA_VISIBLE_DEVICES=all \ - -e HUGGING_FACE_HUB_TOKEN=$hf_token \ + -e HF_TOKEN=$hf_token \ -e OMPI_MCA_btl_vader_single_copy_mechanism=none \ -e TEXT_GENERATION_SERVER_IGNORE_EOS_TOKEN=true \ -e PT_HPU_ENABLE_LAZY_COLLECTIVES=true \ @@ -188,7 +192,7 @@ docker run -p 8080:80 \ --runtime=habana \ -v $volume:/data \ -e HABANA_VISIBLE_DEVICES=all \ - -e HUGGING_FACE_HUB_TOKEN=$hf_token \ + -e HF_TOKEN=$hf_token \ -e OMPI_MCA_btl_vader_single_copy_mechanism=none \ -e TEXT_GENERATION_SERVER_IGNORE_EOS_TOKEN=true \ -e MAX_TOTAL_TOKENS=2048 \ @@ -219,7 +223,7 @@ docker run -p 8080:80 \ --runtime=habana \ -v $volume:/data \ -e HABANA_VISIBLE_DEVICES=all \ - -e HUGGING_FACE_HUB_TOKEN=$hf_token \ + -e HF_TOKEN=$hf_token \ -e OMPI_MCA_btl_vader_single_copy_mechanism=none \ -e TEXT_GENERATION_SERVER_IGNORE_EOS_TOKEN=true \ -e PT_HPU_ENABLE_LAZY_COLLECTIVES=true \ @@ -281,7 +285,7 @@ curl -N 127.0.0.1:8080/generate_stream \ ## Running TGI with FP8 Precision -TGI-Gaudi supports FP8 precision inference with INC (Intel Neural Compressor) and HQT (Habana Quantization Toolkit). FP8 inference can be run by setting QUANT_CONFIG environment variable in the docker command. From TGI-Gaudi 2.0.4 release, INC is used by default for quantization. HQT will be removed in future releases. To use HQT, disable INC by setting `-e USE_INC=0` in docker command. +TGI-Gaudi supports FP8 precision inference with [Intel Neural Compressor (INC)](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP8.html). FP8 inference can be run by setting QUANT_CONFIG environment variable in the docker command. To run FP8 Inference: @@ -303,7 +307,7 @@ docker run -p 8080:80 \ -v $PWD/hqt_output:/usr/src/hqt_output \ -e QUANT_CONFIG=./quantization_config/maxabs_quant.json \ -e HABANA_VISIBLE_DEVICES=all \ - -e HUGGING_FACE_HUB_TOKEN=$hf_token \ + -e HF_TOKEN=$hf_token \ -e OMPI_MCA_btl_vader_single_copy_mechanism=none \ -e TEXT_GENERATION_SERVER_IGNORE_EOS_TOKEN=true \ -e MAX_TOTAL_TOKENS=2048 \ @@ -337,7 +341,7 @@ docker run -p 8080:80 \ -v $PWD/hqt_output:/usr/src/hqt_output \ -e QUANT_CONFIG=./quantization_config/maxabs_quant.json \ -e HABANA_VISIBLE_DEVICES=all \ - -e HUGGING_FACE_HUB_TOKEN=$hf_token \ + -e HF_TOKEN=$hf_token \ -e OMPI_MCA_btl_vader_single_copy_mechanism=none \ -e TEXT_GENERATION_SERVER_IGNORE_EOS_TOKEN=true \ -e PT_HPU_ENABLE_LAZY_COLLECTIVES=true \ @@ -374,7 +378,7 @@ docker run -p 8080:80 \ -v $PWD/hqt_output:/usr/src/hqt_output \ -e QUANT_CONFIG=./quantization_config/maxabs_quant.json \ -e HABANA_VISIBLE_DEVICES=all \ - -e HUGGING_FACE_HUB_TOKEN=$hf_token \ + -e HF_TOKEN=$hf_token \ -e OMPI_MCA_btl_vader_single_copy_mechanism=none \ -e TEXT_GENERATION_SERVER_IGNORE_EOS_TOKEN=true \ -e MAX_TOTAL_TOKENS=2048 \ @@ -408,7 +412,7 @@ docker run -p 8080:80 \ -v $PWD/hqt_output:/usr/src/hqt_output \ -e QUANT_CONFIG=./quantization_config/maxabs_quant.json \ -e HABANA_VISIBLE_DEVICES=all \ - -e HUGGING_FACE_HUB_TOKEN=$hf_token \ + -e HF_TOKEN=$hf_token \ -e OMPI_MCA_btl_vader_single_copy_mechanism=none \ -e TEXT_GENERATION_SERVER_IGNORE_EOS_TOKEN=true \ -e PT_HPU_ENABLE_LAZY_COLLECTIVES=true \ diff --git a/examples/README.md b/examples/README.md index 93f391ec..e605364e 100644 --- a/examples/README.md +++ b/examples/README.md @@ -7,7 +7,7 @@ This example provide a simple way of usage of `tgi-gaudi` with continuous batchi ### Install ``` -pip install -r requirements +pip install -r requirements.txt ``` ### Setup TGI server @@ -36,4 +36,4 @@ All possible parameters are described in the below table: | TOTAL_SAMPLE_COUNT | 2048 | Number of samples to run. | | MAX_CONCURRENT_REQUESTS | 256 | The number of requests sent simultaneously to the TGI server. | - \ No newline at end of file + diff --git a/server/text_generation_server/habana_quantization_env.py b/server/text_generation_server/habana_quantization_env.py index 3c06fd09..e942fdcf 100644 --- a/server/text_generation_server/habana_quantization_env.py +++ b/server/text_generation_server/habana_quantization_env.py @@ -1,6 +1,7 @@ # Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. import os +import habana_frameworks.torch as htorch quant_config = os.getenv("QUANT_CONFIG", "") is_quantization_enabled = quant_config != "" @@ -10,18 +11,35 @@ if is_quantization_enabled: os.environ.setdefault("USE_DEFAULT_QUANT_PARAM", "true") os.environ.setdefault("UPDATE_GRAPH_OUTPUT_MME", "false") os.environ.setdefault("ENABLE_CALC_DYNAMIC_RANGE", "false") - os.environ.setdefault( - "UPDATE_MME_OUTPUT_PRECISION_FILTER", "v_proj,matmul_av") + os.environ.setdefault("UPDATE_MME_OUTPUT_PRECISION_FILTER", "v_proj,matmul_av") os.environ.setdefault("EXPERIMENTAL_WEIGHT_SHARING", "FALSE") +def patch_scoped_linear_all_reduce(model): + from deepspeed.module_inject.layers import LinearAllreduce + from optimum.habana.transformers.models.modeling_all_models import ScopedLinearAllReduce + + for name, module in model.named_children(): + if type(module) is LinearAllreduce: + SL = ScopedLinearAllReduce(mod=module) + setattr(model, name, SL) + patch_scoped_linear_all_reduce(module) + + +def setup_quantization(model): + if is_quantization_enabled: + htorch.core.quantization._mark_params_as_const(model) + htorch.core.quantization._check_params_as_const(model) + htorch.core.hpu_initialize(model) + return model + + def prepare_model_for_quantization(model): if is_quantization_enabled: - if os.getenv("USE_INC", "1") != "0": - from neural_compressor.torch.quantization import FP8Config, convert - config = FP8Config.from_json_file(quant_config) - model = convert(model, config) - else: - import habana_quantization_toolkit - habana_quantization_toolkit.prep_model(model) - return model + if model.config.model_type in ["llama", "falcon", "qwen2", "starcoder2", "gemma"]: + patch_scoped_linear_all_reduce(model) + from neural_compressor.torch.quantization import FP8Config, convert + + config = FP8Config.from_json_file(quant_config) + model = convert(model, config) + return model diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 6d8092b9..88c5debf 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -665,7 +665,7 @@ class CausalLM(Model): model = self.get_deepspeed_model( model_id, dtype, revision ) - model = self.prepare_model_for_quantization(model) + model = hq_env.prepare_model_for_quantization(model) else: get_repo_root(model_id) @@ -684,12 +684,15 @@ class CausalLM(Model): trust_remote_code=trust_remote_code, **model_kwargs ) - model = self.prepare_model_for_quantization(model) + model = hq_env.prepare_model_for_quantization(model) model = model.eval().to(device) self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1 self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true" - model = remove_kv_cache_from_output(model) + + if model.config.model_type not in ["gpt_bigcode"]: # gpt_bigcode/starcoderbase-3b skips remove_kv_cache_from_output() + model = remove_kv_cache_from_output(model) + if self.enable_hpu_graph: from habana_frameworks.torch.hpu import wrap_in_hpu_graph model = wrap_in_hpu_graph(model, disable_tensor_cache=True) @@ -700,7 +703,7 @@ class CausalLM(Model): "TORCH COMPILE", f'Torch compiling of model') model.model = torch.compile(model.model, backend="hpu_backend", options={"keep_input_mutations": True}) - model = self.setup_quantization(model) + model = hq_env.setup_quantization(model) if model.config.model_type not in MODELS_OPTIMIZED_WITH_STATIC_SHAPES: raise ValueError(f"Model type {model.config.model_type} is not supported!") @@ -727,10 +730,13 @@ class CausalLM(Model): "return_dict": True, } - if model.config.model_type in ["llama", "mistral", "starcoder2", "qwen2"]: - - if model.config.model_type in ["llama", "mistral", "qwen2"]: + + if model.config.model_type in ["llama", "mistral", "starcoder2", "qwen2", "falcon", "gemma"]: + + if model.config.model_type not in ["falcon"]: self.kwargs["attn_softmax_bf16"] = True + + if model.config.model_type not in ["gemma"]: self.kwargs["trim_logits"] = True if os.getenv("USE_FLASH_ATTENTION", "false").lower() == "true": @@ -832,29 +838,6 @@ class CausalLM(Model): 'type': rope_scaling, 'factor': float(rope_factor) } - def setup_quantization(self, model): - if hq_env.is_quantization_enabled: - htorch.core.quantization._mark_params_as_const(model) - htorch.core.quantization._check_params_as_const(model) - htorch.core.hpu_initialize(model) - return model - - def prepare_model_for_quantization(self, model): - if hq_env.is_quantization_enabled: - if model.config.model_type == "llama": - self.patch_scoped_linear_all_reduce(model) - model = hq_env.prepare_model_for_quantization(model) - return model - - def patch_scoped_linear_all_reduce(self, model): - from deepspeed.module_inject.layers import LinearAllreduce - from optimum.habana.transformers.models.modeling_all_models import ScopedLinearAllReduce - for name, module in model.named_children(): - if type(module) is LinearAllreduce: - SL = ScopedLinearAllReduce(mod=module) - setattr(model, name, SL) - self.patch_scoped_linear_all_reduce(module) - @property def batch_type(self) -> Type[CausalLMBatch]: return CausalLMBatch @@ -903,7 +886,7 @@ class CausalLM(Model): kwargs.update(self.kwargs) - if past_key_values is not None: + if past_key_values is not None and self.model.config.model_type not in ["gpt_bigcode"]: return self.model.forward(**kwargs) else: outputs = self.model.forward(**kwargs) @@ -988,7 +971,7 @@ class CausalLM(Model): else: batch.position_ids += 1 # Update past key values - if prefill: + if prefill or self.model.config.model_type in ["gpt_bigcode"]: batch.past_key_values = past htorch.core.mark_step() @@ -1032,7 +1015,7 @@ class CausalLM(Model): else: token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device) input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1) - batch.logits = self.forward( + logits = self.forward( input_ids, batch.attention_mask, batch.position_ids, @@ -1040,6 +1023,10 @@ class CausalLM(Model): batch.past_key_values, bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None, ) + if self.model.config.model_type in ["gpt_bigcode"]: + batch.logits, batch.past = logits + else: + batch.logits = logits htorch.core.mark_step() diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index a07dafd5..88734bdc 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -207,7 +207,7 @@ class VlmCausalLMBatch(CausalLMBatch): device: torch.device, is_warmup: bool = False, ) -> "VlmCausalLMBatch": - + dbg_trace('FROM_PB', f'num_reqs:{len(pb.requests)}') requests = [CausalLMRequest.from_pb(idx, req, tokenizer) for idx, req in enumerate(pb.requests)] @@ -576,7 +576,7 @@ class VlmCausalLM(Model): model = self.get_deepspeed_model( model_class, model_id, dtype, revision ) - model = self.prepare_model_for_quantization(model) + model = hq_env.prepare_model_for_quantization(model) else: get_repo_root(model_id) @@ -595,7 +595,7 @@ class VlmCausalLM(Model): trust_remote_code=trust_remote_code, **model_kwargs ) - model = self.prepare_model_for_quantization(model) + model = hq_env.prepare_model_for_quantization(model) model = model.eval().to(device) self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1 @@ -611,7 +611,7 @@ class VlmCausalLM(Model): "TORCH COMPILE", f'Torch compiling of model') model.model = torch.compile(model.model, backend="hpu_backend", options={"keep_input_mutations": True}) - model = self.setup_quantization(model) + model = hq_env.setup_quantization(model) if model.config.model_type not in MODELS_OPTIMIZED_WITH_STATIC_SHAPES: raise ValueError(f"Model type {model.config.model_type} is not supported!") @@ -750,29 +750,6 @@ class VlmCausalLM(Model): 'type': rope_scaling, 'factor': float(rope_factor) } - def setup_quantization(self, model): - if hq_env.is_quantization_enabled: - htorch.core.quantization._mark_params_as_const(model) - htorch.core.quantization._check_params_as_const(model) - htorch.core.hpu_initialize(model) - return model - - def prepare_model_for_quantization(self, model): - if hq_env.is_quantization_enabled: - if model.config.model_type == "llama": - self.patch_scoped_linear_all_reduce(model) - model = hq_env.prepare_model_for_quantization(model) - return model - - def patch_scoped_linear_all_reduce(self, model): - from deepspeed.module_inject.layers import LinearAllreduce - from optimum.habana.transformers.models.modeling_all_models import ScopedLinearAllReduce - for name, module in model.named_children(): - if type(module) is LinearAllreduce: - SL = ScopedLinearAllReduce(mod=module) - setattr(model, name, SL) - self.patch_scoped_linear_all_reduce(module) - def decode(self, generated_ids: List[int]) -> str: return self.tokenizer.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) @@ -946,7 +923,7 @@ class VlmCausalLM(Model): bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None, ) elif all([req.stopping_criteria.max_new_tokens == 1 for req in batch.requests]): - # Don't schedule next forward if max_new_tokens for all requests equals 1 + # Don't schedule next forward if max_new_tokens for all requests equals 1 # - we've already generated the first and only needed token in the prefill phase pass else: diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 1136fa96..aa4d1fdb 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -1,3 +1,6 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. + +import os import re from typing import List, Optional, Tuple, Set, Union @@ -162,7 +165,11 @@ class StoppingCriteria: self.max_new_tokens = max_new_tokens self.current_tokens = 0 self.current_output = "" - self.ignore_eos_token = ignore_eos_token + + if os.getenv("TEXT_GENERATION_SERVER_IGNORE_EOS_TOKEN", "false") == "true": + self.ignore_eos_token = True + else: + self.ignore_eos_token = ignore_eos_token def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]: self.current_tokens += 1