mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
Fix the warmup issue of llama2-7B.
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
c6f023a06b
commit
c922ef9534
@ -1184,6 +1184,7 @@ class CausalLM(Model):
|
|||||||
MAX_TOTAL_TOKENS = request.max_total_tokens
|
MAX_TOTAL_TOKENS = request.max_total_tokens
|
||||||
MAX_BATCH_TOTAL_TOKENS = request.max_batch_total_tokens
|
MAX_BATCH_TOTAL_TOKENS = request.max_batch_total_tokens
|
||||||
batch = self.batch_type.from_pb(request.batch, self.tokenizer, self.dtype, self.device)
|
batch = self.batch_type.from_pb(request.batch, self.tokenizer, self.dtype, self.device)
|
||||||
|
max_prefill_batch_size = batch.input_ids.shape[0]
|
||||||
try:
|
try:
|
||||||
# max prefill batch size warmup
|
# max prefill batch size warmup
|
||||||
_, prefill_batch, _ = self.generate_token([batch])
|
_, prefill_batch, _ = self.generate_token([batch])
|
||||||
@ -1192,9 +1193,9 @@ class CausalLM(Model):
|
|||||||
f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
|
f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
|
||||||
f"You need to decrease `--max-batch-prefill-tokens`"
|
f"You need to decrease `--max-batch-prefill-tokens`"
|
||||||
)
|
)
|
||||||
|
|
||||||
del prefill_batch
|
del prefill_batch
|
||||||
|
|
||||||
max_prefill_batch_size = batch.input_ids.shape[0]
|
|
||||||
# Warmup prefill batch_size
|
# Warmup prefill batch_size
|
||||||
max_input_length = request.max_input_length
|
max_input_length = request.max_input_length
|
||||||
prefill_batch_size_list = [batch for batch in range(BATCH_BUCKET_SIZE, max_prefill_batch_size, BATCH_BUCKET_SIZE)]
|
prefill_batch_size_list = [batch for batch in range(BATCH_BUCKET_SIZE, max_prefill_batch_size, BATCH_BUCKET_SIZE)]
|
||||||
@ -1221,7 +1222,7 @@ class CausalLM(Model):
|
|||||||
prefill_batch_size_list.sort()
|
prefill_batch_size_list.sort()
|
||||||
mem_stats = get_hpu_memory_stats(self.device)
|
mem_stats = get_hpu_memory_stats(self.device)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"\nFollowing prefill and decode warmup successfully.\n"
|
f"\nFollowing prefill warmup successfully.\n"
|
||||||
f"Prefill batch size list:{prefill_batch_size_list}\n"
|
f"Prefill batch size list:{prefill_batch_size_list}\n"
|
||||||
f"Prefill sequence length list:{prefill_seqlen_list}\n"
|
f"Prefill sequence length list:{prefill_seqlen_list}\n"
|
||||||
f"Memory stats: {mem_stats} "
|
f"Memory stats: {mem_stats} "
|
||||||
@ -1235,7 +1236,6 @@ class CausalLM(Model):
|
|||||||
decode_batch_size_list.sort(reverse=True)
|
decode_batch_size_list.sort(reverse=True)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for i in range(2):
|
|
||||||
for batch_size in decode_batch_size_list:
|
for batch_size in decode_batch_size_list:
|
||||||
batches= []
|
batches= []
|
||||||
iters = math.floor(batch_size/max_prefill_batch_size)
|
iters = math.floor(batch_size/max_prefill_batch_size)
|
||||||
@ -1250,6 +1250,7 @@ class CausalLM(Model):
|
|||||||
batches.append(prefill_batch)
|
batches.append(prefill_batch)
|
||||||
|
|
||||||
_, decode_batch, _ = self.generate_token(batches)
|
_, decode_batch, _ = self.generate_token(batches)
|
||||||
|
_, decode_batch, _ = self.generate_token([decode_batch])
|
||||||
del decode_batch
|
del decode_batch
|
||||||
batches.clear()
|
batches.clear()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user