Prefill optimization by allocating space only for the first output token (#34) (#62)

Co-authored-by: mswiniarsk <156412439+mswiniarsk@users.noreply.github.com>
Co-authored-by: Karol Damaszke <karol.damaszke@intel.com>
This commit is contained in:
jkaniecki 2024-02-22 04:55:43 +01:00 committed by GitHub
parent 80303b469c
commit 8f590759e3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 57 additions and 32 deletions

View File

@ -140,6 +140,15 @@ def remove_kv_cache_from_output(module):
return module return module
def pad_tensors(tensors, paddings, dim, value):
for i, (tensor, padding) in enumerate(zip(tensors, paddings)):
if padding > 0:
pad_shape = (0, 0, 0, padding) if dim == -2 else (0, padding)
tensors[i] = torch.nn.functional.pad(tensor, pad_shape, value=value)
htorch.core.mark_step()
return tensors
@dataclass @dataclass
class CausalLMRequest: class CausalLMRequest:
idx: int idx: int
@ -196,7 +205,7 @@ class CausalLMBatch(Batch):
) )
@classmethod @classmethod
def recombine(cls, batches: List["CausalLMBatch"], is_optimized_for_gaudi: bool = False) -> "CausalLMBatch": def recombine(cls, batches: List["CausalLMBatch"], pad_token_id: int) -> "CausalLMBatch":
total_requests = sum(len(b) for b in batches) total_requests = sum(len(b) for b in batches)
new_bs = round_up(total_requests, BATCH_BUCKET_SIZE) new_bs = round_up(total_requests, BATCH_BUCKET_SIZE)
batch_id = batches[0].batch_id batch_id = batches[0].batch_id
@ -224,7 +233,8 @@ class CausalLMBatch(Batch):
return batches[0] return batches[0]
inplace = batches[target_batch_idx].batch_size == new_bs inplace = batches[target_batch_idx].batch_size == new_bs
dbg_trace(scenario, f'bs:{[b.batch_size for b in batches]}->{new_bs} reqs:{[len(b) for b in batches]} offsets:{offsets} padding:{padding} moves_needed:{moves_needed} inplace:{inplace}') dbg_trace(
scenario, f'bs:{[b.batch_size for b in batches]}->{new_bs} reqs:{[len(b) for b in batches]} offsets:{offsets} padding:{padding} moves_needed:{moves_needed} inplace:{inplace}')
grouped_requests = [[req for req in batch.requests] for batch in batches] grouped_requests = [[req for req in batch.requests] for batch in batches]
flat_requests = list(itertools.chain(*grouped_requests)) flat_requests = list(itertools.chain(*grouped_requests))
@ -235,12 +245,9 @@ class CausalLMBatch(Batch):
else: else:
free_indices = itertools.count(0) free_indices = itertools.count(0)
to_tensors = lambda ind: (torch.tensor(ind[0], device=device), torch.tensor(ind[1], device=device)) def to_tensors(ind): return (torch.tensor(ind[0], device=device), torch.tensor(ind[1], device=device))
indices = [[to_tensors(req.update_idx(next(free_indices))) for req in batch_reqs] for batch_reqs in grouped_requests] indices = [[to_tensors(req.update_idx(next(free_indices))) for req in batch_reqs]
for batch_reqs in grouped_requests]
max_seq_len = batches[0].attention_mask.size(1)
input_length = max_input_length
right_padding = max_seq_len - input_length
chunk_size = batches[0].past_key_values[0][0].size(0) // batches[0].batch_size chunk_size = batches[0].past_key_values[0][0].size(0) // batches[0].batch_size
num_layers = len(batches[0].past_key_values) num_layers = len(batches[0].past_key_values)
@ -257,9 +264,14 @@ class CausalLMBatch(Batch):
for b in batches: for b in batches:
b.past_key_values = list(b.past_key_values) b.past_key_values = list(b.past_key_values)
# For prefill there is a space allocated only for first token
# Need to add padding to the max total tokens before first decode
paddings = [(batch.input_length + batch.right_padding) - batch.seq_length for batch in batches]
src = [b.input_ids for b in batches] src = [b.input_ids for b in batches]
for b in batches: for b in batches:
del b.input_ids del b.input_ids
src = pad_tensors(src, paddings, seq_dim, pad_token_id)
src = shift_all(src, seq_dim, offsets) src = shift_all(src, seq_dim, offsets)
input_ids = prepare_memory(new_bs, src[target_batch_idx], inplace) input_ids = prepare_memory(new_bs, src[target_batch_idx], inplace)
input_ids = move_data(input_ids, 1, indices, src) input_ids = move_data(input_ids, 1, indices, src)
@ -267,6 +279,7 @@ class CausalLMBatch(Batch):
src = [b.attention_mask for b in batches] src = [b.attention_mask for b in batches]
for b in batches: for b in batches:
del b.attention_mask del b.attention_mask
src = pad_tensors(src, paddings, seq_dim, 0)
src = shift_all(src, seq_dim, offsets) src = shift_all(src, seq_dim, offsets)
attention_mask = prepare_memory(new_bs, src[target_batch_idx], inplace) attention_mask = prepare_memory(new_bs, src[target_batch_idx], inplace)
attention_mask = move_data(attention_mask, 1, indices, src) attention_mask = move_data(attention_mask, 1, indices, src)
@ -281,11 +294,13 @@ class CausalLMBatch(Batch):
past_key_values = [] past_key_values = []
for layer_num in range(num_layers): for layer_num in range(num_layers):
src = [b.past_key_values[layer_num][0] for b in batches] src = [b.past_key_values[layer_num][0] for b in batches]
src = pad_tensors(src, paddings, key_dim, 0)
src = shift_all(src, key_dim, offsets) src = shift_all(src, key_dim, offsets)
updated_key = prepare_memory(new_bs * chunk_size, src[target_batch_idx], inplace) updated_key = prepare_memory(new_bs * chunk_size, src[target_batch_idx], inplace)
updated_key = move_data(updated_key, chunk_size, indices, src) updated_key = move_data(updated_key, chunk_size, indices, src)
src = [b.past_key_values[layer_num][1] for b in batches] src = [b.past_key_values[layer_num][1] for b in batches]
src = pad_tensors(src, paddings, value_dim, 0)
src = shift_all(src, value_dim, offsets) src = shift_all(src, value_dim, offsets)
updated_value = prepare_memory(new_bs * chunk_size, src[target_batch_idx], inplace) updated_value = prepare_memory(new_bs * chunk_size, src[target_batch_idx], inplace)
updated_value = move_data(updated_value, chunk_size, indices, src) updated_value = move_data(updated_value, chunk_size, indices, src)
@ -304,6 +319,10 @@ class CausalLMBatch(Batch):
batches[0].next_token_chooser.device batches[0].next_token_chooser.device
) )
max_seq_len = attention_mask.size(1)
input_length = max_input_length
right_padding = max_seq_len - input_length
htorch.core.mark_step() htorch.core.mark_step()
return cls( return cls(
@ -320,7 +339,6 @@ class CausalLMBatch(Batch):
right_padding=right_padding right_padding=right_padding
) )
@classmethod @classmethod
def from_pb( def from_pb(
cls, cls,
@ -378,12 +396,16 @@ class CausalLMBatch(Batch):
attention_mask = tokenized_inputs["attention_mask"] attention_mask = tokenized_inputs["attention_mask"]
if is_optimized_for_gaudi: if is_optimized_for_gaudi:
# Allocate space for first token
input_ids = torch.nn.functional.pad( input_ids = torch.nn.functional.pad(
input_ids, (0, max_new_tokens + extra_padding), value=tokenizer.pad_token_id input_ids, (0, 1), value=tokenizer.pad_token_id
) )
attention_mask = torch.nn.functional.pad( attention_mask = torch.nn.functional.pad(
attention_mask, (0, max_new_tokens + extra_padding), value=0) attention_mask, (0, 1), value=0
all_input_ids = input_ids.T.split(1, dim=1) )
all_input_ids = torch.nn.functional.pad(
input_ids, (0, max_new_tokens + extra_padding - 1), value=tokenizer.pad_token_id
).T.split(1, dim=1)
else: else:
all_input_ids = input_ids.clone().T.split(1, dim=1) all_input_ids = input_ids.clone().T.split(1, dim=1)
@ -412,7 +434,7 @@ class CausalLMBatch(Batch):
) )
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int], is_optimized_for_gaudi: bool = False) -> Optional["CausalLMBatch"]: def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]:
dbg_trace('FILTER', f'num_reqs:{len(self.requests)} -> {len(request_ids)}') dbg_trace('FILTER', f'num_reqs:{len(self.requests)} -> {len(request_ids)}')
request_ids = set(request_ids) request_ids = set(request_ids)
self.requests = [req for req in self.requests if req.data.id in request_ids] self.requests = [req for req in self.requests if req.data.id in request_ids]
@ -420,8 +442,8 @@ class CausalLMBatch(Batch):
@classmethod @classmethod
@tracer.start_as_current_span("concatenate") @tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["CausalLMBatch"], is_optimized_for_gaudi: bool = False) -> "CausalLMBatch": def concatenate(cls, batches: List["CausalLMBatch"], pad_token_id: int = 0) -> "CausalLMBatch":
return cls.recombine(batches, is_optimized_for_gaudi) return cls.recombine(batches, pad_token_id)
def __len__(self): def __len__(self):
return len(self.requests) return len(self.requests)
@ -517,7 +539,6 @@ class CausalLM(Model):
ds_inference_kwargs["tensor_parallel"] = {"tp_size": world_size} ds_inference_kwargs["tensor_parallel"] = {"tp_size": world_size}
ds_inference_kwargs["enable_cuda_graph"] = False ds_inference_kwargs["enable_cuda_graph"] = False
if load_to_meta: if load_to_meta:
# model loaded to meta is managed differently # model loaded to meta is managed differently
checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w") checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w")
@ -537,7 +558,7 @@ class CausalLM(Model):
torch_dtype=dtype, torch_dtype=dtype,
) )
model = model.eval().to(device) model = model.eval().to(device)
#wrap in hpu_graph only if self.enable_hpu_graph is set # wrap in hpu_graph only if self.enable_hpu_graph is set
model = remove_kv_cache_from_output(model) model = remove_kv_cache_from_output(model)
if self.enable_hpu_graph: if self.enable_hpu_graph:
model = wrap_in_hpu_graph(model, disable_tensor_cache=True) model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
@ -608,7 +629,6 @@ class CausalLM(Model):
else: else:
return super().decode_token(all_input_ids, prefix_offset, read_offset) return super().decode_token(all_input_ids, prefix_offset, read_offset)
def forward( def forward(
self, self,
input_ids, input_ids,
@ -646,10 +666,11 @@ class CausalLM(Model):
prefill = batch.past_key_values is None prefill = batch.past_key_values is None
# Check if we need to do any bookkeeping first # Check if we need to do any bookkeeping first
if not prefill: if not prefill:
batch = batch.__class__.recombine([batch], self.is_optimized_for_gaudi) batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id)
scenario = 'PREFILL' if prefill else 'GENERATE' scenario = 'PREFILL' if prefill else 'GENERATE'
dbg_trace(scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}') dbg_trace(
scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}')
assert batch.right_padding > 0, 'No more room for next token!' assert batch.right_padding > 0, 'No more room for next token!'
self.step = self.step + 1 self.step = self.step + 1
if self.hb_profer_started == True and self.step > self.profiling_warmup_steps + self.profiling_steps: if self.hb_profer_started == True and self.step > self.profiling_warmup_steps + self.profiling_steps:
@ -657,6 +678,10 @@ class CausalLM(Model):
self.hb_profer_started = False self.hb_profer_started = False
if self.is_optimized_for_gaudi: if self.is_optimized_for_gaudi:
if prefill:
# no right padding for prefill
token_idx = torch.tensor(batch.attention_mask.shape[-1] - 1).to(self.device)
else:
token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device) token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device)
attention_mask = batch.attention_mask attention_mask = batch.attention_mask
else: else:
@ -664,8 +689,8 @@ class CausalLM(Model):
# slice the attention mask to the correct shape # slice the attention mask to the correct shape
# TODO fix me! # TODO fix me!
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
if batch.past_key_values:
if token_idx is not None: if not prefill and token_idx is not None:
input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1) input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1)
else: else:
input_ids = batch.input_ids input_ids = batch.input_ids
@ -677,7 +702,7 @@ class CausalLM(Model):
batch.position_ids, batch.position_ids,
token_idx, token_idx,
batch.past_key_values, batch.past_key_values,
bypass_hpu_graph = prefill and self.limit_hpu_graph if self.enable_hpu_graph else None bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None
) )
else: else:
logits = self.forward( logits = self.forward(
@ -686,7 +711,7 @@ class CausalLM(Model):
batch.position_ids, batch.position_ids,
token_idx, token_idx,
batch.past_key_values, batch.past_key_values,
bypass_hpu_graph = prefill and self.limit_hpu_graph if self.enable_hpu_graph else None bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None
) )
# Results # Results
@ -697,7 +722,7 @@ class CausalLM(Model):
input_length = batch.input_length input_length = batch.input_length
if self.is_optimized_for_gaudi and logits.shape[-2] > 1: if self.is_optimized_for_gaudi and logits.shape[-2] > 1:
next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser( next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
batch.input_ids[:, :token_idx], logits[:, input_length - 1 : input_length, :].squeeze(-2) batch.input_ids[:, :token_idx], logits[:, input_length - 1: input_length, :].squeeze(-2)
) )
else: else:
next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser( next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
@ -757,7 +782,7 @@ class CausalLM(Model):
if stop: if stop:
# Decode generated tokens # Decode generated tokens
output_text = self.decode( output_text = self.decode(
all_input_ids[new_input_length - stopping_criteria.current_tokens : new_input_length, 0] all_input_ids[new_input_length - stopping_criteria.current_tokens: new_input_length, 0]
) )
generated_text = GeneratedText( generated_text = GeneratedText(
output_text, output_text,
@ -772,7 +797,7 @@ class CausalLM(Model):
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
# Remove generated token to only have prefill and add nan for first prompt token # Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs = [float("nan")] + next_token_logprobs prefill_logprobs = [float("nan")] + next_token_logprobs
prefill_token_ids = all_input_ids[0 : new_input_length - 1] prefill_token_ids = all_input_ids[0: new_input_length - 1]
prefill_texts = self.tokenizer.batch_decode( prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids, prefill_token_ids,
clean_up_tokenization_spaces=False, clean_up_tokenization_spaces=False,
@ -846,7 +871,7 @@ class CausalLM(Model):
# Update position_ids # Update position_ids
if prefill: if prefill:
batch.position_ids = batch.position_ids[:, token_idx - 1 : token_idx] + 1 batch.position_ids = batch.position_ids[:, token_idx - 1: token_idx] + 1
else: else:
batch.position_ids += 1 batch.position_ids += 1
# Update past key values # Update past key values

View File

@ -59,7 +59,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
{"util": len(batch.requests)}): {"util": len(batch.requests)}):
if batch is None: if batch is None:
raise ValueError(f"Batch ID {request.batch_id} not found in cache.") raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
filtered_batch = batch.filter(request.request_ids, self.model.is_optimized_for_gaudi) filtered_batch = batch.filter(request.request_ids, self.model.tokenizer.pad_token_id)
self.cache.set(filtered_batch) self.cache.set(filtered_batch)
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
@ -113,7 +113,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
if len(batches) > 1: if len(batches) > 1:
with self.profiler.record_event("internal", "concatenate"): with self.profiler.record_event("internal", "concatenate"):
batch = self.model.batch_type.concatenate(batches, self.model.is_optimized_for_gaudi) batch = self.model.batch_type.concatenate(batches, self.model.tokenizer.pad_token_id)
else: else:
batch = batches[0] batch = batches[0]