diff --git a/Dockerfile_gaudi b/Dockerfile_gaudi index 9009f95b..39fdd703 100644 --- a/Dockerfile_gaudi +++ b/Dockerfile_gaudi @@ -122,5 +122,5 @@ ENV OMPI_MCA_btl_vader_single_copy_mechanism NONE COPY backends/gaudi/tgi-entrypoint.sh /tgi-entrypoint.sh RUN chmod +x /tgi-entrypoint.sh -ENTRYPOINT ["/tgi-entrypoint.sh"] -CMD ["--json-output"] +#ENTRYPOINT ["/tgi-entrypoint.sh"] +#CMD ["--json-output"] diff --git a/backends/gaudi/Makefile b/backends/gaudi/Makefile index 6e38c19e..027bb8b2 100644 --- a/backends/gaudi/Makefile +++ b/backends/gaudi/Makefile @@ -8,7 +8,7 @@ PYTORCH_VERSION := 2.6.0 .PHONY: image run-local-dev-container install-dependencies install-server install-router install-launcher local-dev-install image: - docker build -t tgi-gaudi -f ${root_dir}/Dockerfile_gaudi ${root_dir} --build-arg HABANA_VERSION=$(HABANA_VERSION) --build-arg PYTORCH_VERSION=$(PYTORCH_VERSION) + docker build -t tgi-gaudi -f ${root_dir}/Dockerfile_gaudi ${root_dir} --build-arg HABANA_VERSION=$(HABANA_VERSION) --build-arg PYTORCH_VERSION=$(PYTORCH_VERSION) --build-arg http_proxy=${http_proxy} --build-arg https_proxy=${https_proxy} --build-arg no_proxy=${no_proxy} run-local-dev-container: docker run -it \ diff --git a/backends/gaudi/server/text_generation_server/models/causal_lm.py b/backends/gaudi/server/text_generation_server/models/causal_lm.py index 776c109f..c1ce3335 100644 --- a/backends/gaudi/server/text_generation_server/models/causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/causal_lm.py @@ -57,8 +57,7 @@ MAX_TOTAL_TOKENS = int(os.environ.get("MAX_TOTAL_TOKENS", 2048)) PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF", 256)) CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1)) -BATCH_BUCKET_SIZE = int(os.environ.get("BATCH_BUCKET_SIZE", 8)) -PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get("PREFILL_BATCH_BUCKET_SIZE", 2)) +BATCH_SIZE_EXPONENT_BASE = int(os.environ.get("BATCH_SIZE_EXPONENT_BASE", 2)) MAX_BATCH_SIZE = ( int(os.environ.get("MAX_BATCH_SIZE")) if os.environ.get("MAX_BATCH_SIZE") is not None @@ -74,10 +73,16 @@ def torch_compile_for_eager(func): ) -def round_up(number, k): +def round_up_seq(number, k): return (number + k - 1) // k * k +def round_up_batch(number): + return BATCH_SIZE_EXPONENT_BASE ** ( + math.ceil(math.log(number, BATCH_SIZE_EXPONENT_BASE)) + ) + + def to_tensor_indices(indices, device): return torch.tensor(indices, dtype=torch.long, device=device) @@ -399,7 +404,7 @@ class CausalLMBatch(Batch): total_requests = sum(len(b) for b in batches) new_bs = total_requests - new_bs = round_up(total_requests, BATCH_BUCKET_SIZE) + new_bs = round_up_batch(total_requests) batch_id = batches[0].batch_id device = batches[0].input_ids.device @@ -540,7 +545,7 @@ class CausalLMBatch(Batch): # TODO: by tokenizing all inputs at once we loose information on actual input lengths # this means that we cannot shift inputs to the left after a long input sequence # was filtered out - new_bs = round_up(len(requests), PREFILL_BATCH_BUCKET_SIZE) + new_bs = round_up_batch(len(requests)) missing_inputs = new_bs - len(inputs) dummy_inputs = ["?"] * missing_inputs parameters = [r.parameters for r in pb.requests] @@ -572,7 +577,7 @@ class CausalLMBatch(Batch): assert ( PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length ), "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length" - rounded_seq_len = round_up(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF) + rounded_seq_len = round_up_seq(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF) if rounded_seq_len <= max_input_length: bucket_size = rounded_seq_len - 1 else: @@ -1068,10 +1073,10 @@ class CausalLM(Model): if ( self.enable_hpu_graph and self.limit_hpu_graph - and round_up(batch.batch_size, BATCH_BUCKET_SIZE) != self.prev_bs + and round_up_batch(batch.batch_size) != self.prev_bs ): self.model.clear_cache() - self.prev_bs = round_up(batch.batch_size, BATCH_BUCKET_SIZE) + self.prev_bs = round_up_batch(batch.batch_size) dbg_trace( scenario, f"bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}", @@ -1325,15 +1330,14 @@ class CausalLM(Model): # Warmup prefill batch_size max_input_tokens = request.max_input_tokens + max_exp = math.ceil(math.log(max_prefill_batch_size, BATCH_SIZE_EXPONENT_BASE)) prefill_batch_size_list = [ - batch - for batch in range( - PREFILL_BATCH_BUCKET_SIZE, - max_prefill_batch_size, - PREFILL_BATCH_BUCKET_SIZE, + BATCH_SIZE_EXPONENT_BASE**exp + for exp in range( + 0, + max_exp + 1, ) ] - prefill_batch_size_list.append(max_prefill_batch_size) prefill_seqlen_list = [ seq for seq in range( @@ -1370,12 +1374,10 @@ class CausalLM(Model): ) max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS) - max_decode_batch_size = round_up(max_decode_batch_size, BATCH_BUCKET_SIZE) + max_exp = math.ceil(math.log(max_decode_batch_size, BATCH_SIZE_EXPONENT_BASE)) decode_batch_size_list = [ - i - for i in range(BATCH_BUCKET_SIZE, max_decode_batch_size, BATCH_BUCKET_SIZE) + BATCH_SIZE_EXPONENT_BASE**exp for exp in range(0, max_exp + 1) ] - decode_batch_size_list.append(max_decode_batch_size) decode_batch_size_list.sort(reverse=True) try: