From 2a7a967de361ca2607415e3de9ce579d26122617 Mon Sep 17 00:00:00 2001 From: Karol Damaszke Date: Tue, 23 Jan 2024 15:19:07 +0100 Subject: [PATCH] Revert prefill optimization and fix accuracy issue in shift operation (#29) Co-authored-by: Karol Damaszke Co-authored-by: madamczykhabana <110973826+madamczykhabana@users.noreply.github.com> Co-authored-by: jkaniecki <153085639+jkaniecki@users.noreply.github.com> --- .../models/causal_lm.py | 77 +++++++------------ server/text_generation_server/server.py | 4 +- 2 files changed, 30 insertions(+), 51 deletions(-) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 038f340e..c018e37b 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -101,9 +101,12 @@ def shift(tensor, dim, offset): if offset == 0 or abs(offset) > elements: return tensor htorch.core.mark_step() + # We generate indices from (0 - offset + elements) to (elements - offset + elements) + # so that next modulo operation operates on positive values indices = torch.arange(0, elements, dtype=torch.int32, device=tensor.device) - offset = torch.tensor(offset, dtype=torch.int32, device=tensor.device) - indices = torch.clamp(indices - offset, 0, elements - 1) + offset = torch.tensor(-offset + elements, dtype=torch.int32, device=tensor.device) + indices.add_(offset) + indices.remainder_(elements) target_shape = [1,] * len(tensor.shape) target_shape[dim] = elements indices = indices.view(target_shape).expand(shape) @@ -137,15 +140,6 @@ def remove_kv_cache_from_output(module): return module -def pad_tensors(tensors, paddings, dim, value): - for i, (tensor, padding) in enumerate(zip(tensors, paddings)): - if padding > 0: - pad_shape = (0, 0, 0, padding) if dim == -2 else (0, padding) - tensors[i] = torch.nn.functional.pad(tensor, pad_shape, value=value) - htorch.core.mark_step() - return tensors - - @dataclass class CausalLMRequest: idx: int @@ -202,7 +196,7 @@ class CausalLMBatch(Batch): ) @classmethod - def recombine(cls, batches: List["CausalLMBatch"], pad_token_id: int) -> "CausalLMBatch": + def recombine(cls, batches: List["CausalLMBatch"], is_optimized_for_gaudi: bool = False) -> "CausalLMBatch": total_requests = sum(len(b) for b in batches) new_bs = round_up(total_requests, BATCH_BUCKET_SIZE) batch_id = batches[0].batch_id @@ -221,7 +215,7 @@ class CausalLMBatch(Batch): scenario = 'CONCAT' elif batches[0].batch_size != new_bs: scenario = 'RESHAPE' - elif padding[0] <= 1: + elif padding[0] <= 0: scenario = 'SHIFT' offsets = [b.max_input_length - max_input_length for b in batches] max_input_length = max(b.max_input_length for b in batches) @@ -234,7 +228,7 @@ class CausalLMBatch(Batch): grouped_requests = [[req for req in batch.requests] for batch in batches] flat_requests = list(itertools.chain(*grouped_requests)) - if inplace and scenario != 'SHIFT': + if inplace: # The data is already present in the batch. No need to move it grouped_requests[target_batch_idx] = [] free_indices = batches[target_batch_idx].free_indices() @@ -244,6 +238,10 @@ class CausalLMBatch(Batch): to_tensors = lambda ind: (torch.tensor(ind[0], device=device), torch.tensor(ind[1], device=device)) indices = [[to_tensors(req.update_idx(next(free_indices))) for req in batch_reqs] for batch_reqs in grouped_requests] + max_seq_len = batches[0].attention_mask.size(1) + input_length = max_input_length + right_padding = max_seq_len - input_length + chunk_size = batches[0].past_key_values[0][0].size(0) // batches[0].batch_size num_layers = len(batches[0].past_key_values) past_key_values_type = type(batches[0].past_key_values) @@ -259,14 +257,9 @@ class CausalLMBatch(Batch): for b in batches: b.past_key_values = list(b.past_key_values) - # For prefill there is a space allocated only for first token - # Need to add padding to the max total tokens before first decode - paddings = [(batch.input_length + batch.right_padding) - batch.seq_length for batch in batches] - src = [b.input_ids for b in batches] for b in batches: del b.input_ids - src = pad_tensors(src, paddings, seq_dim, pad_token_id) src = shift_all(src, seq_dim, offsets) input_ids = prepare_memory(new_bs, src[target_batch_idx], inplace) input_ids = move_data(input_ids, 1, indices, src) @@ -274,7 +267,6 @@ class CausalLMBatch(Batch): src = [b.attention_mask for b in batches] for b in batches: del b.attention_mask - src = pad_tensors(src, paddings, seq_dim, 0) src = shift_all(src, seq_dim, offsets) attention_mask = prepare_memory(new_bs, src[target_batch_idx], inplace) attention_mask = move_data(attention_mask, 1, indices, src) @@ -289,13 +281,11 @@ class CausalLMBatch(Batch): past_key_values = [] for layer_num in range(num_layers): src = [b.past_key_values[layer_num][0] for b in batches] - src = pad_tensors(src, paddings, key_dim, 0) src = shift_all(src, key_dim, offsets) updated_key = prepare_memory(new_bs * chunk_size, src[target_batch_idx], inplace) updated_key = move_data(updated_key, chunk_size, indices, src) src = [b.past_key_values[layer_num][1] for b in batches] - src = pad_tensors(src, paddings, value_dim, 0) src = shift_all(src, value_dim, offsets) updated_value = prepare_memory(new_bs * chunk_size, src[target_batch_idx], inplace) updated_value = move_data(updated_value, chunk_size, indices, src) @@ -310,14 +300,10 @@ class CausalLMBatch(Batch): top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( [r.data.parameters for r in flat_requests], - batches[0].next_token_chooser.device, - batches[0].next_token_chooser.dtype + batches[0].next_token_chooser.dtype, + batches[0].next_token_chooser.device ) - max_seq_len = attention_mask.size(1) - input_length = max_input_length - right_padding = max_seq_len - input_length - htorch.core.mark_step() return cls( @@ -392,16 +378,12 @@ class CausalLMBatch(Batch): attention_mask = tokenized_inputs["attention_mask"] if is_optimized_for_gaudi: - # Allocate space for first token input_ids = torch.nn.functional.pad( - input_ids, (0, 1), value=tokenizer.pad_token_id + input_ids, (0, max_new_tokens + extra_padding), value=tokenizer.pad_token_id ) attention_mask = torch.nn.functional.pad( - attention_mask, (0, 1), value=0 - ) - all_input_ids = torch.nn.functional.pad( - input_ids, (0, max_new_tokens + extra_padding - 1), value=tokenizer.pad_token_id - ).T.split(1, dim=1) + attention_mask, (0, max_new_tokens + extra_padding), value=0) + all_input_ids = input_ids.T.split(1, dim=1) else: all_input_ids = input_ids.clone().T.split(1, dim=1) @@ -430,16 +412,16 @@ class CausalLMBatch(Batch): ) @tracer.start_as_current_span("filter") - def filter(self, request_ids: List[int], pad_token_id: int = 0) -> Optional["CausalLMBatch"]: + def filter(self, request_ids: List[int], is_optimized_for_gaudi: bool = False) -> Optional["CausalLMBatch"]: dbg_trace('FILTER', f'num_reqs:{len(self.requests)} -> {len(request_ids)}') request_ids = set(request_ids) self.requests = [req for req in self.requests if req.data.id in request_ids] - return self.__class__.recombine([self], pad_token_id) + return self @classmethod @tracer.start_as_current_span("concatenate") - def concatenate(cls, batches: List["CausalLMBatch"], pad_token_id: int = 0) -> "CausalLMBatch": - return cls.recombine(batches, pad_token_id) + def concatenate(cls, batches: List["CausalLMBatch"], is_optimized_for_gaudi: bool = False) -> "CausalLMBatch": + return cls.recombine(batches, is_optimized_for_gaudi) def __len__(self): return len(self.requests) @@ -664,30 +646,27 @@ class CausalLM(Model): prefill = batch.past_key_values is None # Check if we need to do any bookkeeping first if not prefill: - batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id) + batch = batch.__class__.recombine([batch], self.is_optimized_for_gaudi) scenario = 'PREFILL' if prefill else 'GENERATE' - dbg_trace(scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length}') + dbg_trace(scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}') + assert batch.right_padding > 0, 'No more room for next token!' self.step = self.step + 1 if self.hb_profer_started == True and self.step > self.profiling_warmup_steps + self.profiling_steps: self.hb_profer.stop() self.hb_profer_started = False if self.is_optimized_for_gaudi: - if prefill: - # no right padding for prefill - token_idx = torch.tensor(batch.attention_mask.shape[-1] - 1).to(self.device) - else: - token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device) + token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device) attention_mask = batch.attention_mask else: token_idx = None # slice the attention mask to the correct shape # TODO fix me! attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] - - if not prefill and token_idx is not None: - input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1) + if batch.past_key_values: + if token_idx is not None: + input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1) else: input_ids = batch.input_ids diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index e54f4610..1e17784e 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -59,7 +59,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): {"util": len(batch.requests)}): if batch is None: raise ValueError(f"Batch ID {request.batch_id} not found in cache.") - filtered_batch = batch.filter(request.request_ids, self.model.tokenizer.pad_token_id) + filtered_batch = batch.filter(request.request_ids, self.model.is_optimized_for_gaudi) self.cache.set(filtered_batch) return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) @@ -113,7 +113,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): if len(batches) > 1: with self.profiler.record_event("internal", "concatenate"): - batch = self.model.batch_type.concatenate(batches, self.model.tokenizer.pad_token_id) + batch = self.model.batch_type.concatenate(batches, self.model.is_optimized_for_gaudi) else: batch = batches[0]