mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 23:42:06 +00:00
Co-authored-by: madamczykhabana <110973826+madamczykhabana@users.noreply.github.com>
This commit is contained in:
parent
8f6564ce0e
commit
03c2123244
@ -54,7 +54,7 @@ def round_up(number, k):
|
|||||||
|
|
||||||
|
|
||||||
def to_tensor_indices(indices, device):
|
def to_tensor_indices(indices, device):
|
||||||
return [torch.tensor(idx, dtype=torch.int32, device=device) for idx in indices]
|
return torch.tensor(indices, dtype=torch.int32, device=device)
|
||||||
|
|
||||||
|
|
||||||
def calculate_chunks(offset):
|
def calculate_chunks(offset):
|
||||||
@ -113,28 +113,19 @@ def grouped_shift(tensor_groups, dims, offset, merge_graphs):
|
|||||||
return tensor_groups
|
return tensor_groups
|
||||||
|
|
||||||
|
|
||||||
def move(dst_tensors, dst_dim, dst_indices, src_tensors, src_dim, src_indices):
|
def move(dst_tensors, dst_indices, src_tensors):
|
||||||
if dst_dim == 1 and src_dim == 0:
|
bs_dim = 0
|
||||||
# Case 1: Only destination is merged
|
num_indices = dst_indices.size(0)
|
||||||
dst_tensors = dst_tensors[0]
|
for i, (dst_t, src_t) in enumerate(zip(dst_tensors, src_tensors)):
|
||||||
dst_dim = 0
|
if src_t.size(bs_dim) != num_indices:
|
||||||
elif dst_dim == 0 and src_dim == 1:
|
src_t = torch.narrow(src_t, bs_dim, 0, num_indices)
|
||||||
# Case 2: Only source is merged
|
dst_t.index_copy_(bs_dim, dst_indices, src_t)
|
||||||
src_tensors = src_tensors[0]
|
|
||||||
src_dim = 0
|
|
||||||
else:
|
|
||||||
# Other cases don't need special support
|
|
||||||
pass
|
|
||||||
for dst_t, src_t in zip(dst_tensors, src_tensors):
|
|
||||||
for dst_idx, src_idx in zip(dst_indices, src_indices):
|
|
||||||
dst_t.index_copy_(dst_dim, dst_idx, torch.index_select(src_t, src_dim, src_idx))
|
|
||||||
|
|
||||||
|
|
||||||
def grouped_move(dst_tensor_groups, dst_dims, dst_indices, src_tensor_groups, src_dims, src_indices):
|
|
||||||
for dst_tensors, dst_dim, src_tensors, src_dim in zip(dst_tensor_groups, dst_dims, src_tensor_groups, src_dims):
|
|
||||||
move(dst_tensors, dst_dim, dst_indices, src_tensors, src_dim, src_indices)
|
|
||||||
htorch.core.mark_step()
|
htorch.core.mark_step()
|
||||||
return dst_tensor_groups
|
|
||||||
|
|
||||||
|
def grouped_move(dst_tensor_groups, dst_indices, src_tensor_groups):
|
||||||
|
for dst_tensors, src_tensors in zip(dst_tensor_groups, src_tensor_groups):
|
||||||
|
move(dst_tensors, dst_indices, src_tensors)
|
||||||
|
|
||||||
|
|
||||||
def extend_tensor(tensor, padding, dim):
|
def extend_tensor(tensor, padding, dim):
|
||||||
@ -160,6 +151,20 @@ def grouped_extend_batch(tensor_groups, target_bs, bs_dims):
|
|||||||
return tensor_groups
|
return tensor_groups
|
||||||
|
|
||||||
|
|
||||||
|
def merge(tensor_group):
|
||||||
|
tensor_group = [torch.stack(tensor_group)]
|
||||||
|
htorch.core.mark_step()
|
||||||
|
return tensor_group
|
||||||
|
|
||||||
|
|
||||||
|
def split(tensor_group, clone_data):
|
||||||
|
tensor_group = [t.squeeze(0) for t in torch.split(tensor_group[0], 1)]
|
||||||
|
if clone_data:
|
||||||
|
tensor_group = [t.clone() for t in tensor_group]
|
||||||
|
htorch.core.mark_step()
|
||||||
|
return tensor_group
|
||||||
|
|
||||||
|
|
||||||
def remove_kv_cache_from_output(module):
|
def remove_kv_cache_from_output(module):
|
||||||
orig_fwd = module.forward
|
orig_fwd = module.forward
|
||||||
|
|
||||||
@ -181,13 +186,6 @@ def remove_kv_cache_from_output(module):
|
|||||||
return module
|
return module
|
||||||
|
|
||||||
|
|
||||||
def hpu_graph_fn(fn):
|
|
||||||
class FnModule(torch.nn.Module):
|
|
||||||
def forward(self, *args, **kwargs):
|
|
||||||
return fn(*args, **kwargs)
|
|
||||||
return wrap_in_hpu_graph(FnModule(), disable_tensor_cache=False)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CausalLMRequest:
|
class CausalLMRequest:
|
||||||
idx: int
|
idx: int
|
||||||
@ -265,20 +263,18 @@ class CausalLMBatch(Batch):
|
|||||||
small_bs = len(self.past_key_values) > self.batch_size
|
small_bs = len(self.past_key_values) > self.batch_size
|
||||||
if not self.merged_kv_cache and small_bs and (pad_needed or shift_needed or expand_needed):
|
if not self.merged_kv_cache and small_bs and (pad_needed or shift_needed or expand_needed):
|
||||||
past_keys, past_values = self.detach_kv_cache()
|
past_keys, past_values = self.detach_kv_cache()
|
||||||
past_keys = [torch.stack(past_keys)]
|
past_keys = merge(past_keys)
|
||||||
past_values = [torch.stack(past_values)]
|
past_values = merge(past_values)
|
||||||
self.attach_kv_cache(past_keys, past_values)
|
self.attach_kv_cache(past_keys, past_values)
|
||||||
self.merged_kv_cache = True
|
self.merged_kv_cache = True
|
||||||
htorch.core.mark_step()
|
|
||||||
|
|
||||||
def split_kv_cache_if_needed(self):
|
def split_kv_cache_if_needed(self, clone_data):
|
||||||
if self.merged_kv_cache:
|
if self.merged_kv_cache:
|
||||||
past_keys, past_values = self.detach_kv_cache()
|
past_keys, past_values = self.detach_kv_cache()
|
||||||
past_keys = [t.clone() for t in past_keys[0]]
|
past_keys = split(past_keys, clone_data)
|
||||||
past_values = [t.clone() for t in past_values[0]]
|
past_values = split(past_values, clone_data)
|
||||||
self.attach_kv_cache(past_keys, past_values)
|
self.attach_kv_cache(past_keys, past_values)
|
||||||
self.merged_kv_cache = False
|
self.merged_kv_cache = False
|
||||||
htorch.core.mark_step()
|
|
||||||
|
|
||||||
def get_tensor_groups(self):
|
def get_tensor_groups(self):
|
||||||
past_keys, past_values = self.detach_kv_cache()
|
past_keys, past_values = self.detach_kv_cache()
|
||||||
@ -326,23 +322,9 @@ class CausalLMBatch(Batch):
|
|||||||
dst_tensors, _, dst_dims = self.get_tensor_groups()
|
dst_tensors, _, dst_dims = self.get_tensor_groups()
|
||||||
free_indices_gen = self.free_indices_generator()
|
free_indices_gen = self.free_indices_generator()
|
||||||
for src_b in src_batches:
|
for src_b in src_batches:
|
||||||
src_indices = to_tensor_indices(src_b.used_indices(), self.input_ids.device)
|
|
||||||
dst_indices = to_tensor_indices(src_b.update_indices(free_indices_gen), self.input_ids.device)
|
dst_indices = to_tensor_indices(src_b.update_indices(free_indices_gen), self.input_ids.device)
|
||||||
src_tensors, _, src_dims = src_b.get_tensor_groups()
|
src_tensors, _, src_dims = src_b.get_tensor_groups()
|
||||||
|
grouped_move(dst_tensors, dst_indices, src_tensors)
|
||||||
# Instead of doing one huge grouped_move we're splitting it into 3 to improve perf and mem usage
|
|
||||||
# Both dst_tensors and src_tensors are lists of lists which follow this pattern:
|
|
||||||
# [[position_ids], [attention_mask], [position_ids], past_keys, past_values]
|
|
||||||
|
|
||||||
# move only past_keys
|
|
||||||
dst_tensors[3:4] = grouped_move(dst_tensors[3:4], dst_dims[3:4], dst_indices,
|
|
||||||
src_tensors[3:4], src_dims[3:4], src_indices)
|
|
||||||
# move only past_values
|
|
||||||
dst_tensors[4:5] = grouped_move(dst_tensors[4:5], dst_dims[4:5], dst_indices,
|
|
||||||
src_tensors[4:5], src_dims[4:5], src_indices)
|
|
||||||
# move only input_ids, attention_mask and position_ids
|
|
||||||
dst_tensors[:3] = grouped_move(dst_tensors[:3], dst_dims[:3], dst_indices,
|
|
||||||
src_tensors[:3], src_dims[:3], src_indices)
|
|
||||||
self.set_tensor_groups(dst_tensors)
|
self.set_tensor_groups(dst_tensors)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -392,7 +374,7 @@ class CausalLMBatch(Batch):
|
|||||||
target_bs = new_bs if i == dst_batch_idx else batches[i].batch_size
|
target_bs = new_bs if i == dst_batch_idx else batches[i].batch_size
|
||||||
batches[i].merge_kv_cache_if_needed(target_bs, offsets[i])
|
batches[i].merge_kv_cache_if_needed(target_bs, offsets[i])
|
||||||
batches[i].realign(target_bs, offsets[i], pad_token_id)
|
batches[i].realign(target_bs, offsets[i], pad_token_id)
|
||||||
batches[dst_batch_idx].split_kv_cache_if_needed()
|
batches[i].split_kv_cache_if_needed(i == dst_batch_idx)
|
||||||
batches[dst_batch_idx].expand_bs(new_bs)
|
batches[dst_batch_idx].expand_bs(new_bs)
|
||||||
batches[dst_batch_idx].move_data([batches[i] for i in range(len(batches)) if i != dst_batch_idx])
|
batches[dst_batch_idx].move_data([batches[i] for i in range(len(batches)) if i != dst_batch_idx])
|
||||||
|
|
||||||
@ -1078,4 +1060,4 @@ class CausalLM(Model):
|
|||||||
for chunk in chunk_sizes:
|
for chunk in chunk_sizes:
|
||||||
batch.merge_kv_cache_if_needed(batch.batch_size, chunk)
|
batch.merge_kv_cache_if_needed(batch.batch_size, chunk)
|
||||||
batch.realign(batch.batch_size, chunk, 0)
|
batch.realign(batch.batch_size, chunk, 0)
|
||||||
batch.split_kv_cache_if_needed()
|
batch.split_kv_cache_if_needed(True)
|
||||||
|
Loading…
Reference in New Issue
Block a user