diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 42da4d06..5cc35165 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -3,6 +3,8 @@ import tempfile import itertools import time import glob +import bisect +import math import torch @@ -51,6 +53,7 @@ PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 4)) DBG_TRACE_FILENAME = os.environ.get('DBG_TRACE_FILENAME') START_TS = None +CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] def count_hpu_graphs(): @@ -74,61 +77,111 @@ def round_up(number, k): return (number + k - 1) // k * k -def prepare_memory(new_bs, tensor, inplace): - if inplace: - return tensor - else: - return tensor.new_empty((new_bs,) + tensor.shape[1:]) +def to_tensor_indices(indices, device): + return [torch.tensor(idx, dtype=torch.int32, device=device) for idx in indices] -def move_data(dst_tensor, chunk_size, indices, src_tensors): - batch_dim = 0 - bs = dst_tensor.size(batch_dim) - assert bs % chunk_size == 0, 'Batch dim must be divisible by chunk size!' - result = dst_tensor - if chunk_size > 1: - dst_tensor = dst_tensor.view(bs // chunk_size, chunk_size, *dst_tensor.shape[1:]) - htorch.core.mark_step() - for ind, src_t in zip(indices, src_tensors): - if chunk_size > 1: - src_t = src_t.view(bs // chunk_size, chunk_size, *src_t.shape[1:]) - for dst_idx, src_idx in ind: - src_data = torch.index_select(src_t, batch_dim, src_idx) - dst_tensor.index_copy_(batch_dim, dst_idx, src_data) - htorch.core.mark_step() - return result - - -def generate_shift_chunks(offset): - chunk_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] +def calculate_chunks(offset): result = [] while offset != 0: sign = 1 if offset > 0 else -1 - best_chunk = min((abs(offset - sign * c), sign * c) for c in chunk_sizes)[1] + best_chunk = min((abs(offset - sign * c), sign * c) for c in CHUNK_SIZES)[1] result.append(best_chunk) offset = offset - best_chunk return result -def roll(tensor, dim, chunks): - dbg_trace('ROLL', f'shape:{list(tensor.shape)} dim:{dim} chunks:{chunks}') - for c in chunks: - tensor = torch.roll(tensor, c, dim) +def biggest_single_chunk(offset): + if offset != 0: + idx = bisect.bisect(CHUNK_SIZES, abs(offset)) + return int(math.copysign(CHUNK_SIZES[idx-1], offset)) + else: + return 0 + + +def grouped_pad(tensor_groups, dims, values): + grouped_result = [] + for tensors, dim, value in zip(tensor_groups, dims, values): + padding = MAX_TOTAL_TOKENS - tensors[0].size(dim) if dim is not None else 0 + if padding > 0: + assert dim in [-1, -2], f'Only dims -1 and -2 are supported! {dim}' + pad_shape = (0, 0, 0, padding) if dim == -2 else (0, padding) + result = [torch.nn.functional.pad(t, pad_shape, value=value) for t in tensors] + else: + result = [t for t in tensors] + grouped_result.append(result) + htorch.core.mark_step() + return grouped_result + + +def roll(tensor, chunk, dim, merge_graphs): + if dim is None: + return tensor + tensor = torch.roll(tensor, chunk, dim) + if not merge_graphs: htorch.core.mark_step() return tensor -def shift(tensor, dim, offset): - assert dim < 0, 'Only negative dims are supported' - if offset == 0: - return tensor - chunks = generate_shift_chunks(offset) - tensor = roll(tensor, dim, chunks) - return tensor +def grouped_roll(tensor_groups, chunk, dims, merge_graphs): + tensor_groups = [[roll(t, chunk, dim, merge_graphs) for t in tensors] for tensors, dim in zip(tensor_groups, dims)] + if merge_graphs: + htorch.core.mark_step() + return tensor_groups -def shift_all(srcs, dim, offsets): - return [shift(src, dim, offset) for src, offset in zip(srcs, offsets)] +def grouped_shift(tensor_groups, dims, offset, merge_graphs): + chunks = calculate_chunks(offset) + for c in chunks: + tensor_groups = grouped_roll(tensor_groups, c, dims, merge_graphs) + return tensor_groups + + +def move(dst_tensors, dst_dim, dst_indices, src_tensors, src_dim, src_indices): + if dst_dim == 1 and src_dim == 0: + # Case 1: Only destination is merged + dst_tensors = dst_tensors[0] + dst_dim = 0 + elif dst_dim == 0 and src_dim == 1: + # Case 2: Only source is merged + src_tensors = src_tensors[0] + src_dim = 0 + else: + # Other cases don't need special support + pass + for dst_t, src_t in zip(dst_tensors, src_tensors): + for dst_idx, src_idx in zip(dst_indices, src_indices): + dst_t.index_copy_(dst_dim, dst_idx, torch.index_select(src_t, src_dim, src_idx)) + + +def grouped_move(dst_tensor_groups, dst_dims, dst_indices, src_tensor_groups, src_dims, src_indices): + for dst_tensors, dst_dim, src_tensors, src_dim in zip(dst_tensor_groups, dst_dims, src_tensor_groups, src_dims): + move(dst_tensors, dst_dim, dst_indices, src_tensors, src_dim, src_indices) + htorch.core.mark_step() + return dst_tensor_groups + + +def extend_tensor(tensor, padding, dim): + result = torch.cat([tensor, padding], dim=dim) + htorch.core.mark_step() + return result + + +def extend_batch(tensors, target_bs, dim): + diff = target_bs - tensors[0].size(dim) + #TODO: add support for shrinking bs + if diff <= 0: + return tensors + shape = list(tensors[0].shape) + shape[dim] = diff + padding = torch.empty(shape, device=tensors[0].device, dtype=tensors[0].dtype) + tensors = [extend_tensor(t, padding, dim) for t in tensors] + return tensors + + +def grouped_extend_batch(tensor_groups, target_bs, bs_dims): + tensor_groups = [extend_batch(tensors, target_bs, dim) for tensors, dim in zip(tensor_groups, bs_dims)] + return tensor_groups def remove_kv_cache_from_output(module): @@ -152,13 +205,11 @@ def remove_kv_cache_from_output(module): return module -def pad_tensors(tensors, paddings, dim, value): - for i, (tensor, padding) in enumerate(zip(tensors, paddings)): - if padding > 0: - pad_shape = (0, 0, 0, padding) if dim == -2 else (0, padding) - tensors[i] = torch.nn.functional.pad(tensor, pad_shape, value=value) - htorch.core.mark_step() - return tensors +def hpu_graph_fn(fn): + class FnModule(torch.nn.Module): + def forward(self, *args, **kwargs): + return fn(*args, **kwargs) + return wrap_in_hpu_graph(FnModule(), disable_tensor_cache=False) @dataclass @@ -199,6 +250,7 @@ class CausalLMBatch(Batch): attention_mask: torch.Tensor position_ids: torch.Tensor past_key_values: Optional[List[Tuple]] + merged_kv_cache: bool # Generation helpers next_token_chooser: HeterogeneousNextTokenChooser @@ -218,6 +270,103 @@ class CausalLMBatch(Batch): max_tokens=self.max_tokens, ) + def detach_kv_cache(self): + past_keys = [past[0] for past in self.past_key_values] + past_values = [past[1] for past in self.past_key_values] + del self.past_key_values + return past_keys, past_values + + def attach_kv_cache(self, past_keys, past_values): + # TODO: Add support for models that don't store kv_cache in a list + self.past_key_values = list(zip(past_keys, past_values)) + + def merge_kv_cache_if_needed(self, target_bs, offset): + pad_needed = self.seq_length < MAX_TOTAL_TOKENS + shift_needed = offset != 0 + expand_needed = target_bs > self.batch_size + # Very simple heuristic to determine whether we should merge tensors + # this needs tuning for other models/scenarios + small_bs = len(self.past_key_values) > self.batch_size + if not self.merged_kv_cache and small_bs and (pad_needed or shift_needed or expand_needed): + past_keys, past_values = self.detach_kv_cache() + past_keys = [torch.stack(past_keys)] + past_values = [torch.stack(past_values)] + self.attach_kv_cache(past_keys, past_values) + self.merged_kv_cache = True + htorch.core.mark_step() + + def split_kv_cache_if_needed(self): + if self.merged_kv_cache: + past_keys, past_values = self.detach_kv_cache() + past_keys = [t.clone() for t in past_keys[0]] + past_values = [t.clone() for t in past_values[0]] + self.attach_kv_cache(past_keys, past_values) + self.merged_kv_cache = False + htorch.core.mark_step() + + def get_tensor_groups(self): + past_keys, past_values = self.detach_kv_cache() + seq_dim = -1 + key_dim = -2 # TODO: Add case for Bloom and other models + value_dim = -2 + tensors = [[self.input_ids], [self.attention_mask], [self.position_ids], past_keys, past_values] + # We don't need to align position_ids + seq_dims = [seq_dim, seq_dim, None, key_dim, value_dim] + bs_dims = [0, 0, 0] + ([1, 1] if self.merged_kv_cache else [0, 0]) + return tensors, seq_dims, bs_dims + + def set_tensor_groups(self, tensors): + self.input_ids = tensors.pop(0)[0] + self.attention_mask = tensors.pop(0)[0] + self.position_ids = tensors.pop(0)[0] + past_keys = tensors.pop(0) + past_values = tensors.pop(0) + self.attach_kv_cache(past_keys, past_values) + + def realign(self, target_bs, offset, pad_token_id): + tensors, seq_dims, _ = self.get_tensor_groups() + tensors = grouped_pad(tensors, seq_dims, [pad_token_id, 0, 0, 0, 0]) + tensors = grouped_shift(tensors, seq_dims, offset, self.merged_kv_cache) + self.set_tensor_groups(tensors) + + def expand_bs(self, target_bs): + tensors, _, bs_dims = self.get_tensor_groups() + tensors = grouped_extend_batch(tensors, target_bs, bs_dims) + self.set_tensor_groups(tensors) + + def used_indices(self): + return [req.idx for req in self.requests] + + def update_indices(self, new_indices): + for req, new_idx in zip(self.requests, new_indices): + req.idx = new_idx + return self.used_indices() + + def free_indices_generator(self): + used = set(req.idx for req in self.requests) + return (i for i in range(self.batch_size) if i not in used) + + def move_data(self, src_batches): + dst_tensors, _, dst_dims = self.get_tensor_groups() + free_indices_gen = self.free_indices_generator() + for src_b in src_batches: + src_indices = to_tensor_indices(src_b.used_indices(), self.input_ids.device) + dst_indices = to_tensor_indices(src_b.update_indices(free_indices_gen), self.input_ids.device) + src_tensors, _, src_dims = src_b.get_tensor_groups() + + # Instead of doing one huge grouped_move we're splitting it into 3 to improve perf and mem usage + # Both dst_tensors and src_tensors are lists of lists which follow this pattern: + # [[position_ids], [attention_mask], [position_ids], past_keys, past_values] + + # move only past_keys + dst_tensors[3:4] = grouped_move(dst_tensors[3:4], dst_dims[3:4], dst_indices, src_tensors[3:4], src_dims[3:4], src_indices) + # move only past_values + dst_tensors[4:5] = grouped_move(dst_tensors[4:5], dst_dims[4:5], dst_indices, src_tensors[4:5], src_dims[4:5], src_indices) + # move only input_ids, attention_mask and position_ids + dst_tensors[:3] = grouped_move(dst_tensors[:3], dst_dims[:3], dst_indices, src_tensors[:3], src_dims[:3], src_indices) + self.set_tensor_groups(dst_tensors) + + @classmethod def recombine(cls, batches: List["CausalLMBatch"], pad_token_id: int) -> "CausalLMBatch": total_requests = sum(len(b) for b in batches) @@ -228,127 +377,59 @@ class CausalLMBatch(Batch): input_lengths = [b.input_length for b in batches] max_input_length = max(input_lengths) offsets = [max_input_length - b.input_length for b in batches] - padding = [b.right_padding for b in batches] + cur_padding = [b.right_padding for b in batches] # For prefill there is a space allocated only for first token # Need to add padding to the max total tokens before first decode - extra_padding = [MAX_TOTAL_TOKENS - b.seq_length for b in batches] moves_needed = [total_requests - len(b) if b.batch_size == new_bs else total_requests for b in batches] - target_batch_idx = min(enumerate(moves_needed), key=lambda idx_val: idx_val[1])[0] + dst_batch_idx = min(enumerate(moves_needed), key=lambda idx_val: idx_val[1])[0] + reshape = (batches[dst_batch_idx].batch_size != new_bs) # TODO: Add support for changing max seq len, i.e. due to output length bucketing # FIXME: max_seq_len for non optimized code if len(batches) > 1: scenario = 'CONCAT' - elif batches[target_batch_idx].batch_size != new_bs: + elif reshape: scenario = 'RESHAPE' - elif padding[target_batch_idx] <= 0: + elif cur_padding[dst_batch_idx] <= 0: scenario = 'SHIFT' - offsets = [b.max_input_length - max_input_length for b in batches] - max_input_length = max(b.max_input_length for b in batches) + offsets = [biggest_single_chunk(b.max_input_length - max_input_length) for b in batches] + max_input_length = max_input_length + offsets[dst_batch_idx] else: # Nothing to do return batches[0] - inplace = (batches[target_batch_idx].batch_size == new_bs) - dbg_trace( scenario, f'bs:{[b.batch_size for b in batches]}->{new_bs}' f' reqs:{[len(b) for b in batches]}' f' offsets:{offsets}' f' input_lengths:{input_lengths}' - f' cur_padding:{padding}' - f' inplace:{inplace}') + f' cur_padding:{cur_padding}' + f' dst_batch:{dst_batch_idx}') grouped_requests = [[req for req in batch.requests] for batch in batches] flat_requests = list(itertools.chain(*grouped_requests)) - if inplace: - # The data is already present in the batch. No need to move it - grouped_requests[target_batch_idx] = [] - free_indices = batches[target_batch_idx].free_indices() - else: - free_indices = itertools.count(0) - def to_tensors(ind): return (torch.tensor(ind[0], device=device), torch.tensor(ind[1], device=device)) - indices = [[to_tensors(req.update_idx(next(free_indices))) for req in batch_reqs] - for batch_reqs in grouped_requests] - - chunk_size = batches[0].past_key_values[0][0].size(0) // batches[0].batch_size - num_layers = len(batches[0].past_key_values) - past_key_values_type = type(batches[0].past_key_values) - - seq_dim = -1 - if batches[0].past_key_values[0][0].size(-1) != batches[0].past_key_values[0][1].size(-1): - # Case for Bloom - key_dim = -1 - else: - key_dim = -2 - value_dim = -2 - - for b in batches: - b.past_key_values = list(b.past_key_values) - - src = [b.input_ids for b in batches] - for b in batches: - del b.input_ids - src = pad_tensors(src, extra_padding, seq_dim, pad_token_id) - src = shift_all(src, seq_dim, offsets) - input_ids = prepare_memory(new_bs, src[target_batch_idx], inplace) - input_ids = move_data(input_ids, 1, indices, src) - - src = [b.attention_mask for b in batches] - for b in batches: - del b.attention_mask - src = pad_tensors(src, extra_padding, seq_dim, 0) - src = shift_all(src, seq_dim, offsets) - attention_mask = prepare_memory(new_bs, src[target_batch_idx], inplace) - attention_mask = move_data(attention_mask, 1, indices, src) - - src = [b.position_ids for b in batches] - for b in batches: - del b.position_ids - position_ids = prepare_memory(new_bs, src[target_batch_idx], inplace) - position_ids = move_data(position_ids, 1, indices, src) - - src = None - src_keys = [[b.past_key_values[layer_num][0] for layer_num in range(num_layers)] for b in batches] - src_values = [[b.past_key_values[layer_num][1] for layer_num in range(num_layers)] for b in batches] - for b in batches: - del b.past_key_values - - src_keys = [torch.stack(src) for src in src_keys] - htorch.core.mark_step() - src_keys = pad_tensors(src_keys, extra_padding, key_dim, 0) - src_keys = shift_all(src_keys, key_dim, offsets) - src_keys = [[t.squeeze(0).clone() for t in torch.split(src, 1)] for src in src_keys] - htorch.core.mark_step() - - dst_keys = [prepare_memory(new_bs * chunk_size, prev, inplace) for prev in src_keys[target_batch_idx]] - dst_keys = [move_data(dst_keys[layer_num], chunk_size, indices, [src[layer_num] - for src in src_keys]) for layer_num in range(num_layers)] - - src_values = [torch.stack(src) for src in src_values] - htorch.core.mark_step() - src_values = pad_tensors(src_values, extra_padding, value_dim, 0) - src_values = shift_all(src_values, value_dim, offsets) - src_values = [[t.squeeze(0).clone() for t in torch.split(src, 1)] for src in src_values] - htorch.core.mark_step() - - dst_values = [prepare_memory(new_bs * chunk_size, prev, inplace) for prev in src_values[target_batch_idx]] - dst_values = [move_data(dst_values[layer_num], chunk_size, indices, [src[layer_num] - for src in src_values]) for layer_num in range(num_layers)] - - past_key_values = past_key_values_type(zip(dst_keys, dst_values)) + for i in range(len(batches)): + target_bs = new_bs if i == dst_batch_idx else batches[i].batch_size + batches[i].merge_kv_cache_if_needed(target_bs, offsets[i]) + batches[i].realign(target_bs, offsets[i], pad_token_id) + batches[dst_batch_idx].split_kv_cache_if_needed() + batches[dst_batch_idx].expand_bs(new_bs) + batches[dst_batch_idx].move_data([batches[i] for i in range(len(batches)) if i != dst_batch_idx]) top_n_tokens = [r.data.top_n_tokens for r in flat_requests] top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( [r.data.parameters for r in flat_requests], - batches[0].next_token_chooser.dtype, - batches[0].next_token_chooser.device + batches[dst_batch_idx].next_token_chooser.dtype, + batches[dst_batch_idx].next_token_chooser.device ) - max_seq_len = attention_mask.size(1) + input_ids = batches[dst_batch_idx].input_ids + attention_mask = batches[dst_batch_idx].attention_mask + position_ids = batches[dst_batch_idx].position_ids + past_key_values = batches[dst_batch_idx].past_key_values input_length = max_input_length htorch.core.mark_step() @@ -360,6 +441,7 @@ class CausalLMBatch(Batch): attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + merged_kv_cache=False, next_token_chooser=next_token_chooser, top_n_tokens=top_n_tokens, top_n_tokens_tensor=top_n_tokens_tensor, @@ -448,6 +530,7 @@ class CausalLMBatch(Batch): attention_mask=attention_mask, position_ids=position_ids, past_key_values=None, + merged_kv_cache=False, next_token_chooser=next_token_chooser, top_n_tokens=top_n_tokens, top_n_tokens_tensor=top_n_tokens_tensor, @@ -491,13 +574,6 @@ class CausalLMBatch(Batch): max_total_tokens = self.attention_mask.size(1) return len(self.requests) * max_total_tokens - def free_indices(self): - used = set(req.idx for req in self.requests) - for i in range(self.batch_size): - if i in used: - continue - yield i - class CausalLM(Model): def __init__(