Use batched index_copy (#73) (#89)

Co-authored-by: madamczykhabana <110973826+madamczykhabana@users.noreply.github.com>
This commit is contained in:
Karol Damaszke 2024-02-29 15:45:16 +01:00 committed by GitHub
parent 8f6564ce0e
commit 03c2123244
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -54,7 +54,7 @@ def round_up(number, k):
def to_tensor_indices(indices, device):
return [torch.tensor(idx, dtype=torch.int32, device=device) for idx in indices]
return torch.tensor(indices, dtype=torch.int32, device=device)
def calculate_chunks(offset):
@ -86,7 +86,7 @@ def grouped_pad(tensor_groups, dims, values):
else:
result = [t for t in tensors]
grouped_result.append(result)
htorch.core.mark_step()
htorch.core.mark_step()
return grouped_result
@ -113,28 +113,19 @@ def grouped_shift(tensor_groups, dims, offset, merge_graphs):
return tensor_groups
def move(dst_tensors, dst_dim, dst_indices, src_tensors, src_dim, src_indices):
if dst_dim == 1 and src_dim == 0:
# Case 1: Only destination is merged
dst_tensors = dst_tensors[0]
dst_dim = 0
elif dst_dim == 0 and src_dim == 1:
# Case 2: Only source is merged
src_tensors = src_tensors[0]
src_dim = 0
else:
# Other cases don't need special support
pass
for dst_t, src_t in zip(dst_tensors, src_tensors):
for dst_idx, src_idx in zip(dst_indices, src_indices):
dst_t.index_copy_(dst_dim, dst_idx, torch.index_select(src_t, src_dim, src_idx))
def grouped_move(dst_tensor_groups, dst_dims, dst_indices, src_tensor_groups, src_dims, src_indices):
for dst_tensors, dst_dim, src_tensors, src_dim in zip(dst_tensor_groups, dst_dims, src_tensor_groups, src_dims):
move(dst_tensors, dst_dim, dst_indices, src_tensors, src_dim, src_indices)
def move(dst_tensors, dst_indices, src_tensors):
bs_dim = 0
num_indices = dst_indices.size(0)
for i, (dst_t, src_t) in enumerate(zip(dst_tensors, src_tensors)):
if src_t.size(bs_dim) != num_indices:
src_t = torch.narrow(src_t, bs_dim, 0, num_indices)
dst_t.index_copy_(bs_dim, dst_indices, src_t)
htorch.core.mark_step()
return dst_tensor_groups
def grouped_move(dst_tensor_groups, dst_indices, src_tensor_groups):
for dst_tensors, src_tensors in zip(dst_tensor_groups, src_tensor_groups):
move(dst_tensors, dst_indices, src_tensors)
def extend_tensor(tensor, padding, dim):
@ -160,6 +151,20 @@ def grouped_extend_batch(tensor_groups, target_bs, bs_dims):
return tensor_groups
def merge(tensor_group):
tensor_group = [torch.stack(tensor_group)]
htorch.core.mark_step()
return tensor_group
def split(tensor_group, clone_data):
tensor_group = [t.squeeze(0) for t in torch.split(tensor_group[0], 1)]
if clone_data:
tensor_group = [t.clone() for t in tensor_group]
htorch.core.mark_step()
return tensor_group
def remove_kv_cache_from_output(module):
orig_fwd = module.forward
@ -181,13 +186,6 @@ def remove_kv_cache_from_output(module):
return module
def hpu_graph_fn(fn):
class FnModule(torch.nn.Module):
def forward(self, *args, **kwargs):
return fn(*args, **kwargs)
return wrap_in_hpu_graph(FnModule(), disable_tensor_cache=False)
@dataclass
class CausalLMRequest:
idx: int
@ -265,20 +263,18 @@ class CausalLMBatch(Batch):
small_bs = len(self.past_key_values) > self.batch_size
if not self.merged_kv_cache and small_bs and (pad_needed or shift_needed or expand_needed):
past_keys, past_values = self.detach_kv_cache()
past_keys = [torch.stack(past_keys)]
past_values = [torch.stack(past_values)]
past_keys = merge(past_keys)
past_values = merge(past_values)
self.attach_kv_cache(past_keys, past_values)
self.merged_kv_cache = True
htorch.core.mark_step()
def split_kv_cache_if_needed(self):
def split_kv_cache_if_needed(self, clone_data):
if self.merged_kv_cache:
past_keys, past_values = self.detach_kv_cache()
past_keys = [t.clone() for t in past_keys[0]]
past_values = [t.clone() for t in past_values[0]]
past_keys = split(past_keys, clone_data)
past_values = split(past_values, clone_data)
self.attach_kv_cache(past_keys, past_values)
self.merged_kv_cache = False
htorch.core.mark_step()
def get_tensor_groups(self):
past_keys, past_values = self.detach_kv_cache()
@ -326,23 +322,9 @@ class CausalLMBatch(Batch):
dst_tensors, _, dst_dims = self.get_tensor_groups()
free_indices_gen = self.free_indices_generator()
for src_b in src_batches:
src_indices = to_tensor_indices(src_b.used_indices(), self.input_ids.device)
dst_indices = to_tensor_indices(src_b.update_indices(free_indices_gen), self.input_ids.device)
src_tensors, _, src_dims = src_b.get_tensor_groups()
# Instead of doing one huge grouped_move we're splitting it into 3 to improve perf and mem usage
# Both dst_tensors and src_tensors are lists of lists which follow this pattern:
# [[position_ids], [attention_mask], [position_ids], past_keys, past_values]
# move only past_keys
dst_tensors[3:4] = grouped_move(dst_tensors[3:4], dst_dims[3:4], dst_indices,
src_tensors[3:4], src_dims[3:4], src_indices)
# move only past_values
dst_tensors[4:5] = grouped_move(dst_tensors[4:5], dst_dims[4:5], dst_indices,
src_tensors[4:5], src_dims[4:5], src_indices)
# move only input_ids, attention_mask and position_ids
dst_tensors[:3] = grouped_move(dst_tensors[:3], dst_dims[:3], dst_indices,
src_tensors[:3], src_dims[:3], src_indices)
grouped_move(dst_tensors, dst_indices, src_tensors)
self.set_tensor_groups(dst_tensors)
@classmethod
@ -392,7 +374,7 @@ class CausalLMBatch(Batch):
target_bs = new_bs if i == dst_batch_idx else batches[i].batch_size
batches[i].merge_kv_cache_if_needed(target_bs, offsets[i])
batches[i].realign(target_bs, offsets[i], pad_token_id)
batches[dst_batch_idx].split_kv_cache_if_needed()
batches[i].split_kv_cache_if_needed(i == dst_batch_idx)
batches[dst_batch_idx].expand_bs(new_bs)
batches[dst_batch_idx].move_data([batches[i] for i in range(len(batches)) if i != dst_batch_idx])
@ -1078,4 +1060,4 @@ class CausalLM(Model):
for chunk in chunk_sizes:
batch.merge_kv_cache_if_needed(batch.batch_size, chunk)
batch.realign(batch.batch_size, chunk, 0)
batch.split_kv_cache_if_needed()
batch.split_kv_cache_if_needed(True)