diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index e0084be3..db3264e7 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -54,7 +54,7 @@ def round_up(number, k): def to_tensor_indices(indices, device): - return [torch.tensor(idx, dtype=torch.int32, device=device) for idx in indices] + return torch.tensor(indices, dtype=torch.int32, device=device) def calculate_chunks(offset): @@ -86,7 +86,7 @@ def grouped_pad(tensor_groups, dims, values): else: result = [t for t in tensors] grouped_result.append(result) - htorch.core.mark_step() + htorch.core.mark_step() return grouped_result @@ -113,28 +113,19 @@ def grouped_shift(tensor_groups, dims, offset, merge_graphs): return tensor_groups -def move(dst_tensors, dst_dim, dst_indices, src_tensors, src_dim, src_indices): - if dst_dim == 1 and src_dim == 0: - # Case 1: Only destination is merged - dst_tensors = dst_tensors[0] - dst_dim = 0 - elif dst_dim == 0 and src_dim == 1: - # Case 2: Only source is merged - src_tensors = src_tensors[0] - src_dim = 0 - else: - # Other cases don't need special support - pass - for dst_t, src_t in zip(dst_tensors, src_tensors): - for dst_idx, src_idx in zip(dst_indices, src_indices): - dst_t.index_copy_(dst_dim, dst_idx, torch.index_select(src_t, src_dim, src_idx)) - - -def grouped_move(dst_tensor_groups, dst_dims, dst_indices, src_tensor_groups, src_dims, src_indices): - for dst_tensors, dst_dim, src_tensors, src_dim in zip(dst_tensor_groups, dst_dims, src_tensor_groups, src_dims): - move(dst_tensors, dst_dim, dst_indices, src_tensors, src_dim, src_indices) +def move(dst_tensors, dst_indices, src_tensors): + bs_dim = 0 + num_indices = dst_indices.size(0) + for i, (dst_t, src_t) in enumerate(zip(dst_tensors, src_tensors)): + if src_t.size(bs_dim) != num_indices: + src_t = torch.narrow(src_t, bs_dim, 0, num_indices) + dst_t.index_copy_(bs_dim, dst_indices, src_t) htorch.core.mark_step() - return dst_tensor_groups + + +def grouped_move(dst_tensor_groups, dst_indices, src_tensor_groups): + for dst_tensors, src_tensors in zip(dst_tensor_groups, src_tensor_groups): + move(dst_tensors, dst_indices, src_tensors) def extend_tensor(tensor, padding, dim): @@ -160,6 +151,20 @@ def grouped_extend_batch(tensor_groups, target_bs, bs_dims): return tensor_groups +def merge(tensor_group): + tensor_group = [torch.stack(tensor_group)] + htorch.core.mark_step() + return tensor_group + + +def split(tensor_group, clone_data): + tensor_group = [t.squeeze(0) for t in torch.split(tensor_group[0], 1)] + if clone_data: + tensor_group = [t.clone() for t in tensor_group] + htorch.core.mark_step() + return tensor_group + + def remove_kv_cache_from_output(module): orig_fwd = module.forward @@ -181,13 +186,6 @@ def remove_kv_cache_from_output(module): return module -def hpu_graph_fn(fn): - class FnModule(torch.nn.Module): - def forward(self, *args, **kwargs): - return fn(*args, **kwargs) - return wrap_in_hpu_graph(FnModule(), disable_tensor_cache=False) - - @dataclass class CausalLMRequest: idx: int @@ -265,20 +263,18 @@ class CausalLMBatch(Batch): small_bs = len(self.past_key_values) > self.batch_size if not self.merged_kv_cache and small_bs and (pad_needed or shift_needed or expand_needed): past_keys, past_values = self.detach_kv_cache() - past_keys = [torch.stack(past_keys)] - past_values = [torch.stack(past_values)] + past_keys = merge(past_keys) + past_values = merge(past_values) self.attach_kv_cache(past_keys, past_values) self.merged_kv_cache = True - htorch.core.mark_step() - def split_kv_cache_if_needed(self): + def split_kv_cache_if_needed(self, clone_data): if self.merged_kv_cache: past_keys, past_values = self.detach_kv_cache() - past_keys = [t.clone() for t in past_keys[0]] - past_values = [t.clone() for t in past_values[0]] + past_keys = split(past_keys, clone_data) + past_values = split(past_values, clone_data) self.attach_kv_cache(past_keys, past_values) self.merged_kv_cache = False - htorch.core.mark_step() def get_tensor_groups(self): past_keys, past_values = self.detach_kv_cache() @@ -326,23 +322,9 @@ class CausalLMBatch(Batch): dst_tensors, _, dst_dims = self.get_tensor_groups() free_indices_gen = self.free_indices_generator() for src_b in src_batches: - src_indices = to_tensor_indices(src_b.used_indices(), self.input_ids.device) dst_indices = to_tensor_indices(src_b.update_indices(free_indices_gen), self.input_ids.device) src_tensors, _, src_dims = src_b.get_tensor_groups() - - # Instead of doing one huge grouped_move we're splitting it into 3 to improve perf and mem usage - # Both dst_tensors and src_tensors are lists of lists which follow this pattern: - # [[position_ids], [attention_mask], [position_ids], past_keys, past_values] - - # move only past_keys - dst_tensors[3:4] = grouped_move(dst_tensors[3:4], dst_dims[3:4], dst_indices, - src_tensors[3:4], src_dims[3:4], src_indices) - # move only past_values - dst_tensors[4:5] = grouped_move(dst_tensors[4:5], dst_dims[4:5], dst_indices, - src_tensors[4:5], src_dims[4:5], src_indices) - # move only input_ids, attention_mask and position_ids - dst_tensors[:3] = grouped_move(dst_tensors[:3], dst_dims[:3], dst_indices, - src_tensors[:3], src_dims[:3], src_indices) + grouped_move(dst_tensors, dst_indices, src_tensors) self.set_tensor_groups(dst_tensors) @classmethod @@ -392,7 +374,7 @@ class CausalLMBatch(Batch): target_bs = new_bs if i == dst_batch_idx else batches[i].batch_size batches[i].merge_kv_cache_if_needed(target_bs, offsets[i]) batches[i].realign(target_bs, offsets[i], pad_token_id) - batches[dst_batch_idx].split_kv_cache_if_needed() + batches[i].split_kv_cache_if_needed(i == dst_batch_idx) batches[dst_batch_idx].expand_bs(new_bs) batches[dst_batch_idx].move_data([batches[i] for i in range(len(batches)) if i != dst_batch_idx]) @@ -1078,4 +1060,4 @@ class CausalLM(Model): for chunk in chunk_sizes: batch.merge_kv_cache_if_needed(batch.batch_size, chunk) batch.realign(batch.batch_size, chunk, 0) - batch.split_kv_cache_if_needed() + batch.split_kv_cache_if_needed(True)