mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 08:22:07 +00:00
Gaudi: Use exponential growth to replace BATCH_BUCKET_SIZE
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
e497bc09f6
commit
2d2c56361d
@ -122,5 +122,5 @@ ENV OMPI_MCA_btl_vader_single_copy_mechanism NONE
|
|||||||
COPY backends/gaudi/tgi-entrypoint.sh /tgi-entrypoint.sh
|
COPY backends/gaudi/tgi-entrypoint.sh /tgi-entrypoint.sh
|
||||||
RUN chmod +x /tgi-entrypoint.sh
|
RUN chmod +x /tgi-entrypoint.sh
|
||||||
|
|
||||||
ENTRYPOINT ["/tgi-entrypoint.sh"]
|
#ENTRYPOINT ["/tgi-entrypoint.sh"]
|
||||||
CMD ["--json-output"]
|
#CMD ["--json-output"]
|
||||||
|
@ -8,7 +8,7 @@ PYTORCH_VERSION := 2.6.0
|
|||||||
.PHONY: image run-local-dev-container install-dependencies install-server install-router install-launcher local-dev-install
|
.PHONY: image run-local-dev-container install-dependencies install-server install-router install-launcher local-dev-install
|
||||||
|
|
||||||
image:
|
image:
|
||||||
docker build -t tgi-gaudi -f ${root_dir}/Dockerfile_gaudi ${root_dir} --build-arg HABANA_VERSION=$(HABANA_VERSION) --build-arg PYTORCH_VERSION=$(PYTORCH_VERSION)
|
docker build -t tgi-gaudi -f ${root_dir}/Dockerfile_gaudi ${root_dir} --build-arg HABANA_VERSION=$(HABANA_VERSION) --build-arg PYTORCH_VERSION=$(PYTORCH_VERSION) --build-arg http_proxy=${http_proxy} --build-arg https_proxy=${https_proxy} --build-arg no_proxy=${no_proxy}
|
||||||
|
|
||||||
run-local-dev-container:
|
run-local-dev-container:
|
||||||
docker run -it \
|
docker run -it \
|
||||||
|
@ -57,8 +57,7 @@ MAX_TOTAL_TOKENS = int(os.environ.get("MAX_TOTAL_TOKENS", 2048))
|
|||||||
PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF", 256))
|
PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF", 256))
|
||||||
CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
|
CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
|
||||||
LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1))
|
LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1))
|
||||||
BATCH_BUCKET_SIZE = int(os.environ.get("BATCH_BUCKET_SIZE", 8))
|
BATCH_SIZE_EXPONENT_BASE = int(os.environ.get("BATCH_SIZE_EXPONENT_BASE", 2))
|
||||||
PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get("PREFILL_BATCH_BUCKET_SIZE", 2))
|
|
||||||
MAX_BATCH_SIZE = (
|
MAX_BATCH_SIZE = (
|
||||||
int(os.environ.get("MAX_BATCH_SIZE"))
|
int(os.environ.get("MAX_BATCH_SIZE"))
|
||||||
if os.environ.get("MAX_BATCH_SIZE") is not None
|
if os.environ.get("MAX_BATCH_SIZE") is not None
|
||||||
@ -74,10 +73,16 @@ def torch_compile_for_eager(func):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def round_up(number, k):
|
def round_up_seq(number, k):
|
||||||
return (number + k - 1) // k * k
|
return (number + k - 1) // k * k
|
||||||
|
|
||||||
|
|
||||||
|
def round_up_batch(number):
|
||||||
|
return BATCH_SIZE_EXPONENT_BASE ** (
|
||||||
|
math.ceil(math.log(number, BATCH_SIZE_EXPONENT_BASE))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def to_tensor_indices(indices, device):
|
def to_tensor_indices(indices, device):
|
||||||
return torch.tensor(indices, dtype=torch.long, device=device)
|
return torch.tensor(indices, dtype=torch.long, device=device)
|
||||||
|
|
||||||
@ -399,7 +404,7 @@ class CausalLMBatch(Batch):
|
|||||||
|
|
||||||
total_requests = sum(len(b) for b in batches)
|
total_requests = sum(len(b) for b in batches)
|
||||||
new_bs = total_requests
|
new_bs = total_requests
|
||||||
new_bs = round_up(total_requests, BATCH_BUCKET_SIZE)
|
new_bs = round_up_batch(total_requests)
|
||||||
|
|
||||||
batch_id = batches[0].batch_id
|
batch_id = batches[0].batch_id
|
||||||
device = batches[0].input_ids.device
|
device = batches[0].input_ids.device
|
||||||
@ -540,7 +545,7 @@ class CausalLMBatch(Batch):
|
|||||||
# TODO: by tokenizing all inputs at once we loose information on actual input lengths
|
# TODO: by tokenizing all inputs at once we loose information on actual input lengths
|
||||||
# this means that we cannot shift inputs to the left after a long input sequence
|
# this means that we cannot shift inputs to the left after a long input sequence
|
||||||
# was filtered out
|
# was filtered out
|
||||||
new_bs = round_up(len(requests), PREFILL_BATCH_BUCKET_SIZE)
|
new_bs = round_up_batch(len(requests))
|
||||||
missing_inputs = new_bs - len(inputs)
|
missing_inputs = new_bs - len(inputs)
|
||||||
dummy_inputs = ["?"] * missing_inputs
|
dummy_inputs = ["?"] * missing_inputs
|
||||||
parameters = [r.parameters for r in pb.requests]
|
parameters = [r.parameters for r in pb.requests]
|
||||||
@ -572,7 +577,7 @@ class CausalLMBatch(Batch):
|
|||||||
assert (
|
assert (
|
||||||
PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length
|
PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length
|
||||||
), "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length"
|
), "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length"
|
||||||
rounded_seq_len = round_up(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF)
|
rounded_seq_len = round_up_seq(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF)
|
||||||
if rounded_seq_len <= max_input_length:
|
if rounded_seq_len <= max_input_length:
|
||||||
bucket_size = rounded_seq_len - 1
|
bucket_size = rounded_seq_len - 1
|
||||||
else:
|
else:
|
||||||
@ -1068,10 +1073,10 @@ class CausalLM(Model):
|
|||||||
if (
|
if (
|
||||||
self.enable_hpu_graph
|
self.enable_hpu_graph
|
||||||
and self.limit_hpu_graph
|
and self.limit_hpu_graph
|
||||||
and round_up(batch.batch_size, BATCH_BUCKET_SIZE) != self.prev_bs
|
and round_up_batch(batch.batch_size) != self.prev_bs
|
||||||
):
|
):
|
||||||
self.model.clear_cache()
|
self.model.clear_cache()
|
||||||
self.prev_bs = round_up(batch.batch_size, BATCH_BUCKET_SIZE)
|
self.prev_bs = round_up_batch(batch.batch_size)
|
||||||
dbg_trace(
|
dbg_trace(
|
||||||
scenario,
|
scenario,
|
||||||
f"bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}",
|
f"bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}",
|
||||||
@ -1325,15 +1330,14 @@ class CausalLM(Model):
|
|||||||
|
|
||||||
# Warmup prefill batch_size
|
# Warmup prefill batch_size
|
||||||
max_input_tokens = request.max_input_tokens
|
max_input_tokens = request.max_input_tokens
|
||||||
|
max_exp = math.ceil(math.log(max_prefill_batch_size, BATCH_SIZE_EXPONENT_BASE))
|
||||||
prefill_batch_size_list = [
|
prefill_batch_size_list = [
|
||||||
batch
|
BATCH_SIZE_EXPONENT_BASE**exp
|
||||||
for batch in range(
|
for exp in range(
|
||||||
PREFILL_BATCH_BUCKET_SIZE,
|
0,
|
||||||
max_prefill_batch_size,
|
max_exp + 1,
|
||||||
PREFILL_BATCH_BUCKET_SIZE,
|
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
prefill_batch_size_list.append(max_prefill_batch_size)
|
|
||||||
prefill_seqlen_list = [
|
prefill_seqlen_list = [
|
||||||
seq
|
seq
|
||||||
for seq in range(
|
for seq in range(
|
||||||
@ -1370,12 +1374,10 @@ class CausalLM(Model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS)
|
max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS)
|
||||||
max_decode_batch_size = round_up(max_decode_batch_size, BATCH_BUCKET_SIZE)
|
max_exp = math.ceil(math.log(max_decode_batch_size, BATCH_SIZE_EXPONENT_BASE))
|
||||||
decode_batch_size_list = [
|
decode_batch_size_list = [
|
||||||
i
|
BATCH_SIZE_EXPONENT_BASE**exp for exp in range(0, max_exp + 1)
|
||||||
for i in range(BATCH_BUCKET_SIZE, max_decode_batch_size, BATCH_BUCKET_SIZE)
|
|
||||||
]
|
]
|
||||||
decode_batch_size_list.append(max_decode_batch_size)
|
|
||||||
decode_batch_size_list.sort(reverse=True)
|
decode_batch_size_list.sort(reverse=True)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
Loading…
Reference in New Issue
Block a user