Improve tensor slicing performance (#66) (#87)

Co-authored-by: mswiniarsk <156412439+mswiniarsk@users.noreply.github.com>
This commit is contained in:
Karol Damaszke 2024-02-29 10:48:54 +01:00 committed by GitHub
parent 3831f1bed5
commit 7dbf4bf7a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -70,7 +70,7 @@ def calculate_chunks(offset):
def biggest_single_chunk(offset):
if offset != 0:
idx = bisect.bisect(CHUNK_SIZES, abs(offset))
return int(math.copysign(CHUNK_SIZES[idx-1], offset))
return int(math.copysign(CHUNK_SIZES[idx - 1], offset))
else:
return 0
@ -145,7 +145,7 @@ def extend_tensor(tensor, padding, dim):
def extend_batch(tensors, target_bs, dim):
diff = target_bs - tensors[0].size(dim)
#TODO: add support for shrinking bs
# TODO: add support for shrinking bs
if diff <= 0:
return tensors
shape = list(tensors[0].shape)
@ -335,11 +335,14 @@ class CausalLMBatch(Batch):
# [[position_ids], [attention_mask], [position_ids], past_keys, past_values]
# move only past_keys
dst_tensors[3:4] = grouped_move(dst_tensors[3:4], dst_dims[3:4], dst_indices, src_tensors[3:4], src_dims[3:4], src_indices)
dst_tensors[3:4] = grouped_move(dst_tensors[3:4], dst_dims[3:4], dst_indices,
src_tensors[3:4], src_dims[3:4], src_indices)
# move only past_values
dst_tensors[4:5] = grouped_move(dst_tensors[4:5], dst_dims[4:5], dst_indices, src_tensors[4:5], src_dims[4:5], src_indices)
dst_tensors[4:5] = grouped_move(dst_tensors[4:5], dst_dims[4:5], dst_indices,
src_tensors[4:5], src_dims[4:5], src_indices)
# move only input_ids, attention_mask and position_ids
dst_tensors[:3] = grouped_move(dst_tensors[:3], dst_dims[:3], dst_indices, src_tensors[:3], src_dims[:3], src_indices)
dst_tensors[:3] = grouped_move(dst_tensors[:3], dst_dims[:3], dst_indices,
src_tensors[:3], src_dims[:3], src_indices)
self.set_tensor_groups(dst_tensors)
@classmethod
@ -792,9 +795,11 @@ class CausalLM(Model):
if self.is_optimized_for_gaudi:
if prefill:
# no right padding for prefill
token_idx = torch.tensor(batch.attention_mask.shape[-1] - 1).to(self.device)
token_idx_scalar = batch.attention_mask.shape[-1] - 1
token_idx = torch.tensor(token_idx_scalar).to(self.device)
else:
token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device)
token_idx_scalar = batch.attention_mask.shape[-1] - batch.right_padding
token_idx = torch.tensor(token_idx_scalar).to(self.device)
else:
token_idx = None
@ -802,11 +807,11 @@ class CausalLM(Model):
input_length = batch.input_length
if self.is_optimized_for_gaudi and logits.shape[-2] > 1:
next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
batch.input_ids[:, :token_idx], logits[:, input_length - 1: input_length, :].squeeze(-2)
batch.input_ids[:, :token_idx_scalar], logits[:, input_length - 1: input_length, :].squeeze(-2)
)
else:
next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
batch.input_ids[:, :token_idx], logits.squeeze(-2)
batch.input_ids[:, :token_idx_scalar], logits.squeeze(-2)
)
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens,
@ -853,7 +858,7 @@ class CausalLM(Model):
# Update position_ids
if prefill:
batch.position_ids = batch.position_ids[:, token_idx - 1: token_idx] + 1
batch.position_ids = torch.index_select(batch.position_ids, 1, token_idx - 1) + 1
else:
batch.position_ids += 1
# Update past key values