mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-15 13:52:06 +00:00
Fix the starCode warmup issue
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
b83419a769
commit
4586325a34
@ -61,28 +61,13 @@ LAZY_MODE = int(os.environ.get('PT_HPU_LAZY_MODE', 1))
|
|||||||
BATCH_BUCKET_SIZE= int(os.environ.get('BATCH_BUCKET_SIZE', 8))
|
BATCH_BUCKET_SIZE= int(os.environ.get('BATCH_BUCKET_SIZE', 8))
|
||||||
PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 2))
|
PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 2))
|
||||||
|
|
||||||
|
|
||||||
PREFILL_WARMUP_BATCH_SIZE_LIST = []
|
|
||||||
PREFILL_WARMUP_SEQLEN_LIST = []
|
|
||||||
DECODE_WARMUP_BATCH_SIZE_LIST = []
|
|
||||||
|
|
||||||
|
|
||||||
def torch_compile_for_eager(func):
|
def torch_compile_for_eager(func):
|
||||||
if LAZY_MODE == 1:
|
if LAZY_MODE == 1:
|
||||||
return func
|
return func
|
||||||
return torch.compile(func, backend="hpu_backend", options={"keep_input_mutations": True})
|
return torch.compile(func, backend="hpu_backend", options={"keep_input_mutations": True})
|
||||||
|
|
||||||
|
def round_up(number, k):
|
||||||
def round_up(warmup_list:list, num) :
|
return (number + k - 1) // k * k
|
||||||
i = 0
|
|
||||||
if len(warmup_list) == 0:
|
|
||||||
return num
|
|
||||||
|
|
||||||
for i in warmup_list:
|
|
||||||
if num <= i :
|
|
||||||
break
|
|
||||||
return i
|
|
||||||
|
|
||||||
|
|
||||||
def to_tensor_indices(indices, device):
|
def to_tensor_indices(indices, device):
|
||||||
return torch.tensor(indices, dtype=torch.long, device=device)
|
return torch.tensor(indices, dtype=torch.long, device=device)
|
||||||
@ -372,14 +357,13 @@ class CausalLMBatch(Batch):
|
|||||||
self.set_tensor_groups(dst_tensors)
|
self.set_tensor_groups(dst_tensors)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def recombine(cls, batches: List["CausalLMBatch"], pad_token_id: int, is_warmup: bool =False) -> "CausalLMBatch":
|
def recombine(cls, batches: List["CausalLMBatch"], pad_token_id: int) -> "CausalLMBatch":
|
||||||
if not all(b.past_key_values is not None for b in batches):
|
if not all(b.past_key_values is not None for b in batches):
|
||||||
raise ValueError("KV cache not allocated! Cannot recombine before prefill!")
|
raise ValueError("KV cache not allocated! Cannot recombine before prefill!")
|
||||||
|
|
||||||
total_requests = sum(len(b) for b in batches)
|
total_requests = sum(len(b) for b in batches)
|
||||||
new_bs = total_requests
|
new_bs = total_requests
|
||||||
if is_warmup is False :
|
new_bs = round_up(total_requests, BATCH_BUCKET_SIZE)
|
||||||
new_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST, total_requests)
|
|
||||||
|
|
||||||
batch_id = batches[0].batch_id
|
batch_id = batches[0].batch_id
|
||||||
device = batches[0].input_ids.device
|
device = batches[0].input_ids.device
|
||||||
@ -481,7 +465,6 @@ class CausalLMBatch(Batch):
|
|||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
is_warmup: bool = False,
|
|
||||||
) -> "CausalLMBatch":
|
) -> "CausalLMBatch":
|
||||||
dbg_trace('FROM_PB', f'num_reqs:{len(pb.requests)}')
|
dbg_trace('FROM_PB', f'num_reqs:{len(pb.requests)}')
|
||||||
requests = [CausalLMRequest.from_pb(idx, req, tokenizer) for idx, req in enumerate(pb.requests)]
|
requests = [CausalLMRequest.from_pb(idx, req, tokenizer) for idx, req in enumerate(pb.requests)]
|
||||||
@ -503,7 +486,7 @@ class CausalLMBatch(Batch):
|
|||||||
# TODO: by tokenizing all inputs at once we loose information on actual input lengths
|
# TODO: by tokenizing all inputs at once we loose information on actual input lengths
|
||||||
# this means that we cannot shift inputs to the left after a long input sequence
|
# this means that we cannot shift inputs to the left after a long input sequence
|
||||||
# was filtered out
|
# was filtered out
|
||||||
new_bs = round_up(PREFILL_WARMUP_BATCH_SIZE_LIST, len(requests))
|
new_bs = round_up(len(requests), PREFILL_BATCH_BUCKET_SIZE)
|
||||||
missing_inputs = new_bs - len(inputs)
|
missing_inputs = new_bs - len(inputs)
|
||||||
dummy_inputs = ["?"] * missing_inputs
|
dummy_inputs = ["?"] * missing_inputs
|
||||||
parameters = [r.parameters for r in pb.requests]
|
parameters = [r.parameters for r in pb.requests]
|
||||||
@ -533,7 +516,7 @@ class CausalLMBatch(Batch):
|
|||||||
left_padding = max_input_length - input_len
|
left_padding = max_input_length - input_len
|
||||||
if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0:
|
if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0:
|
||||||
assert PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length, "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length"
|
assert PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length, "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length"
|
||||||
rounded_seq_len = round_up(PREFILL_WARMUP_SEQLEN_LIST, input_len + 1)
|
rounded_seq_len = round_up(input_len + 1, PREFILL_BATCH_BUCKET_SIZE)
|
||||||
if rounded_seq_len <= max_input_length:
|
if rounded_seq_len <= max_input_length:
|
||||||
bucket_size = rounded_seq_len - 1
|
bucket_size = rounded_seq_len - 1
|
||||||
else:
|
else:
|
||||||
@ -593,8 +576,8 @@ class CausalLMBatch(Batch):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@tracer.start_as_current_span("concatenate")
|
@tracer.start_as_current_span("concatenate")
|
||||||
def concatenate(cls, batches: List["CausalLMBatch"], pad_token_id: int = 0, is_warmup: bool = False) -> "CausalLMBatch":
|
def concatenate(cls, batches: List["CausalLMBatch"], pad_token_id: int = 0) -> "CausalLMBatch":
|
||||||
return cls.recombine(batches, pad_token_id, is_warmup)
|
return cls.recombine(batches, pad_token_id)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.requests)
|
return len(self.requests)
|
||||||
@ -895,7 +878,7 @@ class CausalLM(Model):
|
|||||||
|
|
||||||
@tracer.start_as_current_span("generate_token")
|
@tracer.start_as_current_span("generate_token")
|
||||||
def generate_token(
|
def generate_token(
|
||||||
self, batches: List[CausalLMBatch], is_warmup: bool = False
|
self, batches: List[CausalLMBatch]
|
||||||
) -> Tuple[List[Generation], Optional[CausalLMBatch], Tuple[int, int]]:
|
) -> Tuple[List[Generation], Optional[CausalLMBatch], Tuple[int, int]]:
|
||||||
start = time.time_ns()
|
start = time.time_ns()
|
||||||
# Results
|
# Results
|
||||||
@ -979,7 +962,7 @@ class CausalLM(Model):
|
|||||||
|
|
||||||
# Stage 2. Prepare new batch for speculative scheduling
|
# Stage 2. Prepare new batch for speculative scheduling
|
||||||
if len(batches) > 1:
|
if len(batches) > 1:
|
||||||
batch = self.batch_type.concatenate(batches, self.tokenizer.pad_token_id, is_warmup)
|
batch = self.batch_type.concatenate(batches, self.tokenizer.pad_token_id)
|
||||||
else:
|
else:
|
||||||
batch = batches[0]
|
batch = batches[0]
|
||||||
|
|
||||||
@ -987,12 +970,12 @@ class CausalLM(Model):
|
|||||||
|
|
||||||
# Check if we need to do any bookkeeping first
|
# Check if we need to do any bookkeeping first
|
||||||
if not prefill:
|
if not prefill:
|
||||||
batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id, is_warmup)
|
batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id)
|
||||||
|
|
||||||
scenario = 'PREFILL' if prefill else 'GENERATE'
|
scenario = 'PREFILL' if prefill else 'GENERATE'
|
||||||
if self.enable_hpu_graph and self.limit_hpu_graph and round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size) != self.prev_bs:
|
if self.enable_hpu_graph and self.limit_hpu_graph and round_up(batch.batch_size, BATCH_BUCKET_SIZE) != self.prev_bs:
|
||||||
self.model.clear_cache()
|
self.model.clear_cache()
|
||||||
self.prev_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size)
|
self.prev_bs = round_up(batch.batch_size, BATCH_BUCKET_SIZE)
|
||||||
dbg_trace(
|
dbg_trace(
|
||||||
scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}')
|
scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}')
|
||||||
assert batch.right_padding > 0, 'No more room for next token!'
|
assert batch.right_padding > 0, 'No more room for next token!'
|
||||||
@ -1194,100 +1177,93 @@ class CausalLM(Model):
|
|||||||
for i in range(len(batch.requests) - batch_size):
|
for i in range(len(batch.requests) - batch_size):
|
||||||
batch.requests.pop()
|
batch.requests.pop()
|
||||||
|
|
||||||
return CausalLMBatch.from_pb(batch, self.tokenizer, self.dtype, self.device, is_warmup=True)
|
return self.batch_type.from_pb(batch, self.tokenizer, self.dtype, self.device)
|
||||||
|
|
||||||
|
|
||||||
def warmup(self, request) -> None:
|
def warmup(self, request) -> None:
|
||||||
is_warmup = True
|
|
||||||
MAX_TOTAL_TOKENS = request.max_total_tokens
|
MAX_TOTAL_TOKENS = request.max_total_tokens
|
||||||
MAX_BATCH_TOTAL_TOKENS = request.max_batch_total_tokens
|
MAX_BATCH_TOTAL_TOKENS = request.max_batch_total_tokens
|
||||||
batch = CausalLMBatch.from_pb(request.batch, self.tokenizer, self.dtype, self.device, is_warmup = is_warmup)
|
batch = self.batch_type.from_pb(request.batch, self.tokenizer, self.dtype, self.device)
|
||||||
try:
|
try:
|
||||||
# max prefill batch size warmup
|
# max prefill batch size warmup
|
||||||
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
|
_, prefill_batch, _ = self.generate_token([batch])
|
||||||
except:
|
except:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
|
f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
|
||||||
f"You need to decrease `--max-batch-prefill-tokens`"
|
f"You need to decrease `--max-batch-prefill-tokens`"
|
||||||
)
|
)
|
||||||
|
del prefill_batch
|
||||||
#warmup decode batch size
|
#warmup decode batch size
|
||||||
max_prefill_batch_size = batch.input_ids.shape[0]
|
max_prefill_batch_size = batch.input_ids.shape[0]
|
||||||
|
del batch
|
||||||
max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS)
|
max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS)
|
||||||
|
max_decode_batch_size = round_up(max_decode_batch_size, BATCH_BUCKET_SIZE)
|
||||||
|
decode_batch_size_list = [i for i in range(BATCH_BUCKET_SIZE, max_decode_batch_size, BATCH_BUCKET_SIZE)]
|
||||||
|
decode_batch_size_list.append(max_decode_batch_size)
|
||||||
|
decode_batch_size_list.sort(reverse=True)
|
||||||
|
|
||||||
self.limit_hpu_graph = True
|
self.limit_hpu_graph = True
|
||||||
try:
|
try:
|
||||||
for batch_size in range(max_decode_batch_size, 0, -BATCH_BUCKET_SIZE):
|
for batch_size in decode_batch_size_list:
|
||||||
batches= []
|
batches= []
|
||||||
iters = math.floor(batch_size/max_prefill_batch_size)
|
iters = math.floor(batch_size/max_prefill_batch_size)
|
||||||
DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size)
|
|
||||||
for i in range(iters):
|
for i in range(iters):
|
||||||
batch = self.generate_warmup_batch(request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, max_prefill_batch_size)
|
batch = self.generate_warmup_batch(request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, max_prefill_batch_size)
|
||||||
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
|
_, prefill_batch, _ = self.generate_token([batch])
|
||||||
batches.append(prefill_batch)
|
batches.append(prefill_batch)
|
||||||
|
|
||||||
if batch_size % max_prefill_batch_size != 0:
|
if batch_size % max_prefill_batch_size != 0:
|
||||||
batch = self.generate_warmup_batch(request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, batch_size % max_prefill_batch_size)
|
batch = self.generate_warmup_batch(request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, batch_size % max_prefill_batch_size)
|
||||||
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
|
_, prefill_batch, _ = self.generate_token([batch])
|
||||||
batches.append(batch)
|
batches.append(prefill_batch)
|
||||||
|
|
||||||
_, decode_batch, _ = self.generate_token(batches, is_warmup)
|
_, decode_batch, _ = self.generate_token(batches)
|
||||||
|
del decode_batch
|
||||||
|
batches.clear()
|
||||||
except:
|
except:
|
||||||
DECODE_WARMUP_BATCH_SIZE_LIST.pop(-1)
|
|
||||||
self.model.clear_cache()
|
|
||||||
if len(DECODE_WARMUP_BATCH_SIZE_LIST) > 0:
|
|
||||||
logger.warning(
|
|
||||||
f"Not enough memory to warmup all batch size of decode."
|
|
||||||
f"You need to decrease `--max-batch-total-tokens`"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Not enough memory to warmup decode batch_size({max_decode_batch_size})."
|
f"Not enough memory to warmup decode batch_sizes({decode_batch_size_list})."
|
||||||
f"You need to decrease `--max-batch-total-tokens`"
|
f"You need to decrease `--max-batch-total-tokens`"
|
||||||
)
|
)
|
||||||
DECODE_WARMUP_BATCH_SIZE_LIST.sort()
|
decode_batch_size_list.sort()
|
||||||
|
MAX_BATCH_TOTAL_TOKENS = MAX_TOTAL_TOKENS * decode_batch_size_list[-1]
|
||||||
mem_stats = get_hpu_memory_stats(self.device)
|
mem_stats = get_hpu_memory_stats(self.device)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"\nFollowing decode warmup successfully.\n"
|
f"\nFollowing decode warmup successfully.\n"
|
||||||
f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n"
|
f"Decode batch size list:{decode_batch_size_list}\n"
|
||||||
f"Memory stats: {mem_stats} "
|
f"Memory stats: {mem_stats} "
|
||||||
)
|
)
|
||||||
|
|
||||||
# Warmup prefill batch_size
|
# Warmup prefill batch_size
|
||||||
max_input_length = request.max_input_length
|
max_input_length = request.max_input_length
|
||||||
max_prefill_batch_size = batch.input_ids.shape[0]
|
prefill_batch_size_list = []
|
||||||
seq_len = PAD_SEQUENCE_TO_MULTIPLE_OF
|
prefill_seqlen_list = []
|
||||||
|
|
||||||
i = 0
|
|
||||||
while seq_len <= max_input_length:
|
|
||||||
PREFILL_WARMUP_SEQLEN_LIST.append(seq_len)
|
|
||||||
seq_len += PAD_SEQUENCE_TO_MULTIPLE_OF*(2**i)
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
if PREFILL_WARMUP_SEQLEN_LIST[-1] < max_input_length:
|
|
||||||
PREFILL_WARMUP_SEQLEN_LIST.append(max_input_length)
|
|
||||||
|
|
||||||
#Prefill and decode warmup
|
#Prefill and decode warmup
|
||||||
try:
|
try:
|
||||||
for batch_size in range(PREFILL_BATCH_BUCKET_SIZE, max_prefill_batch_size + 1, PREFILL_BATCH_BUCKET_SIZE):
|
for batch_size in range(max_prefill_batch_size, 0, -PREFILL_BATCH_BUCKET_SIZE):
|
||||||
PREFILL_WARMUP_BATCH_SIZE_LIST.append(batch_size)
|
prefill_batch_size_list.append(batch_size)
|
||||||
for seq_len in PREFILL_WARMUP_SEQLEN_LIST :
|
for seq_len in range(max_input_length, 0, -PAD_SEQUENCE_TO_MULTIPLE_OF):
|
||||||
batch = self.generate_warmup_batch(request, seq_len - 1, batch_size)
|
prefill_seqlen_list.append(seq_len)
|
||||||
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
|
batch = self.generate_warmup_batch(request, seq_len, batch_size)
|
||||||
|
_, prefill_batch, _ = self.generate_token([batch])
|
||||||
|
del batch
|
||||||
|
del prefill_batch
|
||||||
except:
|
except:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Not enough memory to run following prefill batch_size."
|
f"Not enough memory to run following prefill batch_size."
|
||||||
f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}"
|
f"Prefill batch size list:{prefill_batch_size_list}"
|
||||||
f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}"
|
f"Prefill sequence length list:{prefill_seqlen_list}"
|
||||||
f"You need to decrease `--max-batch-prefill-tokens`"
|
f"You need to decrease `--max-batch-prefill-tokens`"
|
||||||
)
|
)
|
||||||
|
prefill_batch_size_list.sort()
|
||||||
|
prefill_seqlen_list.sort()
|
||||||
limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true"
|
limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true"
|
||||||
if limit_hpu_graph == False:
|
if limit_hpu_graph == False:
|
||||||
mem_stats = get_hpu_memory_stats(self.device)
|
mem_stats = get_hpu_memory_stats(self.device)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"\nFollowing prefill and decode warmup successfully.\n"
|
f"\nFollowing prefill and decode warmup successfully.\n"
|
||||||
f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}\n"
|
f"Prefill batch size list:{prefill_batch_size_list}\n"
|
||||||
f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}\n"
|
f"Prefill sequence length list:{prefill_seqlen_list}\n"
|
||||||
f"Memory stats: {mem_stats} "
|
f"Memory stats: {mem_stats} "
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user