Gaudi: Use exponential growth to replace BATCH_BUCKET_SIZE

Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
yuanwu 2025-03-21 05:46:08 +00:00
parent e497bc09f6
commit 2d2c56361d
3 changed files with 23 additions and 21 deletions

View File

@ -122,5 +122,5 @@ ENV OMPI_MCA_btl_vader_single_copy_mechanism NONE
COPY backends/gaudi/tgi-entrypoint.sh /tgi-entrypoint.sh
RUN chmod +x /tgi-entrypoint.sh
ENTRYPOINT ["/tgi-entrypoint.sh"]
CMD ["--json-output"]
#ENTRYPOINT ["/tgi-entrypoint.sh"]
#CMD ["--json-output"]

View File

@ -8,7 +8,7 @@ PYTORCH_VERSION := 2.6.0
.PHONY: image run-local-dev-container install-dependencies install-server install-router install-launcher local-dev-install
image:
docker build -t tgi-gaudi -f ${root_dir}/Dockerfile_gaudi ${root_dir} --build-arg HABANA_VERSION=$(HABANA_VERSION) --build-arg PYTORCH_VERSION=$(PYTORCH_VERSION)
docker build -t tgi-gaudi -f ${root_dir}/Dockerfile_gaudi ${root_dir} --build-arg HABANA_VERSION=$(HABANA_VERSION) --build-arg PYTORCH_VERSION=$(PYTORCH_VERSION) --build-arg http_proxy=${http_proxy} --build-arg https_proxy=${https_proxy} --build-arg no_proxy=${no_proxy}
run-local-dev-container:
docker run -it \

View File

@ -57,8 +57,7 @@ MAX_TOTAL_TOKENS = int(os.environ.get("MAX_TOTAL_TOKENS", 2048))
PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF", 256))
CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1))
BATCH_BUCKET_SIZE = int(os.environ.get("BATCH_BUCKET_SIZE", 8))
PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get("PREFILL_BATCH_BUCKET_SIZE", 2))
BATCH_SIZE_EXPONENT_BASE = int(os.environ.get("BATCH_SIZE_EXPONENT_BASE", 2))
MAX_BATCH_SIZE = (
int(os.environ.get("MAX_BATCH_SIZE"))
if os.environ.get("MAX_BATCH_SIZE") is not None
@ -74,10 +73,16 @@ def torch_compile_for_eager(func):
)
def round_up(number, k):
def round_up_seq(number, k):
return (number + k - 1) // k * k
def round_up_batch(number):
return BATCH_SIZE_EXPONENT_BASE ** (
math.ceil(math.log(number, BATCH_SIZE_EXPONENT_BASE))
)
def to_tensor_indices(indices, device):
return torch.tensor(indices, dtype=torch.long, device=device)
@ -399,7 +404,7 @@ class CausalLMBatch(Batch):
total_requests = sum(len(b) for b in batches)
new_bs = total_requests
new_bs = round_up(total_requests, BATCH_BUCKET_SIZE)
new_bs = round_up_batch(total_requests)
batch_id = batches[0].batch_id
device = batches[0].input_ids.device
@ -540,7 +545,7 @@ class CausalLMBatch(Batch):
# TODO: by tokenizing all inputs at once we loose information on actual input lengths
# this means that we cannot shift inputs to the left after a long input sequence
# was filtered out
new_bs = round_up(len(requests), PREFILL_BATCH_BUCKET_SIZE)
new_bs = round_up_batch(len(requests))
missing_inputs = new_bs - len(inputs)
dummy_inputs = ["?"] * missing_inputs
parameters = [r.parameters for r in pb.requests]
@ -572,7 +577,7 @@ class CausalLMBatch(Batch):
assert (
PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length
), "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length"
rounded_seq_len = round_up(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF)
rounded_seq_len = round_up_seq(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF)
if rounded_seq_len <= max_input_length:
bucket_size = rounded_seq_len - 1
else:
@ -1068,10 +1073,10 @@ class CausalLM(Model):
if (
self.enable_hpu_graph
and self.limit_hpu_graph
and round_up(batch.batch_size, BATCH_BUCKET_SIZE) != self.prev_bs
and round_up_batch(batch.batch_size) != self.prev_bs
):
self.model.clear_cache()
self.prev_bs = round_up(batch.batch_size, BATCH_BUCKET_SIZE)
self.prev_bs = round_up_batch(batch.batch_size)
dbg_trace(
scenario,
f"bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}",
@ -1325,15 +1330,14 @@ class CausalLM(Model):
# Warmup prefill batch_size
max_input_tokens = request.max_input_tokens
max_exp = math.ceil(math.log(max_prefill_batch_size, BATCH_SIZE_EXPONENT_BASE))
prefill_batch_size_list = [
batch
for batch in range(
PREFILL_BATCH_BUCKET_SIZE,
max_prefill_batch_size,
PREFILL_BATCH_BUCKET_SIZE,
BATCH_SIZE_EXPONENT_BASE**exp
for exp in range(
0,
max_exp + 1,
)
]
prefill_batch_size_list.append(max_prefill_batch_size)
prefill_seqlen_list = [
seq
for seq in range(
@ -1370,12 +1374,10 @@ class CausalLM(Model):
)
max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS)
max_decode_batch_size = round_up(max_decode_batch_size, BATCH_BUCKET_SIZE)
max_exp = math.ceil(math.log(max_decode_batch_size, BATCH_SIZE_EXPONENT_BASE))
decode_batch_size_list = [
i
for i in range(BATCH_BUCKET_SIZE, max_decode_batch_size, BATCH_BUCKET_SIZE)
BATCH_SIZE_EXPONENT_BASE**exp for exp in range(0, max_exp + 1)
]
decode_batch_size_list.append(max_decode_batch_size)
decode_batch_size_list.sort(reverse=True)
try: