From c08005a4cd1f761b75dea788d46f133dd1070056 Mon Sep 17 00:00:00 2001 From: baptiste Date: Mon, 24 Feb 2025 09:48:44 +0000 Subject: [PATCH] feat(gaudi): new gaudi backend working --- .gitignore | 3 + Dockerfile_gaudi | 22 +++-- backends/gaudi/Makefile | 50 +++++++++++ backends/gaudi/README.md | 84 +++++++++++++++++++ backends/gaudi/server/Makefile | 6 +- .../models/causal_lm.py | 28 ++++--- .../text_generation_server/models/model.py | 7 +- .../server/text_generation_server/server.py | 6 +- backends/gaudi/tgi-entrypoint.sh | 5 ++ launcher/src/env_runtime.rs | 17 +++- launcher/src/main.rs | 9 ++ 11 files changed, 213 insertions(+), 24 deletions(-) create mode 100644 backends/gaudi/Makefile create mode 100644 backends/gaudi/README.md create mode 100644 backends/gaudi/tgi-entrypoint.sh diff --git a/.gitignore b/.gitignore index 9434d75c..248001fd 100644 --- a/.gitignore +++ b/.gitignore @@ -23,3 +23,6 @@ server/fbgemmm .direnv/ .venv/ + +# Gaudi auto-generated files +hl-smi_log*.txt \ No newline at end of file diff --git a/Dockerfile_gaudi b/Dockerfile_gaudi index e59814c3..0af46248 100644 --- a/Dockerfile_gaudi +++ b/Dockerfile_gaudi @@ -17,8 +17,15 @@ RUN cargo chef prepare --recipe-path recipe.json FROM chef AS builder -RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ - python3.11-dev +ENV PYO3_PYTHON="/root/.local/bin/python" \ + PYTHON_SYS_EXECUTABLE="/root/.local/bin/python" \ + PYO3_PYTHON_VERSION="3.10" + +RUN curl -LsSf https://astral.sh/uv/install.sh | sh \ + && . $HOME/.local/bin/env \ + && uv python install 3.10 --default --preview \ + && test -f /root/.local/bin/python || (echo "Python 3.10 not found at /root/.local/bin/python" && exit 1) + RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ @@ -52,6 +59,9 @@ ENV HF_HOME=/data \ HF_HUB_ENABLE_HF_TRANSFER=1 \ PORT=80 +# Assert that Python 3.10 is installed as the launcher is compiled with Python 3.10 +RUN python3.10 --version || (echo "Python 3.10 is not installed" && exit 1) + # libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it RUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \ dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb @@ -64,17 +74,17 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins make \ curl \ git \ - python3.11-dev \ && rm -rf /var/lib/apt/lists/* # Install server COPY proto proto -COPY server server -COPY server/Makefile server/Makefile +COPY backends/gaudi/server server +COPY backends/gaudi/server/Makefile server/Makefile RUN cd server && \ make gen-server && \ pip install --no-deps -r requirements.txt && \ bash ./dill-0.3.8-patch.sh && \ + pip install outlines~=0.0.34 && \ pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.19.0 && \ BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \ pip install . --no-cache-dir @@ -98,7 +108,7 @@ ENTRYPOINT ["./entrypoint.sh"] # Final image FROM base -COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh +COPY backends/gaudi/tgi-entrypoint.sh /tgi-entrypoint.sh RUN chmod +x /tgi-entrypoint.sh ENTRYPOINT ["/tgi-entrypoint.sh"] diff --git a/backends/gaudi/Makefile b/backends/gaudi/Makefile new file mode 100644 index 00000000..507dfedb --- /dev/null +++ b/backends/gaudi/Makefile @@ -0,0 +1,50 @@ +mkfile_path := $(abspath $(lastword $(MAKEFILE_LIST))) +mkfile_dir := $(dir $(mkfile_path)) +root_dir := "${mkfile_dir}/../.." + +.PHONY: image run-local-dev-container install-dependencies install-server install-router install-launcher local-dev-install + +image: + docker build -t tgi-gaudi -f ${root_dir}/Dockerfile_gaudi ${root_dir} + +run-local-dev-container: + docker run -it \ + --runtime=habana \ + -e HABANA_VISIBLE_DEVICES=all \ + -e PT_HPU_ENABLE_LAZY_COLLECTIVES=true \ + -e LOG_LEVEL=debug \ + -e OMPI_MCA_btl_vader_single_copy_mechanism=none \ + -e HF_TOKEN=`cat /home/ubuntu/.cache/huggingface/token` \ + -e ENABLE_HPU_GRAPH=true \ + -e LIMIT_HPU_GRAPH=true \ + -e USE_FLASH_ATTENTION=true \ + -e FLASH_ATTENTION_RECOMPUTE=true \ + -e PORT=8080 \ + --cap-add=sys_nice \ + --net=host \ + --ipc=host \ + -v /home/ubuntu/.cache/huggingface:/data \ + -v $(PWD):/text-generation-inference \ + -w /text-generation-inference \ + vault.habana.ai/gaudi-docker/1.19.0/ubuntu22.04/habanalabs/pytorch-installer-2.5.1:latest + +install-dependencies: + pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.19.0 + pip install outlines~=0.0.34 + curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y + +install-server: + make -C ${root_dir}/backends/gaudi/server install PROTO_PATH=../../../proto/v3 + +install-router: + make -C ${root_dir} install-router + +install-launcher: + make -C ${root_dir} install-launcher + +# use source to load the rust in path +local-dev-install: install-dependencies + bash -c 'source "$$HOME/.cargo/env" && \ + make install-server && \ + make install-router && \ + make install-launcher' \ No newline at end of file diff --git a/backends/gaudi/README.md b/backends/gaudi/README.md new file mode 100644 index 00000000..695fa41a --- /dev/null +++ b/backends/gaudi/README.md @@ -0,0 +1,84 @@ +# Text-generation-inference - Gaudi backend + +## Description + +This is the TGI backend for Intel Gaudi. This backend is composed of the tgi server optimized for Gaudi hardware. + +## Build your own image + +The simplest way to build TGI with the gaudi backend is to use the provided `Makefile`: + +Option 1: From the project root directory: +```bash +make -C backends/gaudi image +``` + +Option 2: From the Gaudi backend directory: +```bash +cd backends/gaudi +make image +``` + +You can now run the server with the following command: +```bash +model=meta-llama/Llama-3.1-8B-Instruct +hf_token=$(cat ${HOME}/.cache/huggingface/token) +volume=${HOME}/.cache/huggingface + +docker run -p 8080:80 -v $volume:/data --runtime=habana -e PT_HPU_ENABLE_LAZY_COLLECTIVES=true \ +-e LOG_LEVEL=debug \ +-e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none \ +-e HF_TOKEN=$hf_token -e ENABLE_HPU_GRAPH=true -e LIMIT_HPU_GRAPH=true \ +-e USE_FLASH_ATTENTION=true -e FLASH_ATTENTION_RECOMPUTE=true --cap-add=sys_nice \ +--ipc=host tgi-gaudi --model-id $model --sharded true \ +--num-shard 8 --max-input-tokens 512 --max-total-tokens 1024 --max-batch-size 8 --max-batch-prefill-tokens 2048 --max-batch-total-tokens 8192 +``` + +## Contributing + +### Local Development + +This is useful if you want to run the server in locally for better debugging. +```bash +make -C backends/gaudi run-local-dev-container +``` + +Then run the following command inside the container to install tgi for gaudi: +```bash +make -C backends/gaudi local-dev-install +``` + +Add rust to path: +```bash +. "$HOME/.cargo/env" +``` + +Option 1: Run the server (sharded model): +```bash +LOG_LEVEL=debug text-generation-launcher \ + --model-id meta-llama/Llama-3.1-8B-Instruct \ + --sharded true \ + --num-shard 8 \ + --max-input-tokens 512 \ + --max-total-tokens 1024 \ + --max-batch-size 8 \ + --max-batch-prefill-tokens 2048 +``` + +Option 2: Run the server (non-sharded model): +```bash +LOG_LEVEL=debug text-generation-launcher \ + --model-id meta-llama/Llama-3.1-8B-Instruct \ + --max-input-tokens 512 \ + --max-total-tokens 1024 \ + --max-batch-size 4 \ + --max-batch-prefill-tokens 2048 +``` + +You can then test the server with the following curl command from another terminal (can be outside the container): +```bash +curl 127.0.0.1:8080/generate \ + -X POST \ + -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \ + -H 'Content-Type: application/json' +``` diff --git a/backends/gaudi/server/Makefile b/backends/gaudi/server/Makefile index f01897e5..b5b84338 100644 --- a/backends/gaudi/server/Makefile +++ b/backends/gaudi/server/Makefile @@ -5,6 +5,8 @@ include Makefile-awq include Makefile-eetq include Makefile-selective-scan +PROTO_PATH ?= ../proto/v3 + unit-tests: pytest -s -vv -m "not private" tests @@ -12,8 +14,8 @@ gen-server: # Compile protos pip install grpcio-tools==1.62.2 mypy-protobuf==3.6.0 'types-protobuf' --no-cache-dir mkdir text_generation_server/pb || true - python -m grpc_tools.protoc -I../proto/v3 --python_out=text_generation_server/pb \ - --grpc_python_out=text_generation_server/pb --mypy_out=text_generation_server/pb ../proto/v3/generate.proto + python -m grpc_tools.protoc -I$(PROTO_PATH) --python_out=text_generation_server/pb \ + --grpc_python_out=text_generation_server/pb --mypy_out=text_generation_server/pb $(PROTO_PATH)/generate.proto find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; touch text_generation_server/pb/__init__.py diff --git a/backends/gaudi/server/text_generation_server/models/causal_lm.py b/backends/gaudi/server/text_generation_server/models/causal_lm.py index 5b105224..92f7a806 100644 --- a/backends/gaudi/server/text_generation_server/models/causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/causal_lm.py @@ -59,7 +59,7 @@ CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1)) BATCH_BUCKET_SIZE = int(os.environ.get("BATCH_BUCKET_SIZE", 8)) PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get("PREFILL_BATCH_BUCKET_SIZE", 2)) - +MAX_BATCH_SIZE = int(os.environ.get('MAX_BATCH_SIZE')) if os.environ.get('MAX_BATCH_SIZE') is not None else None def torch_compile_for_eager(func): if LAZY_MODE == 1: @@ -1289,9 +1289,13 @@ class CausalLM(Model): return self.batch_type.from_pb(batch, self.tokenizer, self.dtype, self.device) - def warmup(self, request) -> None: + def warmup(self, request: generate_pb2.WarmupRequest) -> Tuple[Optional[int], Optional[int], Optional[int]]: + assert MAX_BATCH_SIZE is not None, "MAX_BATCH_SIZE is not set, it should be set in the launcher" + MAX_BATCH_TOTAL_TOKENS = MAX_BATCH_SIZE * request.max_total_tokens + logger.info(f"MAX_BATCH_SIZE: {MAX_BATCH_SIZE}") + logger.info(f"MAX_BATCH_TOTAL_TOKENS: {MAX_BATCH_TOTAL_TOKENS}") MAX_TOTAL_TOKENS = request.max_total_tokens - MAX_BATCH_TOTAL_TOKENS = request.max_batch_total_tokens + batch = self.batch_type.from_pb( request.batch, self.tokenizer, self.dtype, self.device ) @@ -1308,18 +1312,18 @@ class CausalLM(Model): del prefill_batch # Warmup prefill batch_size - max_input_length = request.max_input_length + max_input_tokens = request.max_input_tokens prefill_batch_size_list = [batch for batch in range(PREFILL_BATCH_BUCKET_SIZE, max_prefill_batch_size, PREFILL_BATCH_BUCKET_SIZE)] prefill_batch_size_list.append(max_prefill_batch_size) prefill_seqlen_list = [ seq for seq in range( PAD_SEQUENCE_TO_MULTIPLE_OF, - max_input_length, + max_input_tokens, PAD_SEQUENCE_TO_MULTIPLE_OF, ) ] - prefill_seqlen_list.append(max_input_length) + prefill_seqlen_list.append(max_input_tokens) prefill_batch_size_list.sort(reverse=True) prefill_seqlen_list.sort(reverse=True) try: @@ -1345,8 +1349,7 @@ class CausalLM(Model): f"Prefill sequence length list:{prefill_seqlen_list}\n" f"Memory stats: {mem_stats} " ) - - # warmup decode batch size + max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS) max_decode_batch_size = round_up(max_decode_batch_size, BATCH_BUCKET_SIZE) decode_batch_size_list = [ @@ -1388,12 +1391,15 @@ class CausalLM(Model): ) decode_batch_size_list.sort() - MAX_BATCH_TOTAL_TOKENS = MAX_TOTAL_TOKENS * decode_batch_size_list[-1] + max_supported_total_tokens = MAX_TOTAL_TOKENS * decode_batch_size_list[-1] mem_stats = get_hpu_memory_stats(self.device) logger.info( f"\nFollowing decode warmup successfully.\n" f"Decode batch size list:{decode_batch_size_list}\n" f"Memory stats: {mem_stats} " ) - - return MAX_BATCH_TOTAL_TOKENS + + max_input_tokens=max_input_tokens + max_total_tokens=MAX_TOTAL_TOKENS + + return max_supported_total_tokens, max_input_tokens, max_total_tokens diff --git a/backends/gaudi/server/text_generation_server/models/model.py b/backends/gaudi/server/text_generation_server/models/model.py index 4568b8dd..04172c74 100644 --- a/backends/gaudi/server/text_generation_server/models/model.py +++ b/backends/gaudi/server/text_generation_server/models/model.py @@ -8,9 +8,11 @@ from collections import defaultdict from transformers import PreTrainedTokenizerBase from text_generation_server.models.types import Batch, Generation +from text_generation_server.models.globals import BLOCK_SIZE from text_generation_server.utils.speculate import get_speculate from text_generation_server.pb.generate_pb2 import InfoResponse from text_generation_server.adapters.weights import LayerAdapterWeights +from text_generation_server.pb import generate_pb2 import time BASE_MODEL_ADAPTER_ID = "__base_model__" @@ -79,6 +81,7 @@ class Model(ABC): device_type=self.device.type, window_size=self.sliding_window, speculate=self.speculate, + block_size=BLOCK_SIZE, ) @property @@ -92,9 +95,9 @@ class Model(ABC): ) -> Tuple[List[Generation], Optional[B], Tuple[int, int]]: raise NotImplementedError - def warmup(self, batch: B) -> Optional[int]: + def warmup(self, batch: generate_pb2.WarmupRequest) -> Tuple[Optional[int], Optional[int], Optional[int]]: self.generate_token(batch) - return None + return None, None, None def decode_token( self, diff --git a/backends/gaudi/server/text_generation_server/server.py b/backends/gaudi/server/text_generation_server/server.py index 159e6af1..7e15c0cf 100644 --- a/backends/gaudi/server/text_generation_server/server.py +++ b/backends/gaudi/server/text_generation_server/server.py @@ -102,7 +102,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) async def Warmup(self, request, context): - max_supported_total_tokens = self.model.warmup(request) + max_supported_total_tokens, max_input_tokens, max_total_tokens = self.model.warmup(request) # W/A for the skip tokenizer path # We need to call make_tokenizer_optional after the warmup, @@ -110,7 +110,9 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): make_tokenizer_optional(self.model.tokenizer) return generate_pb2.WarmupResponse( - max_supported_total_tokens=max_supported_total_tokens + max_supported_total_tokens=max_supported_total_tokens, + max_input_tokens=max_input_tokens, + max_total_tokens=max_total_tokens, ) async def Prefill(self, request, context): diff --git a/backends/gaudi/tgi-entrypoint.sh b/backends/gaudi/tgi-entrypoint.sh new file mode 100644 index 00000000..ea94dcd9 --- /dev/null +++ b/backends/gaudi/tgi-entrypoint.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +ldconfig 2>/dev/null || echo 'unable to refresh ld cache, not a big deal in most cases' + +text-generation-launcher $@ diff --git a/launcher/src/env_runtime.rs b/launcher/src/env_runtime.rs index 08fb301c..58080bd1 100644 --- a/launcher/src/env_runtime.rs +++ b/launcher/src/env_runtime.rs @@ -8,22 +8,29 @@ pub(crate) struct Env { docker_label: &'static str, nvidia_env: String, xpu_env: String, + hpu_env: String, } impl Env { pub fn new() -> Self { let nvidia_env = nvidia_smi(); let xpu_env = xpu_smi(); + let hpu_env = hl_smi(); Self { nvidia_env: nvidia_env.unwrap_or("N/A".to_string()), xpu_env: xpu_env.unwrap_or("N/A".to_string()), + hpu_env: hpu_env.unwrap_or("N/A".to_string()), cargo_target: env!("VERGEN_CARGO_TARGET_TRIPLE"), cargo_version: env!("VERGEN_RUSTC_SEMVER"), git_sha: option_env!("VERGEN_GIT_SHA").unwrap_or("N/A"), docker_label: option_env!("DOCKER_LABEL").unwrap_or("N/A"), } } + + pub fn is_hpu_device(&self) -> bool { + self.hpu_env != "N/A" + } } impl fmt::Display for Env { @@ -35,7 +42,8 @@ impl fmt::Display for Env { writeln!(f, "Commit sha: {}", self.git_sha)?; writeln!(f, "Docker label: {}", self.docker_label)?; writeln!(f, "nvidia-smi:\n{}", self.nvidia_env)?; - write!(f, "xpu-smi:\n{}", self.xpu_env)?; + writeln!(f, "xpu-smi:\n{}", self.xpu_env)?; + writeln!(f, "hpu-smi:\n{}", self.hpu_env)?; Ok(()) } @@ -54,3 +62,10 @@ fn xpu_smi() -> Option { let output = xpu_smi.replace('\n', "\n "); Some(output.trim().to_string()) } + +fn hl_smi() -> Option { + let output = Command::new("hl-smi").output().ok()?; + let hl_smi = String::from_utf8(output.stdout).ok()?; + let output = hl_smi.replace('\n', "\n "); + Some(output.trim().to_string()) +} \ No newline at end of file diff --git a/launcher/src/main.rs b/launcher/src/main.rs index fbbe8a2d..d5d1ba83 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1531,6 +1531,11 @@ fn spawn_shards( ) -> Result<(), LauncherError> { // Start shard processes for rank in 0..num_shard { + if rank != 0 && env_runtime::Env::new().is_hpu_device() { + tracing::info!("Running on HPU, the launcher will not do any sharding as actual sharding is done in the server"); + break; + } + let model_id = args.model_id.clone(); let revision = args.revision.clone(); let uds_path = args.shard_uds_path.clone(); @@ -1605,6 +1610,10 @@ fn spawn_shards( if shard_ready == num_shard { break; } + if env_runtime::Env::new().is_hpu_device() { + tracing::info!("HPU detected, shard is ready"); + break; + } } Err(TryRecvError::Empty) => { sleep(Duration::from_millis(100));