mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
Co-authored-by: madamczykhabana <110973826+madamczykhabana@users.noreply.github.com>
This commit is contained in:
parent
8f6564ce0e
commit
03c2123244
@ -54,7 +54,7 @@ def round_up(number, k):
|
||||
|
||||
|
||||
def to_tensor_indices(indices, device):
|
||||
return [torch.tensor(idx, dtype=torch.int32, device=device) for idx in indices]
|
||||
return torch.tensor(indices, dtype=torch.int32, device=device)
|
||||
|
||||
|
||||
def calculate_chunks(offset):
|
||||
@ -86,7 +86,7 @@ def grouped_pad(tensor_groups, dims, values):
|
||||
else:
|
||||
result = [t for t in tensors]
|
||||
grouped_result.append(result)
|
||||
htorch.core.mark_step()
|
||||
htorch.core.mark_step()
|
||||
return grouped_result
|
||||
|
||||
|
||||
@ -113,28 +113,19 @@ def grouped_shift(tensor_groups, dims, offset, merge_graphs):
|
||||
return tensor_groups
|
||||
|
||||
|
||||
def move(dst_tensors, dst_dim, dst_indices, src_tensors, src_dim, src_indices):
|
||||
if dst_dim == 1 and src_dim == 0:
|
||||
# Case 1: Only destination is merged
|
||||
dst_tensors = dst_tensors[0]
|
||||
dst_dim = 0
|
||||
elif dst_dim == 0 and src_dim == 1:
|
||||
# Case 2: Only source is merged
|
||||
src_tensors = src_tensors[0]
|
||||
src_dim = 0
|
||||
else:
|
||||
# Other cases don't need special support
|
||||
pass
|
||||
for dst_t, src_t in zip(dst_tensors, src_tensors):
|
||||
for dst_idx, src_idx in zip(dst_indices, src_indices):
|
||||
dst_t.index_copy_(dst_dim, dst_idx, torch.index_select(src_t, src_dim, src_idx))
|
||||
|
||||
|
||||
def grouped_move(dst_tensor_groups, dst_dims, dst_indices, src_tensor_groups, src_dims, src_indices):
|
||||
for dst_tensors, dst_dim, src_tensors, src_dim in zip(dst_tensor_groups, dst_dims, src_tensor_groups, src_dims):
|
||||
move(dst_tensors, dst_dim, dst_indices, src_tensors, src_dim, src_indices)
|
||||
def move(dst_tensors, dst_indices, src_tensors):
|
||||
bs_dim = 0
|
||||
num_indices = dst_indices.size(0)
|
||||
for i, (dst_t, src_t) in enumerate(zip(dst_tensors, src_tensors)):
|
||||
if src_t.size(bs_dim) != num_indices:
|
||||
src_t = torch.narrow(src_t, bs_dim, 0, num_indices)
|
||||
dst_t.index_copy_(bs_dim, dst_indices, src_t)
|
||||
htorch.core.mark_step()
|
||||
return dst_tensor_groups
|
||||
|
||||
|
||||
def grouped_move(dst_tensor_groups, dst_indices, src_tensor_groups):
|
||||
for dst_tensors, src_tensors in zip(dst_tensor_groups, src_tensor_groups):
|
||||
move(dst_tensors, dst_indices, src_tensors)
|
||||
|
||||
|
||||
def extend_tensor(tensor, padding, dim):
|
||||
@ -160,6 +151,20 @@ def grouped_extend_batch(tensor_groups, target_bs, bs_dims):
|
||||
return tensor_groups
|
||||
|
||||
|
||||
def merge(tensor_group):
|
||||
tensor_group = [torch.stack(tensor_group)]
|
||||
htorch.core.mark_step()
|
||||
return tensor_group
|
||||
|
||||
|
||||
def split(tensor_group, clone_data):
|
||||
tensor_group = [t.squeeze(0) for t in torch.split(tensor_group[0], 1)]
|
||||
if clone_data:
|
||||
tensor_group = [t.clone() for t in tensor_group]
|
||||
htorch.core.mark_step()
|
||||
return tensor_group
|
||||
|
||||
|
||||
def remove_kv_cache_from_output(module):
|
||||
orig_fwd = module.forward
|
||||
|
||||
@ -181,13 +186,6 @@ def remove_kv_cache_from_output(module):
|
||||
return module
|
||||
|
||||
|
||||
def hpu_graph_fn(fn):
|
||||
class FnModule(torch.nn.Module):
|
||||
def forward(self, *args, **kwargs):
|
||||
return fn(*args, **kwargs)
|
||||
return wrap_in_hpu_graph(FnModule(), disable_tensor_cache=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CausalLMRequest:
|
||||
idx: int
|
||||
@ -265,20 +263,18 @@ class CausalLMBatch(Batch):
|
||||
small_bs = len(self.past_key_values) > self.batch_size
|
||||
if not self.merged_kv_cache and small_bs and (pad_needed or shift_needed or expand_needed):
|
||||
past_keys, past_values = self.detach_kv_cache()
|
||||
past_keys = [torch.stack(past_keys)]
|
||||
past_values = [torch.stack(past_values)]
|
||||
past_keys = merge(past_keys)
|
||||
past_values = merge(past_values)
|
||||
self.attach_kv_cache(past_keys, past_values)
|
||||
self.merged_kv_cache = True
|
||||
htorch.core.mark_step()
|
||||
|
||||
def split_kv_cache_if_needed(self):
|
||||
def split_kv_cache_if_needed(self, clone_data):
|
||||
if self.merged_kv_cache:
|
||||
past_keys, past_values = self.detach_kv_cache()
|
||||
past_keys = [t.clone() for t in past_keys[0]]
|
||||
past_values = [t.clone() for t in past_values[0]]
|
||||
past_keys = split(past_keys, clone_data)
|
||||
past_values = split(past_values, clone_data)
|
||||
self.attach_kv_cache(past_keys, past_values)
|
||||
self.merged_kv_cache = False
|
||||
htorch.core.mark_step()
|
||||
|
||||
def get_tensor_groups(self):
|
||||
past_keys, past_values = self.detach_kv_cache()
|
||||
@ -326,23 +322,9 @@ class CausalLMBatch(Batch):
|
||||
dst_tensors, _, dst_dims = self.get_tensor_groups()
|
||||
free_indices_gen = self.free_indices_generator()
|
||||
for src_b in src_batches:
|
||||
src_indices = to_tensor_indices(src_b.used_indices(), self.input_ids.device)
|
||||
dst_indices = to_tensor_indices(src_b.update_indices(free_indices_gen), self.input_ids.device)
|
||||
src_tensors, _, src_dims = src_b.get_tensor_groups()
|
||||
|
||||
# Instead of doing one huge grouped_move we're splitting it into 3 to improve perf and mem usage
|
||||
# Both dst_tensors and src_tensors are lists of lists which follow this pattern:
|
||||
# [[position_ids], [attention_mask], [position_ids], past_keys, past_values]
|
||||
|
||||
# move only past_keys
|
||||
dst_tensors[3:4] = grouped_move(dst_tensors[3:4], dst_dims[3:4], dst_indices,
|
||||
src_tensors[3:4], src_dims[3:4], src_indices)
|
||||
# move only past_values
|
||||
dst_tensors[4:5] = grouped_move(dst_tensors[4:5], dst_dims[4:5], dst_indices,
|
||||
src_tensors[4:5], src_dims[4:5], src_indices)
|
||||
# move only input_ids, attention_mask and position_ids
|
||||
dst_tensors[:3] = grouped_move(dst_tensors[:3], dst_dims[:3], dst_indices,
|
||||
src_tensors[:3], src_dims[:3], src_indices)
|
||||
grouped_move(dst_tensors, dst_indices, src_tensors)
|
||||
self.set_tensor_groups(dst_tensors)
|
||||
|
||||
@classmethod
|
||||
@ -392,7 +374,7 @@ class CausalLMBatch(Batch):
|
||||
target_bs = new_bs if i == dst_batch_idx else batches[i].batch_size
|
||||
batches[i].merge_kv_cache_if_needed(target_bs, offsets[i])
|
||||
batches[i].realign(target_bs, offsets[i], pad_token_id)
|
||||
batches[dst_batch_idx].split_kv_cache_if_needed()
|
||||
batches[i].split_kv_cache_if_needed(i == dst_batch_idx)
|
||||
batches[dst_batch_idx].expand_bs(new_bs)
|
||||
batches[dst_batch_idx].move_data([batches[i] for i in range(len(batches)) if i != dst_batch_idx])
|
||||
|
||||
@ -1078,4 +1060,4 @@ class CausalLM(Model):
|
||||
for chunk in chunk_sizes:
|
||||
batch.merge_kv_cache_if_needed(batch.batch_size, chunk)
|
||||
batch.realign(batch.batch_size, chunk, 0)
|
||||
batch.split_kv_cache_if_needed()
|
||||
batch.split_kv_cache_if_needed(True)
|
||||
|
Loading…
Reference in New Issue
Block a user