diff --git a/Dockerfile_gaudi b/Dockerfile_gaudi index 5bb5b606..54a0bb7c 100644 --- a/Dockerfile_gaudi +++ b/Dockerfile_gaudi @@ -61,6 +61,7 @@ ENV ATTENTION=default ENV PREFIX_CACHING=0 ENV PREFILL_CHUNKING=0 ENV PT_HPU_LAZY_MODE=1 +ENV PT_HPU_WEIGHT_SHARING=0 # Text Generation Inference base env ENV HF_HOME=/data \ @@ -96,7 +97,8 @@ RUN cd server && \ pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \ BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \ pip install . --no-cache-dir -RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git +RUN pip install git+https://github.com/HabanaAI/vllm-hpu-extension.git@a060794 + # Install benchmarker COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark # Install router diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index cafaae23..9a0f789a 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -72,7 +72,7 @@ from text_generation_server.utils.import_utils import ( import vllm_hpu_extension.environment as environment import habana_frameworks.torch as htorch import itertools -from vllm_hpu_extension.bucketing import HPUBucketingContext +from vllm_hpu_extension.bucketing.common import get_bucketing_context tracer = trace.get_tracer(__name__) @@ -1497,6 +1497,11 @@ class FlashCausalLM(Model): max_input_tokens: Optional[int], max_total_tokens: Optional[int], ): + if os.environ.get("MAX_BATCH_SIZE") is None: + raise RuntimeError( + "MAX_BATCH_SIZE is not set, it should be set in the launcher " + "using `--max-batch-size xxx`" + ) # The warmup batch is the biggest batch we could ever receive self.kv_cache = [] empty_cache() @@ -1566,25 +1571,28 @@ class FlashCausalLM(Model): self.device, ) self.max_batch_prefill_tokens = max_input_tokens * len(batch) - - max_num_seqs = int(os.getenv("MAX_BATCH_SIZE", 128)) - if os.getenv("VLLM_PROMPT_SEQ_BUCKET_MAX") is None: - os.environ["VLLM_PROMPT_SEQ_BUCKET_MAX"] = str(max_input_tokens) - if os.getenv("VLLM_DECODE_BLOCK_BUCKET_MAX") is None: - max_total_blocks = ( - math.ceil(max_total_tokens / BLOCK_SIZE) * max_num_seqs + 1 - ) - os.environ["VLLM_DECODE_BLOCK_BUCKET_MAX"] = str(max_total_blocks) - + max_num_seqs = int(os.getenv("MAX_BATCH_SIZE")) + HPUBucketingContext = get_bucketing_context() + max_total_tokens_aligned = math.ceil(max_total_tokens / BLOCK_SIZE) * BLOCK_SIZE self.bucketing_ctx = HPUBucketingContext( max_num_seqs, - os.getenv("PREFILL_MAX_BS", 64), # self.max_num_prefill_seqs, #TODO + max_num_seqs, # self.max_num_prefill_seqs, #TODO BLOCK_SIZE, - num_blocks * BLOCK_SIZE, + max_num_seqs * max_total_tokens_aligned, False, + self.tokenizer.model_max_length, + max_input_tokens, + max_total_tokens_aligned, ) - self.bucketing_ctx.num_hpu_blocks = num_blocks + max_blocks = max( + BLOCK_SIZE, max_num_seqs * max_total_tokens_aligned // BLOCK_SIZE + ) + self.bucketing_ctx.num_hpu_blocks = max_blocks if os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true": + self.bucketing_ctx.generate_prompt_buckets() + self.bucketing_ctx.generate_decode_buckets( + self.bucketing_ctx.num_hpu_blocks + ) logger.info("skip warmup hpu graph, not recommmended") del _batch, batch return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens @@ -1606,8 +1614,6 @@ class FlashCausalLM(Model): for i, (batch_size, seq_len) in enumerate( reversed(self.bucketing_ctx.prompt_buckets) ): - if batch_size * seq_len > self.max_batch_prefill_tokens: - continue log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}") for index in range(warmup_times): self.warmup_prefill(seq_len, batch_size, batch) diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index dac65fea..f9186450 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -350,8 +350,6 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): for i, (batch_size, seq_len) in enumerate( reversed(self.bucketing_ctx.prompt_buckets) ): - if batch_size * seq_len > self.max_batch_prefill_tokens: - continue log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}") for index in range(warmup_times): self.warmup_prefill(seq_len, batch_size, batch)