From 7dbf4bf7a4498ab3262d831a82da30e23ea78465 Mon Sep 17 00:00:00 2001 From: Karol Damaszke Date: Thu, 29 Feb 2024 10:48:54 +0100 Subject: [PATCH] Improve tensor slicing performance (#66) (#87) Co-authored-by: mswiniarsk <156412439+mswiniarsk@users.noreply.github.com> --- .../models/causal_lm.py | 27 +++++++++++-------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index b19170cb..e0084be3 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -70,7 +70,7 @@ def calculate_chunks(offset): def biggest_single_chunk(offset): if offset != 0: idx = bisect.bisect(CHUNK_SIZES, abs(offset)) - return int(math.copysign(CHUNK_SIZES[idx-1], offset)) + return int(math.copysign(CHUNK_SIZES[idx - 1], offset)) else: return 0 @@ -145,7 +145,7 @@ def extend_tensor(tensor, padding, dim): def extend_batch(tensors, target_bs, dim): diff = target_bs - tensors[0].size(dim) - #TODO: add support for shrinking bs + # TODO: add support for shrinking bs if diff <= 0: return tensors shape = list(tensors[0].shape) @@ -283,7 +283,7 @@ class CausalLMBatch(Batch): def get_tensor_groups(self): past_keys, past_values = self.detach_kv_cache() seq_dim = -1 - key_dim = -2 # TODO: Add case for Bloom and other models + key_dim = -2 # TODO: Add case for Bloom and other models value_dim = -2 tensors = [[self.input_ids], [self.attention_mask], [self.position_ids], past_keys, past_values] # We don't need to align position_ids @@ -335,11 +335,14 @@ class CausalLMBatch(Batch): # [[position_ids], [attention_mask], [position_ids], past_keys, past_values] # move only past_keys - dst_tensors[3:4] = grouped_move(dst_tensors[3:4], dst_dims[3:4], dst_indices, src_tensors[3:4], src_dims[3:4], src_indices) + dst_tensors[3:4] = grouped_move(dst_tensors[3:4], dst_dims[3:4], dst_indices, + src_tensors[3:4], src_dims[3:4], src_indices) # move only past_values - dst_tensors[4:5] = grouped_move(dst_tensors[4:5], dst_dims[4:5], dst_indices, src_tensors[4:5], src_dims[4:5], src_indices) + dst_tensors[4:5] = grouped_move(dst_tensors[4:5], dst_dims[4:5], dst_indices, + src_tensors[4:5], src_dims[4:5], src_indices) # move only input_ids, attention_mask and position_ids - dst_tensors[:3] = grouped_move(dst_tensors[:3], dst_dims[:3], dst_indices, src_tensors[:3], src_dims[:3], src_indices) + dst_tensors[:3] = grouped_move(dst_tensors[:3], dst_dims[:3], dst_indices, + src_tensors[:3], src_dims[:3], src_indices) self.set_tensor_groups(dst_tensors) @classmethod @@ -792,9 +795,11 @@ class CausalLM(Model): if self.is_optimized_for_gaudi: if prefill: # no right padding for prefill - token_idx = torch.tensor(batch.attention_mask.shape[-1] - 1).to(self.device) + token_idx_scalar = batch.attention_mask.shape[-1] - 1 + token_idx = torch.tensor(token_idx_scalar).to(self.device) else: - token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device) + token_idx_scalar = batch.attention_mask.shape[-1] - batch.right_padding + token_idx = torch.tensor(token_idx_scalar).to(self.device) else: token_idx = None @@ -802,11 +807,11 @@ class CausalLM(Model): input_length = batch.input_length if self.is_optimized_for_gaudi and logits.shape[-2] > 1: next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser( - batch.input_ids[:, :token_idx], logits[:, input_length - 1: input_length, :].squeeze(-2) + batch.input_ids[:, :token_idx_scalar], logits[:, input_length - 1: input_length, :].squeeze(-2) ) else: next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser( - batch.input_ids[:, :token_idx], logits.squeeze(-2) + batch.input_ids[:, :token_idx_scalar], logits.squeeze(-2) ) batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch.top_n_tokens, @@ -853,7 +858,7 @@ class CausalLM(Model): # Update position_ids if prefill: - batch.position_ids = batch.position_ids[:, token_idx - 1: token_idx] + 1 + batch.position_ids = torch.index_select(batch.position_ids, 1, token_idx - 1) + 1 else: batch.position_ids += 1 # Update past key values