mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 08:22:07 +00:00
Gaudi: Use exponential growth to replace BATCH_BUCKET_SIZE
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
e497bc09f6
commit
2d2c56361d
@ -122,5 +122,5 @@ ENV OMPI_MCA_btl_vader_single_copy_mechanism NONE
|
||||
COPY backends/gaudi/tgi-entrypoint.sh /tgi-entrypoint.sh
|
||||
RUN chmod +x /tgi-entrypoint.sh
|
||||
|
||||
ENTRYPOINT ["/tgi-entrypoint.sh"]
|
||||
CMD ["--json-output"]
|
||||
#ENTRYPOINT ["/tgi-entrypoint.sh"]
|
||||
#CMD ["--json-output"]
|
||||
|
@ -8,7 +8,7 @@ PYTORCH_VERSION := 2.6.0
|
||||
.PHONY: image run-local-dev-container install-dependencies install-server install-router install-launcher local-dev-install
|
||||
|
||||
image:
|
||||
docker build -t tgi-gaudi -f ${root_dir}/Dockerfile_gaudi ${root_dir} --build-arg HABANA_VERSION=$(HABANA_VERSION) --build-arg PYTORCH_VERSION=$(PYTORCH_VERSION)
|
||||
docker build -t tgi-gaudi -f ${root_dir}/Dockerfile_gaudi ${root_dir} --build-arg HABANA_VERSION=$(HABANA_VERSION) --build-arg PYTORCH_VERSION=$(PYTORCH_VERSION) --build-arg http_proxy=${http_proxy} --build-arg https_proxy=${https_proxy} --build-arg no_proxy=${no_proxy}
|
||||
|
||||
run-local-dev-container:
|
||||
docker run -it \
|
||||
|
@ -57,8 +57,7 @@ MAX_TOTAL_TOKENS = int(os.environ.get("MAX_TOTAL_TOKENS", 2048))
|
||||
PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF", 256))
|
||||
CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
|
||||
LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1))
|
||||
BATCH_BUCKET_SIZE = int(os.environ.get("BATCH_BUCKET_SIZE", 8))
|
||||
PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get("PREFILL_BATCH_BUCKET_SIZE", 2))
|
||||
BATCH_SIZE_EXPONENT_BASE = int(os.environ.get("BATCH_SIZE_EXPONENT_BASE", 2))
|
||||
MAX_BATCH_SIZE = (
|
||||
int(os.environ.get("MAX_BATCH_SIZE"))
|
||||
if os.environ.get("MAX_BATCH_SIZE") is not None
|
||||
@ -74,10 +73,16 @@ def torch_compile_for_eager(func):
|
||||
)
|
||||
|
||||
|
||||
def round_up(number, k):
|
||||
def round_up_seq(number, k):
|
||||
return (number + k - 1) // k * k
|
||||
|
||||
|
||||
def round_up_batch(number):
|
||||
return BATCH_SIZE_EXPONENT_BASE ** (
|
||||
math.ceil(math.log(number, BATCH_SIZE_EXPONENT_BASE))
|
||||
)
|
||||
|
||||
|
||||
def to_tensor_indices(indices, device):
|
||||
return torch.tensor(indices, dtype=torch.long, device=device)
|
||||
|
||||
@ -399,7 +404,7 @@ class CausalLMBatch(Batch):
|
||||
|
||||
total_requests = sum(len(b) for b in batches)
|
||||
new_bs = total_requests
|
||||
new_bs = round_up(total_requests, BATCH_BUCKET_SIZE)
|
||||
new_bs = round_up_batch(total_requests)
|
||||
|
||||
batch_id = batches[0].batch_id
|
||||
device = batches[0].input_ids.device
|
||||
@ -540,7 +545,7 @@ class CausalLMBatch(Batch):
|
||||
# TODO: by tokenizing all inputs at once we loose information on actual input lengths
|
||||
# this means that we cannot shift inputs to the left after a long input sequence
|
||||
# was filtered out
|
||||
new_bs = round_up(len(requests), PREFILL_BATCH_BUCKET_SIZE)
|
||||
new_bs = round_up_batch(len(requests))
|
||||
missing_inputs = new_bs - len(inputs)
|
||||
dummy_inputs = ["?"] * missing_inputs
|
||||
parameters = [r.parameters for r in pb.requests]
|
||||
@ -572,7 +577,7 @@ class CausalLMBatch(Batch):
|
||||
assert (
|
||||
PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length
|
||||
), "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length"
|
||||
rounded_seq_len = round_up(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF)
|
||||
rounded_seq_len = round_up_seq(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF)
|
||||
if rounded_seq_len <= max_input_length:
|
||||
bucket_size = rounded_seq_len - 1
|
||||
else:
|
||||
@ -1068,10 +1073,10 @@ class CausalLM(Model):
|
||||
if (
|
||||
self.enable_hpu_graph
|
||||
and self.limit_hpu_graph
|
||||
and round_up(batch.batch_size, BATCH_BUCKET_SIZE) != self.prev_bs
|
||||
and round_up_batch(batch.batch_size) != self.prev_bs
|
||||
):
|
||||
self.model.clear_cache()
|
||||
self.prev_bs = round_up(batch.batch_size, BATCH_BUCKET_SIZE)
|
||||
self.prev_bs = round_up_batch(batch.batch_size)
|
||||
dbg_trace(
|
||||
scenario,
|
||||
f"bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}",
|
||||
@ -1325,15 +1330,14 @@ class CausalLM(Model):
|
||||
|
||||
# Warmup prefill batch_size
|
||||
max_input_tokens = request.max_input_tokens
|
||||
max_exp = math.ceil(math.log(max_prefill_batch_size, BATCH_SIZE_EXPONENT_BASE))
|
||||
prefill_batch_size_list = [
|
||||
batch
|
||||
for batch in range(
|
||||
PREFILL_BATCH_BUCKET_SIZE,
|
||||
max_prefill_batch_size,
|
||||
PREFILL_BATCH_BUCKET_SIZE,
|
||||
BATCH_SIZE_EXPONENT_BASE**exp
|
||||
for exp in range(
|
||||
0,
|
||||
max_exp + 1,
|
||||
)
|
||||
]
|
||||
prefill_batch_size_list.append(max_prefill_batch_size)
|
||||
prefill_seqlen_list = [
|
||||
seq
|
||||
for seq in range(
|
||||
@ -1370,12 +1374,10 @@ class CausalLM(Model):
|
||||
)
|
||||
|
||||
max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS)
|
||||
max_decode_batch_size = round_up(max_decode_batch_size, BATCH_BUCKET_SIZE)
|
||||
max_exp = math.ceil(math.log(max_decode_batch_size, BATCH_SIZE_EXPONENT_BASE))
|
||||
decode_batch_size_list = [
|
||||
i
|
||||
for i in range(BATCH_BUCKET_SIZE, max_decode_batch_size, BATCH_BUCKET_SIZE)
|
||||
BATCH_SIZE_EXPONENT_BASE**exp for exp in range(0, max_exp + 1)
|
||||
]
|
||||
decode_batch_size_list.append(max_decode_batch_size)
|
||||
decode_batch_size_list.sort(reverse=True)
|
||||
|
||||
try:
|
||||
|
Loading…
Reference in New Issue
Block a user