mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 00:12:08 +00:00
Fix the starCode warmup issue
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
b83419a769
commit
4586325a34
@ -61,28 +61,13 @@ LAZY_MODE = int(os.environ.get('PT_HPU_LAZY_MODE', 1))
|
||||
BATCH_BUCKET_SIZE= int(os.environ.get('BATCH_BUCKET_SIZE', 8))
|
||||
PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 2))
|
||||
|
||||
|
||||
PREFILL_WARMUP_BATCH_SIZE_LIST = []
|
||||
PREFILL_WARMUP_SEQLEN_LIST = []
|
||||
DECODE_WARMUP_BATCH_SIZE_LIST = []
|
||||
|
||||
|
||||
def torch_compile_for_eager(func):
|
||||
if LAZY_MODE == 1:
|
||||
return func
|
||||
return torch.compile(func, backend="hpu_backend", options={"keep_input_mutations": True})
|
||||
|
||||
|
||||
def round_up(warmup_list:list, num) :
|
||||
i = 0
|
||||
if len(warmup_list) == 0:
|
||||
return num
|
||||
|
||||
for i in warmup_list:
|
||||
if num <= i :
|
||||
break
|
||||
return i
|
||||
|
||||
def round_up(number, k):
|
||||
return (number + k - 1) // k * k
|
||||
|
||||
def to_tensor_indices(indices, device):
|
||||
return torch.tensor(indices, dtype=torch.long, device=device)
|
||||
@ -372,14 +357,13 @@ class CausalLMBatch(Batch):
|
||||
self.set_tensor_groups(dst_tensors)
|
||||
|
||||
@classmethod
|
||||
def recombine(cls, batches: List["CausalLMBatch"], pad_token_id: int, is_warmup: bool =False) -> "CausalLMBatch":
|
||||
def recombine(cls, batches: List["CausalLMBatch"], pad_token_id: int) -> "CausalLMBatch":
|
||||
if not all(b.past_key_values is not None for b in batches):
|
||||
raise ValueError("KV cache not allocated! Cannot recombine before prefill!")
|
||||
|
||||
total_requests = sum(len(b) for b in batches)
|
||||
new_bs = total_requests
|
||||
if is_warmup is False :
|
||||
new_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST, total_requests)
|
||||
new_bs = round_up(total_requests, BATCH_BUCKET_SIZE)
|
||||
|
||||
batch_id = batches[0].batch_id
|
||||
device = batches[0].input_ids.device
|
||||
@ -481,7 +465,6 @@ class CausalLMBatch(Batch):
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
is_warmup: bool = False,
|
||||
) -> "CausalLMBatch":
|
||||
dbg_trace('FROM_PB', f'num_reqs:{len(pb.requests)}')
|
||||
requests = [CausalLMRequest.from_pb(idx, req, tokenizer) for idx, req in enumerate(pb.requests)]
|
||||
@ -503,7 +486,7 @@ class CausalLMBatch(Batch):
|
||||
# TODO: by tokenizing all inputs at once we loose information on actual input lengths
|
||||
# this means that we cannot shift inputs to the left after a long input sequence
|
||||
# was filtered out
|
||||
new_bs = round_up(PREFILL_WARMUP_BATCH_SIZE_LIST, len(requests))
|
||||
new_bs = round_up(len(requests), PREFILL_BATCH_BUCKET_SIZE)
|
||||
missing_inputs = new_bs - len(inputs)
|
||||
dummy_inputs = ["?"] * missing_inputs
|
||||
parameters = [r.parameters for r in pb.requests]
|
||||
@ -533,7 +516,7 @@ class CausalLMBatch(Batch):
|
||||
left_padding = max_input_length - input_len
|
||||
if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0:
|
||||
assert PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length, "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length"
|
||||
rounded_seq_len = round_up(PREFILL_WARMUP_SEQLEN_LIST, input_len + 1)
|
||||
rounded_seq_len = round_up(input_len + 1, PREFILL_BATCH_BUCKET_SIZE)
|
||||
if rounded_seq_len <= max_input_length:
|
||||
bucket_size = rounded_seq_len - 1
|
||||
else:
|
||||
@ -593,8 +576,8 @@ class CausalLMBatch(Batch):
|
||||
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
def concatenate(cls, batches: List["CausalLMBatch"], pad_token_id: int = 0, is_warmup: bool = False) -> "CausalLMBatch":
|
||||
return cls.recombine(batches, pad_token_id, is_warmup)
|
||||
def concatenate(cls, batches: List["CausalLMBatch"], pad_token_id: int = 0) -> "CausalLMBatch":
|
||||
return cls.recombine(batches, pad_token_id)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.requests)
|
||||
@ -895,7 +878,7 @@ class CausalLM(Model):
|
||||
|
||||
@tracer.start_as_current_span("generate_token")
|
||||
def generate_token(
|
||||
self, batches: List[CausalLMBatch], is_warmup: bool = False
|
||||
self, batches: List[CausalLMBatch]
|
||||
) -> Tuple[List[Generation], Optional[CausalLMBatch], Tuple[int, int]]:
|
||||
start = time.time_ns()
|
||||
# Results
|
||||
@ -979,7 +962,7 @@ class CausalLM(Model):
|
||||
|
||||
# Stage 2. Prepare new batch for speculative scheduling
|
||||
if len(batches) > 1:
|
||||
batch = self.batch_type.concatenate(batches, self.tokenizer.pad_token_id, is_warmup)
|
||||
batch = self.batch_type.concatenate(batches, self.tokenizer.pad_token_id)
|
||||
else:
|
||||
batch = batches[0]
|
||||
|
||||
@ -987,12 +970,12 @@ class CausalLM(Model):
|
||||
|
||||
# Check if we need to do any bookkeeping first
|
||||
if not prefill:
|
||||
batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id, is_warmup)
|
||||
batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id)
|
||||
|
||||
scenario = 'PREFILL' if prefill else 'GENERATE'
|
||||
if self.enable_hpu_graph and self.limit_hpu_graph and round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size) != self.prev_bs:
|
||||
if self.enable_hpu_graph and self.limit_hpu_graph and round_up(batch.batch_size, BATCH_BUCKET_SIZE) != self.prev_bs:
|
||||
self.model.clear_cache()
|
||||
self.prev_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size)
|
||||
self.prev_bs = round_up(batch.batch_size, BATCH_BUCKET_SIZE)
|
||||
dbg_trace(
|
||||
scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}')
|
||||
assert batch.right_padding > 0, 'No more room for next token!'
|
||||
@ -1194,100 +1177,93 @@ class CausalLM(Model):
|
||||
for i in range(len(batch.requests) - batch_size):
|
||||
batch.requests.pop()
|
||||
|
||||
return CausalLMBatch.from_pb(batch, self.tokenizer, self.dtype, self.device, is_warmup=True)
|
||||
return self.batch_type.from_pb(batch, self.tokenizer, self.dtype, self.device)
|
||||
|
||||
|
||||
def warmup(self, request) -> None:
|
||||
is_warmup = True
|
||||
MAX_TOTAL_TOKENS = request.max_total_tokens
|
||||
MAX_BATCH_TOTAL_TOKENS = request.max_batch_total_tokens
|
||||
batch = CausalLMBatch.from_pb(request.batch, self.tokenizer, self.dtype, self.device, is_warmup = is_warmup)
|
||||
batch = self.batch_type.from_pb(request.batch, self.tokenizer, self.dtype, self.device)
|
||||
try:
|
||||
# max prefill batch size warmup
|
||||
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
|
||||
_, prefill_batch, _ = self.generate_token([batch])
|
||||
except:
|
||||
raise RuntimeError(
|
||||
f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
|
||||
f"You need to decrease `--max-batch-prefill-tokens`"
|
||||
)
|
||||
|
||||
del prefill_batch
|
||||
#warmup decode batch size
|
||||
max_prefill_batch_size = batch.input_ids.shape[0]
|
||||
del batch
|
||||
max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS)
|
||||
max_decode_batch_size = round_up(max_decode_batch_size, BATCH_BUCKET_SIZE)
|
||||
decode_batch_size_list = [i for i in range(BATCH_BUCKET_SIZE, max_decode_batch_size, BATCH_BUCKET_SIZE)]
|
||||
decode_batch_size_list.append(max_decode_batch_size)
|
||||
decode_batch_size_list.sort(reverse=True)
|
||||
|
||||
self.limit_hpu_graph = True
|
||||
try:
|
||||
for batch_size in range(max_decode_batch_size, 0, -BATCH_BUCKET_SIZE):
|
||||
for batch_size in decode_batch_size_list:
|
||||
batches= []
|
||||
iters = math.floor(batch_size/max_prefill_batch_size)
|
||||
DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size)
|
||||
for i in range(iters):
|
||||
batch = self.generate_warmup_batch(request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, max_prefill_batch_size)
|
||||
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
|
||||
_, prefill_batch, _ = self.generate_token([batch])
|
||||
batches.append(prefill_batch)
|
||||
|
||||
if batch_size % max_prefill_batch_size != 0:
|
||||
batch = self.generate_warmup_batch(request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, batch_size % max_prefill_batch_size)
|
||||
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
|
||||
batches.append(batch)
|
||||
_, prefill_batch, _ = self.generate_token([batch])
|
||||
batches.append(prefill_batch)
|
||||
|
||||
_, decode_batch, _ = self.generate_token(batches, is_warmup)
|
||||
_, decode_batch, _ = self.generate_token(batches)
|
||||
del decode_batch
|
||||
batches.clear()
|
||||
except:
|
||||
DECODE_WARMUP_BATCH_SIZE_LIST.pop(-1)
|
||||
self.model.clear_cache()
|
||||
if len(DECODE_WARMUP_BATCH_SIZE_LIST) > 0:
|
||||
logger.warning(
|
||||
f"Not enough memory to warmup all batch size of decode."
|
||||
f"You need to decrease `--max-batch-total-tokens`"
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Not enough memory to warmup decode batch_size({max_decode_batch_size})."
|
||||
f"Not enough memory to warmup decode batch_sizes({decode_batch_size_list})."
|
||||
f"You need to decrease `--max-batch-total-tokens`"
|
||||
)
|
||||
DECODE_WARMUP_BATCH_SIZE_LIST.sort()
|
||||
decode_batch_size_list.sort()
|
||||
MAX_BATCH_TOTAL_TOKENS = MAX_TOTAL_TOKENS * decode_batch_size_list[-1]
|
||||
mem_stats = get_hpu_memory_stats(self.device)
|
||||
logger.info(
|
||||
f"\nFollowing decode warmup successfully.\n"
|
||||
f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n"
|
||||
f"Decode batch size list:{decode_batch_size_list}\n"
|
||||
f"Memory stats: {mem_stats} "
|
||||
)
|
||||
|
||||
# Warmup prefill batch_size
|
||||
max_input_length = request.max_input_length
|
||||
max_prefill_batch_size = batch.input_ids.shape[0]
|
||||
seq_len = PAD_SEQUENCE_TO_MULTIPLE_OF
|
||||
|
||||
i = 0
|
||||
while seq_len <= max_input_length:
|
||||
PREFILL_WARMUP_SEQLEN_LIST.append(seq_len)
|
||||
seq_len += PAD_SEQUENCE_TO_MULTIPLE_OF*(2**i)
|
||||
i += 1
|
||||
|
||||
if PREFILL_WARMUP_SEQLEN_LIST[-1] < max_input_length:
|
||||
PREFILL_WARMUP_SEQLEN_LIST.append(max_input_length)
|
||||
|
||||
prefill_batch_size_list = []
|
||||
prefill_seqlen_list = []
|
||||
#Prefill and decode warmup
|
||||
try:
|
||||
for batch_size in range(PREFILL_BATCH_BUCKET_SIZE, max_prefill_batch_size + 1, PREFILL_BATCH_BUCKET_SIZE):
|
||||
PREFILL_WARMUP_BATCH_SIZE_LIST.append(batch_size)
|
||||
for seq_len in PREFILL_WARMUP_SEQLEN_LIST :
|
||||
batch = self.generate_warmup_batch(request, seq_len - 1, batch_size)
|
||||
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
|
||||
for batch_size in range(max_prefill_batch_size, 0, -PREFILL_BATCH_BUCKET_SIZE):
|
||||
prefill_batch_size_list.append(batch_size)
|
||||
for seq_len in range(max_input_length, 0, -PAD_SEQUENCE_TO_MULTIPLE_OF):
|
||||
prefill_seqlen_list.append(seq_len)
|
||||
batch = self.generate_warmup_batch(request, seq_len, batch_size)
|
||||
_, prefill_batch, _ = self.generate_token([batch])
|
||||
del batch
|
||||
del prefill_batch
|
||||
except:
|
||||
raise RuntimeError(
|
||||
f"Not enough memory to run following prefill batch_size."
|
||||
f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}"
|
||||
f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}"
|
||||
f"Prefill batch size list:{prefill_batch_size_list}"
|
||||
f"Prefill sequence length list:{prefill_seqlen_list}"
|
||||
f"You need to decrease `--max-batch-prefill-tokens`"
|
||||
)
|
||||
|
||||
prefill_batch_size_list.sort()
|
||||
prefill_seqlen_list.sort()
|
||||
limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true"
|
||||
if limit_hpu_graph == False:
|
||||
mem_stats = get_hpu_memory_stats(self.device)
|
||||
logger.info(
|
||||
f"\nFollowing prefill and decode warmup successfully.\n"
|
||||
f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}\n"
|
||||
f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}\n"
|
||||
f"Prefill batch size list:{prefill_batch_size_list}\n"
|
||||
f"Prefill sequence length list:{prefill_seqlen_list}\n"
|
||||
f"Memory stats: {mem_stats} "
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user