mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 16:32:12 +00:00
Prefill optimization by allocating space only for the first token (#17)
This commit is contained in:
parent
0b96da89aa
commit
60f63262db
@ -114,6 +114,15 @@ def shift_all(srcs, dim, offsets):
|
||||
return [shift(src, dim, offset) for src, offset in zip(srcs, offsets)]
|
||||
|
||||
|
||||
def pad_tensors(tensors, paddings, dim, value):
|
||||
for i, (tensor, padding) in enumerate(zip(tensors, paddings)):
|
||||
if padding > 0:
|
||||
pad_shape = (0, 0, 0, padding) if dim == -2 else (0, padding)
|
||||
tensors[i] = torch.nn.functional.pad(tensor, pad_shape, value=value)
|
||||
htorch.core.mark_step()
|
||||
return tensors
|
||||
|
||||
|
||||
@dataclass
|
||||
class CausalLMRequest:
|
||||
idx: int
|
||||
@ -170,7 +179,7 @@ class CausalLMBatch(Batch):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def recombine(cls, batches: List["CausalLMBatch"], is_optimized_for_gaudi: bool = False) -> "CausalLMBatch":
|
||||
def recombine(cls, batches: List["CausalLMBatch"], pad_token_id: int) -> "CausalLMBatch":
|
||||
total_requests = sum(len(b) for b in batches)
|
||||
new_bs = round_up(total_requests, BATCH_BUCKET_SIZE)
|
||||
batch_id = batches[0].batch_id
|
||||
@ -212,10 +221,6 @@ class CausalLMBatch(Batch):
|
||||
to_tensors = lambda ind: (torch.tensor(ind[0], device=device), torch.tensor(ind[1], device=device))
|
||||
indices = [[to_tensors(req.update_idx(next(free_indices))) for req in batch_reqs] for batch_reqs in grouped_requests]
|
||||
|
||||
max_seq_len = batches[0].attention_mask.size(1)
|
||||
input_length = max_input_length
|
||||
right_padding = max_seq_len - input_length
|
||||
|
||||
chunk_size = batches[0].past_key_values[0][0].size(0) // batches[0].batch_size
|
||||
num_layers = len(batches[0].past_key_values)
|
||||
past_key_values_type = type(batches[0].past_key_values)
|
||||
@ -231,9 +236,14 @@ class CausalLMBatch(Batch):
|
||||
for b in batches:
|
||||
b.past_key_values = list(b.past_key_values)
|
||||
|
||||
# For prefill there is a space allocated only for first token
|
||||
# Need to add padding to the max total tokens before first decode
|
||||
paddings = [(batch.input_length + batch.right_padding) - batch.seq_length for batch in batches]
|
||||
|
||||
src = [b.input_ids for b in batches]
|
||||
for b in batches:
|
||||
del b.input_ids
|
||||
src = pad_tensors(src, paddings, seq_dim, pad_token_id)
|
||||
src = shift_all(src, seq_dim, offsets)
|
||||
input_ids = prepare_memory(new_bs, src[target_batch_idx], inplace)
|
||||
input_ids = move_data(input_ids, 1, indices, src)
|
||||
@ -241,6 +251,7 @@ class CausalLMBatch(Batch):
|
||||
src = [b.attention_mask for b in batches]
|
||||
for b in batches:
|
||||
del b.attention_mask
|
||||
src = pad_tensors(src, paddings, seq_dim, 0)
|
||||
src = shift_all(src, seq_dim, offsets)
|
||||
attention_mask = prepare_memory(new_bs, src[target_batch_idx], inplace)
|
||||
attention_mask = move_data(attention_mask, 1, indices, src)
|
||||
@ -255,11 +266,13 @@ class CausalLMBatch(Batch):
|
||||
past_key_values = []
|
||||
for layer_num in range(num_layers):
|
||||
src = [b.past_key_values[layer_num][0] for b in batches]
|
||||
src = pad_tensors(src, paddings, key_dim, 0)
|
||||
src = shift_all(src, key_dim, offsets)
|
||||
updated_key = prepare_memory(new_bs * chunk_size, src[target_batch_idx], inplace)
|
||||
updated_key = move_data(updated_key, chunk_size, indices, src)
|
||||
|
||||
src = [b.past_key_values[layer_num][1] for b in batches]
|
||||
src = pad_tensors(src, paddings, value_dim, 0)
|
||||
src = shift_all(src, value_dim, offsets)
|
||||
updated_value = prepare_memory(new_bs * chunk_size, src[target_batch_idx], inplace)
|
||||
updated_value = move_data(updated_value, chunk_size, indices, src)
|
||||
@ -278,6 +291,10 @@ class CausalLMBatch(Batch):
|
||||
batches[0].next_token_chooser.dtype
|
||||
)
|
||||
|
||||
max_seq_len = attention_mask.size(1)
|
||||
input_length = max_input_length
|
||||
right_padding = max_seq_len - input_length
|
||||
|
||||
htorch.core.mark_step()
|
||||
|
||||
return cls(
|
||||
@ -352,12 +369,16 @@ class CausalLMBatch(Batch):
|
||||
attention_mask = tokenized_inputs["attention_mask"]
|
||||
|
||||
if is_optimized_for_gaudi:
|
||||
# Allocate space for first token
|
||||
input_ids = torch.nn.functional.pad(
|
||||
input_ids, (0, max_new_tokens + extra_padding), value=tokenizer.pad_token_id
|
||||
input_ids, (0, 1), value=tokenizer.pad_token_id
|
||||
)
|
||||
attention_mask = torch.nn.functional.pad(
|
||||
attention_mask, (0, max_new_tokens + extra_padding), value=0)
|
||||
all_input_ids = input_ids.T.split(1, dim=1)
|
||||
attention_mask, (0, 1), value=0
|
||||
)
|
||||
all_input_ids = torch.nn.functional.pad(
|
||||
input_ids, (0, max_new_tokens + extra_padding - 1), value=tokenizer.pad_token_id
|
||||
).T.split(1, dim=1)
|
||||
else:
|
||||
all_input_ids = input_ids.clone().T.split(1, dim=1)
|
||||
|
||||
@ -386,16 +407,16 @@ class CausalLMBatch(Batch):
|
||||
)
|
||||
|
||||
@tracer.start_as_current_span("filter")
|
||||
def filter(self, request_ids: List[int], is_optimized_for_gaudi: bool = False) -> Optional["CausalLMBatch"]:
|
||||
def filter(self, request_ids: List[int], pad_token_id: int = 0) -> Optional["CausalLMBatch"]:
|
||||
dbg_trace('FILTER', f'num_reqs:{len(self.requests)} -> {len(request_ids)}')
|
||||
request_ids = set(request_ids)
|
||||
self.requests = [req for req in self.requests if req.data.id in request_ids]
|
||||
return self.__class__.recombine([self], is_optimized_for_gaudi)
|
||||
return self.__class__.recombine([self], pad_token_id)
|
||||
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
def concatenate(cls, batches: List["CausalLMBatch"], is_optimized_for_gaudi: bool = False) -> "CausalLMBatch":
|
||||
return cls.recombine(batches, is_optimized_for_gaudi)
|
||||
def concatenate(cls, batches: List["CausalLMBatch"], pad_token_id: int = 0) -> "CausalLMBatch":
|
||||
return cls.recombine(batches, pad_token_id)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.requests)
|
||||
@ -611,7 +632,7 @@ class CausalLM(Model):
|
||||
prefill = batch.past_key_values is None
|
||||
# Check if we need to do any bookkeeping first
|
||||
if not prefill:
|
||||
batch = batch.__class__.recombine([batch], self.is_optimized_for_gaudi)
|
||||
batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id)
|
||||
|
||||
scenario = 'PREFILL' if prefill else 'GENERATE'
|
||||
dbg_trace(scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length}')
|
||||
@ -621,16 +642,20 @@ class CausalLM(Model):
|
||||
self.hb_profer_started = False
|
||||
|
||||
if self.is_optimized_for_gaudi:
|
||||
token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device)
|
||||
if prefill:
|
||||
# no right padding for prefill
|
||||
token_idx = torch.tensor(batch.attention_mask.shape[-1] - 1).to(self.device)
|
||||
else:
|
||||
token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device)
|
||||
attention_mask = batch.attention_mask
|
||||
else:
|
||||
token_idx = None
|
||||
# slice the attention mask to the correct shape
|
||||
# TODO fix me!
|
||||
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
|
||||
if batch.past_key_values:
|
||||
if token_idx is not None:
|
||||
input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1)
|
||||
|
||||
if not prefill and token_idx is not None:
|
||||
input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1)
|
||||
else:
|
||||
input_ids = batch.input_ids
|
||||
|
||||
|
@ -59,7 +59,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
{"util": len(batch.requests)}):
|
||||
if batch is None:
|
||||
raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
|
||||
filtered_batch = batch.filter(request.request_ids, self.model.is_optimized_for_gaudi)
|
||||
filtered_batch = batch.filter(request.request_ids, self.model.tokenizer.pad_token_id)
|
||||
self.cache.set(filtered_batch)
|
||||
|
||||
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
||||
@ -113,7 +113,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
|
||||
if len(batches) > 1:
|
||||
with self.profiler.record_event("internal", "concatenate"):
|
||||
batch = self.model.batch_type.concatenate(batches, self.model.is_optimized_for_gaudi)
|
||||
batch = self.model.batch_type.concatenate(batches, self.model.tokenizer.pad_token_id)
|
||||
else:
|
||||
batch = batches[0]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user