Prefill optimization by allocating space only for the first token (#17)

This commit is contained in:
Karol Damaszke 2024-01-19 15:18:35 +01:00 committed by GitHub
parent 0b96da89aa
commit 60f63262db
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 44 additions and 19 deletions

View File

@ -114,6 +114,15 @@ def shift_all(srcs, dim, offsets):
return [shift(src, dim, offset) for src, offset in zip(srcs, offsets)]
def pad_tensors(tensors, paddings, dim, value):
for i, (tensor, padding) in enumerate(zip(tensors, paddings)):
if padding > 0:
pad_shape = (0, 0, 0, padding) if dim == -2 else (0, padding)
tensors[i] = torch.nn.functional.pad(tensor, pad_shape, value=value)
htorch.core.mark_step()
return tensors
@dataclass
class CausalLMRequest:
idx: int
@ -170,7 +179,7 @@ class CausalLMBatch(Batch):
)
@classmethod
def recombine(cls, batches: List["CausalLMBatch"], is_optimized_for_gaudi: bool = False) -> "CausalLMBatch":
def recombine(cls, batches: List["CausalLMBatch"], pad_token_id: int) -> "CausalLMBatch":
total_requests = sum(len(b) for b in batches)
new_bs = round_up(total_requests, BATCH_BUCKET_SIZE)
batch_id = batches[0].batch_id
@ -212,10 +221,6 @@ class CausalLMBatch(Batch):
to_tensors = lambda ind: (torch.tensor(ind[0], device=device), torch.tensor(ind[1], device=device))
indices = [[to_tensors(req.update_idx(next(free_indices))) for req in batch_reqs] for batch_reqs in grouped_requests]
max_seq_len = batches[0].attention_mask.size(1)
input_length = max_input_length
right_padding = max_seq_len - input_length
chunk_size = batches[0].past_key_values[0][0].size(0) // batches[0].batch_size
num_layers = len(batches[0].past_key_values)
past_key_values_type = type(batches[0].past_key_values)
@ -231,9 +236,14 @@ class CausalLMBatch(Batch):
for b in batches:
b.past_key_values = list(b.past_key_values)
# For prefill there is a space allocated only for first token
# Need to add padding to the max total tokens before first decode
paddings = [(batch.input_length + batch.right_padding) - batch.seq_length for batch in batches]
src = [b.input_ids for b in batches]
for b in batches:
del b.input_ids
src = pad_tensors(src, paddings, seq_dim, pad_token_id)
src = shift_all(src, seq_dim, offsets)
input_ids = prepare_memory(new_bs, src[target_batch_idx], inplace)
input_ids = move_data(input_ids, 1, indices, src)
@ -241,6 +251,7 @@ class CausalLMBatch(Batch):
src = [b.attention_mask for b in batches]
for b in batches:
del b.attention_mask
src = pad_tensors(src, paddings, seq_dim, 0)
src = shift_all(src, seq_dim, offsets)
attention_mask = prepare_memory(new_bs, src[target_batch_idx], inplace)
attention_mask = move_data(attention_mask, 1, indices, src)
@ -255,11 +266,13 @@ class CausalLMBatch(Batch):
past_key_values = []
for layer_num in range(num_layers):
src = [b.past_key_values[layer_num][0] for b in batches]
src = pad_tensors(src, paddings, key_dim, 0)
src = shift_all(src, key_dim, offsets)
updated_key = prepare_memory(new_bs * chunk_size, src[target_batch_idx], inplace)
updated_key = move_data(updated_key, chunk_size, indices, src)
src = [b.past_key_values[layer_num][1] for b in batches]
src = pad_tensors(src, paddings, value_dim, 0)
src = shift_all(src, value_dim, offsets)
updated_value = prepare_memory(new_bs * chunk_size, src[target_batch_idx], inplace)
updated_value = move_data(updated_value, chunk_size, indices, src)
@ -278,6 +291,10 @@ class CausalLMBatch(Batch):
batches[0].next_token_chooser.dtype
)
max_seq_len = attention_mask.size(1)
input_length = max_input_length
right_padding = max_seq_len - input_length
htorch.core.mark_step()
return cls(
@ -352,12 +369,16 @@ class CausalLMBatch(Batch):
attention_mask = tokenized_inputs["attention_mask"]
if is_optimized_for_gaudi:
# Allocate space for first token
input_ids = torch.nn.functional.pad(
input_ids, (0, max_new_tokens + extra_padding), value=tokenizer.pad_token_id
input_ids, (0, 1), value=tokenizer.pad_token_id
)
attention_mask = torch.nn.functional.pad(
attention_mask, (0, max_new_tokens + extra_padding), value=0)
all_input_ids = input_ids.T.split(1, dim=1)
attention_mask, (0, 1), value=0
)
all_input_ids = torch.nn.functional.pad(
input_ids, (0, max_new_tokens + extra_padding - 1), value=tokenizer.pad_token_id
).T.split(1, dim=1)
else:
all_input_ids = input_ids.clone().T.split(1, dim=1)
@ -386,16 +407,16 @@ class CausalLMBatch(Batch):
)
@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int], is_optimized_for_gaudi: bool = False) -> Optional["CausalLMBatch"]:
def filter(self, request_ids: List[int], pad_token_id: int = 0) -> Optional["CausalLMBatch"]:
dbg_trace('FILTER', f'num_reqs:{len(self.requests)} -> {len(request_ids)}')
request_ids = set(request_ids)
self.requests = [req for req in self.requests if req.data.id in request_ids]
return self.__class__.recombine([self], is_optimized_for_gaudi)
return self.__class__.recombine([self], pad_token_id)
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["CausalLMBatch"], is_optimized_for_gaudi: bool = False) -> "CausalLMBatch":
return cls.recombine(batches, is_optimized_for_gaudi)
def concatenate(cls, batches: List["CausalLMBatch"], pad_token_id: int = 0) -> "CausalLMBatch":
return cls.recombine(batches, pad_token_id)
def __len__(self):
return len(self.requests)
@ -611,7 +632,7 @@ class CausalLM(Model):
prefill = batch.past_key_values is None
# Check if we need to do any bookkeeping first
if not prefill:
batch = batch.__class__.recombine([batch], self.is_optimized_for_gaudi)
batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id)
scenario = 'PREFILL' if prefill else 'GENERATE'
dbg_trace(scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length}')
@ -621,16 +642,20 @@ class CausalLM(Model):
self.hb_profer_started = False
if self.is_optimized_for_gaudi:
token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device)
if prefill:
# no right padding for prefill
token_idx = torch.tensor(batch.attention_mask.shape[-1] - 1).to(self.device)
else:
token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device)
attention_mask = batch.attention_mask
else:
token_idx = None
# slice the attention mask to the correct shape
# TODO fix me!
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
if batch.past_key_values:
if token_idx is not None:
input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1)
if not prefill and token_idx is not None:
input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1)
else:
input_ids = batch.input_ids

View File

@ -59,7 +59,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
{"util": len(batch.requests)}):
if batch is None:
raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
filtered_batch = batch.filter(request.request_ids, self.model.is_optimized_for_gaudi)
filtered_batch = batch.filter(request.request_ids, self.model.tokenizer.pad_token_id)
self.cache.set(filtered_batch)
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
@ -113,7 +113,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
if len(batches) > 1:
with self.profiler.record_event("internal", "concatenate"):
batch = self.model.batch_type.concatenate(batches, self.model.is_optimized_for_gaudi)
batch = self.model.batch_type.concatenate(batches, self.model.tokenizer.pad_token_id)
else:
batch = batches[0]