From 1be2d9a8ec1835fcb36fd5b67ecced58cda7b991 Mon Sep 17 00:00:00 2001 From: Karol Damaszke Date: Fri, 22 Dec 2023 21:53:01 +0100 Subject: [PATCH] Batch size bucketing (#5) --- README.md | 1 + .../models/causal_lm.py | 754 +++++++----------- 2 files changed, 277 insertions(+), 478 deletions(-) diff --git a/README.md b/README.md index 480cc9a7..d033ffcd 100644 --- a/README.md +++ b/README.md @@ -74,6 +74,7 @@ Environment Variables Added: | PROF_STEP | interger | 5 | Control profile step | add -e in docker run command | | PROF_PATH | string | /root/text-generation-inference | Define profile folder | add -e in docker run command | | LIMIT_HPU_GRAPH | True/False | False | Skip HPU graph usage for prefill to save memory | add -e in docker run command | +| BATCH_BUCKET_SIZE | integer | 8 | Batch size will be rounded to the nearest multiple of this number. This limits the number of cached graphs | add -e in docker run command | diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 26be1875..8630bbd1 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -1,5 +1,6 @@ import os import tempfile +import itertools from text_generation_server.utils.tokens import batch_top_tokens import torch @@ -34,11 +35,98 @@ from loguru import logger tracer = trace.get_tracer(__name__) +BATCH_BUCKET_SIZE = int(os.environ.get('BATCH_BUCKET_SIZE', 8)) +TRACE_FILENAME = os.environ.get('TRACE_FILENAME') + +def trace(txt): + if TRACE_FILENAME is not None: + print(txt, flush=True, file=open(TRACE_FILENAME, 'a')) + + +def round_up(number, k): + return (number + k - 1) // k * k + + +def batch_alloc(new_bs, tensor): + return tensor.new_empty((new_bs,) + tensor.shape[1:]) + + +def to_tensors(indices, device): + def convert(idx): + return torch.tensor(idx, device=device) + return [[(convert(dst), convert(src)) for dst, src in batch_ind] for batch_ind in indices] + + +def move_data(dst_tensor, chunk_size, indices, src_tensors): + batch_dim = 0 + bs = dst_tensor.size(batch_dim) + assert bs % chunk_size == 0, 'Batch dim must be divisible by chunk size!' + result = dst_tensor + if chunk_size > 1: + dst_tensor = dst_tensor.view(bs // chunk_size, chunk_size, *dst_tensor.shape[1:]) + htorch.core.mark_step() + for ind, src_t in zip(indices, src_tensors): + if chunk_size > 1: + src_t = src_t.view(bs // chunk_size, chunk_size, *src_t.shape[1:]) + for dst_idx, src_idx in ind: + src_data = torch.index_select(src_t, batch_dim, src_idx) + dst_tensor.index_copy_(batch_dim, dst_idx, src_data) + htorch.core.mark_step() + return result + + +def shift(tensor, dim, offset): + shape = tensor.shape + elements = shape[dim] + if offset == 0 or abs(offset) > elements: + return tensor + htorch.core.mark_step() + indices = torch.arange(0, elements, dtype=torch.int32, device=tensor.device) + offset = torch.tensor(offset, dtype=torch.int32, device=tensor.device) + indices = torch.clamp(indices - offset, 0, elements - 1) + target_shape = [1,] * len(tensor.shape) + target_shape[dim] = elements + indices = indices.view(target_shape).expand(shape) + result = torch.gather(tensor, dim, indices) + htorch.core.mark_step() + return result + + +def shift_all(srcs, dim, offsets): + return [shift(src, dim, offset) for src, offset in zip(srcs, offsets)] + + +@dataclass +class CausalLMRequest: + idx: int + data: generate_pb2.Request + input_length: int + prefix_offset: int + read_offset: int + stopping_criteria: StoppingCriteria + + all_input_ids: torch.Tensor + + @classmethod + def from_pb(cls, idx: int, data: generate_pb2.Request, tokenizer: PreTrainedTokenizerBase): + return cls( + idx=idx, + data=data, + input_length=None, + prefix_offset=None, + read_offset=None, + stopping_criteria=StoppingCriteria.from_pb(data.stopping_parameters, tokenizer), + all_input_ids=None,) + + def update_idx(self, new_idx): + prev = self.idx + self.idx = new_idx + return (new_idx, prev) + @dataclass class CausalLMBatch(Batch): batch_id: int - requests: List[generate_pb2.Request] - requests_idx_mapping: Dict[int, int] + requests: List[CausalLMRequest] # Decoder values input_ids: torch.Tensor @@ -46,38 +134,126 @@ class CausalLMBatch(Batch): position_ids: torch.Tensor past_key_values: Optional[List[Tuple]] - # All tokens - all_input_ids: List[torch.Tensor] - - # Lengths of all generations present in the batch - input_lengths: List[int] - prefix_offsets: List[int] - read_offsets: List[int] - # Generation helpers next_token_chooser: HeterogeneousNextTokenChooser - stopping_criterias: List[StoppingCriteria] top_n_tokens: List[int] top_n_tokens_tensor: torch.Tensor - # Metadata used for padding - max_input_length: int - padding_right_offset: int - # Maximum number of tokens this batch will grow to max_tokens: int - # Past metadata - keys_head_dim_last: bool = True + input_length: int + right_padding: int def to_pb(self) -> generate_pb2.CachedBatch: return generate_pb2.CachedBatch( id=self.batch_id, - request_ids=[r.id for r in self.requests], + request_ids=[r.data.id for r in self.requests], size=len(self), max_tokens=self.max_tokens, ) + @classmethod + def recombine(cls, batches: List["CausalLMBatch"], req_ids: List[List[int]], is_optimized_for_gaudi: bool = False) -> "CausalLMBatch": + new_bs = round_up(sum([len(reqs) for reqs in req_ids]), BATCH_BUCKET_SIZE) + batch_id = batches[0].batch_id + device = batches[0].input_ids.device + + # TODO: for now use consecutive indices. This could be optimized to reuse existing batch memory and only overwrite + # indices that are no longer used instead of allocating new memory + free_indices = itertools.count(0) + to_tensors = lambda ind: (torch.tensor(ind[0], device=device), torch.tensor(ind[1], device=device)) + requests = [[req for req in batch.requests if req.data.id in ids] for batch, ids in zip(batches, req_ids)] + indices = [[to_tensors(req.update_idx(next(free_indices))) for req in batch_reqs] for batch_reqs in requests] + requests = list(itertools.chain(*requests)) + + # TODO: Add support for changing max seq len, i.e. due to output length bucketing + # FIXME: max_seq_len for non optimized code + max_input_length = max(req.input_length for req in requests) + offsets = [(max_input_length - b.input_length) for b in batches] + trace(f'RECOMBINE: bs:{new_bs} requests: {len(requests)} offsets: {offsets}') + + max_seq_len = batches[0].attention_mask.size(1) + input_length = max(r.input_length for r in requests) + right_padding = max_seq_len - input_length + max_tokens = len(requests) * max_seq_len + + chunk_size = batches[0].past_key_values[0][0].size(0) // batches[0].input_ids.size(0) + num_layers = len(batches[0].past_key_values) + past_key_values_type = type(batches[0].past_key_values) + + seq_dim = 1 + if batches[0].past_key_values[0][0].size(-1) != batches[0].past_key_values[0][1].size(-1): + # Case for Bloom + key_dim = -1 + else: + key_dim = -2 + value_dim = -2 + + for b in batches: + b.past_key_values = list(b.past_key_values) + + src = [b.input_ids for b in batches] + for b in batches: + del b.input_ids + src = shift_all(src, seq_dim, offsets) + input_ids = batch_alloc(new_bs, src[0]) + input_ids = move_data(input_ids, 1, indices, src) + + src = [b.attention_mask for b in batches] + for b in batches: + del b.attention_mask + src = shift_all(src, seq_dim, offsets) + attention_mask = batch_alloc(new_bs, src[0]) + attention_mask = move_data(attention_mask, 1, indices, src) + + src = [b.position_ids for b in batches] + for b in batches: + del b.position_ids + src = shift_all(src, seq_dim, offsets) + position_ids = batch_alloc(new_bs, src[0]) + position_ids = move_data(position_ids, 1, indices, src) + + past_key_values = [] + for layer_num in range(num_layers): + src = [b.past_key_values[layer_num][0] for b in batches] + src = shift_all(src, key_dim, offsets) + updated_key = batch_alloc(new_bs * chunk_size, src[0]) + updated_key = move_data(updated_key, chunk_size, indices, src) + + src = [b.past_key_values[layer_num][1] for b in batches] + src = shift_all(src, value_dim, offsets) + updated_value = batch_alloc(new_bs * chunk_size, src[0]) + updated_value = move_data(updated_value, chunk_size, indices, src) + + past_key_values.append((updated_key, updated_value)) + for b in batches: + b.past_key_values[layer_num] = None + + past_key_values = past_key_values_type(past_key_values) + + top_n_tokens = [r.data.top_n_tokens for r in requests] + top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64) + next_token_chooser = HeterogeneousNextTokenChooser.from_pb([r.data.parameters for r in requests], batches[0].next_token_chooser.device, batches[0].next_token_chooser.dtype) + + htorch.core.mark_step() + + return cls( + batch_id=batch_id, + requests=requests, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + next_token_chooser=next_token_chooser, + top_n_tokens=top_n_tokens, + top_n_tokens_tensor=top_n_tokens_tensor, + max_tokens=max_tokens, + input_length=input_length, + right_padding=right_padding + ) + + @classmethod def from_pb( cls, @@ -87,19 +263,16 @@ class CausalLMBatch(Batch): device: torch.device, is_optimized_for_gaudi: bool = False, ) -> "CausalLMBatch": - inputs = [] - next_token_chooser_parameters = [] - stopping_criterias = [] - top_n_tokens = [] - prefix_offsets = [] - read_offsets = [] - requests_idx_mapping = {} - input_lengths = [] + trace(f'NEW BATCH: ({len(pb.requests)}){[req.id for req in pb.requests]}') + requests = [CausalLMRequest.from_pb(idx, req, tokenizer) for idx, req in enumerate(pb.requests)] - # Parse batch - max_truncation = 0 - padding_right_offset = 0 - max_decode_tokens = 0 + max_input_length = max(r.data.truncate for r in requests) + max_new_tokens = max(r.stopping_criteria.max_new_tokens for r in requests) + + # TODO: Add support for sparse batches + top_n_tokens = [r.top_n_tokens for r in pb.requests] + top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64) + next_token_chooser = HeterogeneousNextTokenChooser.from_pb([r.parameters for r in pb.requests], dtype, device) # TODO: this should be set to rust side `max_total_tokens`, # (see https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs#L177) @@ -110,442 +283,81 @@ class CausalLMBatch(Batch): max_total_tokens = int(os.getenv("MAX_TOTAL_TOKENS", "0")) logger.info("MAX_TOTAL_TOKENS = {}".format(max_total_tokens)) - for i, r in enumerate(pb.requests): - requests_idx_mapping[r.id] = i - inputs.append(r.inputs) - next_token_chooser_parameters.append(r.parameters) - stopping_criteria = StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) - stopping_criterias.append(stopping_criteria) - top_n_tokens.append(r.top_n_tokens) - max_truncation = max(max_truncation, r.truncate) - max_decode_tokens += stopping_criteria.max_new_tokens - padding_right_offset = max(padding_right_offset, stopping_criteria.max_new_tokens) - - next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - next_token_chooser_parameters, dtype, device - ) - + # TODO: by tokenizing all inputs at once we loose information on actual input lengths + # this means that we cannot shift inputs to the left after a long input sequence + # was filtered out + new_bs = round_up(len(requests), BATCH_BUCKET_SIZE) + dummy_inputs = ["?"] * (new_bs - len(requests)) tokenized_inputs = tokenizer( - inputs, + [r.data.inputs for r in requests] + dummy_inputs, return_tensors="pt", padding="max_length", return_token_type_ids=False, truncation=True, - max_length=max_truncation, + max_length=max_input_length, ) - for _ in pb.requests: - input_len = tokenized_inputs["input_ids"].shape[1] - input_lengths.append(input_len) - prefix_offsets.append(input_len - 5) - read_offsets.append(input_len) + input_len = tokenized_inputs["input_ids"].shape[1] + extra_padding = 0 + if is_optimized_for_gaudi and max_total_tokens > 0: + extra_padding = max(extra_padding, max_total_tokens - max_input_length - max_new_tokens) - max_input_length = max(input_lengths) - if max_total_tokens == 0: - max_total_tokens = max_input_length - max_tokens = len(inputs) * max_input_length + max_decode_tokens - if is_optimized_for_gaudi and max_total_tokens > max_input_length: - # pad to max_total_tokens in case max_new_token changes per request and triggers new hpu graph generation - padding_right_offset = max_total_tokens - max_input_length + for r in requests: + r.input_length = input_len + r.prefix_offset = input_len - 5 + r.read_offset = input_len + + #max_tokens = new_bs * max_total_tokens + max_tokens = len(requests) * max_total_tokens input_ids = tokenized_inputs["input_ids"] attention_mask = tokenized_inputs["attention_mask"] - # only move model inputs to device - attention_mask = attention_mask.to(device) if is_optimized_for_gaudi: - input_ids_cpu = torch.nn.functional.pad( - input_ids, (0, padding_right_offset), value=tokenizer.pad_token_id + input_ids = torch.nn.functional.pad( + input_ids, (0, max_new_tokens + extra_padding), value=tokenizer.pad_token_id ) - input_ids = input_ids_cpu.to(device) - attention_mask = torch.nn.functional.pad(attention_mask, (0, padding_right_offset), value=0) - all_input_ids = input_ids_cpu.T.split(1, dim=1) + attention_mask = torch.nn.functional.pad( + attention_mask, (0, max_new_tokens + extra_padding), value=0) + all_input_ids = input_ids.T.split(1, dim=1) else: all_input_ids = input_ids.clone().T.split(1, dim=1) - input_ids = input_ids.to(device) + for r in requests: + r.all_input_ids = all_input_ids[r.idx] + + input_ids = input_ids.to(device) + attention_mask = attention_mask.to(device) position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) - htorch.core.mark_step() - top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64) + htorch.core.mark_step() return cls( batch_id=pb.id, - requests=pb.requests, - requests_idx_mapping=requests_idx_mapping, + requests=requests, input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=None, - all_input_ids=list(all_input_ids), - input_lengths=input_lengths, - prefix_offsets=prefix_offsets, - read_offsets=read_offsets, next_token_chooser=next_token_chooser, - stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, top_n_tokens_tensor=top_n_tokens_tensor, - max_input_length=max_input_length, - padding_right_offset=padding_right_offset, max_tokens=max_tokens, + input_length=max_input_length, + right_padding=max_new_tokens + extra_padding if is_optimized_for_gaudi else 0 ) @tracer.start_as_current_span("filter") def filter(self, request_ids: List[int], is_optimized_for_gaudi: bool = False) -> Optional["CausalLMBatch"]: - if len(request_ids) == 0: - raise ValueError("Batch must have at least one request") - if len(request_ids) == len(self): - return self - - keep_indices = [] - - # New values after filtering - requests_idx_mapping = {} - requests = [] - input_lengths = [] - prefix_offsets = [] - read_offsets = [] - all_input_ids = [] - max_input_length = 0 - - stopping_criterias = [] - top_n_tokens = [] - - total_remaining_decode_tokens = 0 - new_padding_right_offset = 0 - - for i, request_id in enumerate(request_ids): - idx = self.requests_idx_mapping[request_id] - requests_idx_mapping[request_id] = i - keep_indices.append(idx) - - requests.append(self.requests[idx]) - prefix_offsets.append(self.prefix_offsets[idx]) - read_offsets.append(self.read_offsets[idx]) - all_input_ids.append(self.all_input_ids[idx]) - - request_input_length = self.input_lengths[idx] - input_lengths.append(request_input_length) - max_input_length = max(max_input_length, request_input_length) - - stopping_criteria = self.stopping_criterias[idx] - stopping_criterias.append(stopping_criteria) - top_n_tokens.append(self.top_n_tokens[idx]) - remaining_decode_tokens = stopping_criteria.max_new_tokens - stopping_criteria.current_tokens - total_remaining_decode_tokens += remaining_decode_tokens - new_padding_right_offset = max(new_padding_right_offset, remaining_decode_tokens) - - # Apply indices to input_ids, attention mask, past key values and other items that need to be cached - input_ids = self.input_ids[keep_indices] - position_ids = self.position_ids[keep_indices] - next_token_chooser = self.next_token_chooser.filter(keep_indices) - if is_optimized_for_gaudi: - self.attention_mask = self.attention_mask[keep_indices] - else: - self.attention_mask = self.attention_mask[ - keep_indices, - -(self.padding_right_offset + max_input_length) : ( - self.attention_mask.shape[1] - self.padding_right_offset - ) - + new_padding_right_offset, - ] - - # Ensure that past_key_values tensors can be updated in-place - kv_tuple = False - if type(self.past_key_values[0]) == tuple: - self.past_key_values = [list(layer) for layer in self.past_key_values] - kv_tuple = True - - # Update tensors in-place to allow incremental garbage collection - past_kv_length = max_input_length - 1 - for layer in self.past_key_values: - past_keys, past_values = layer - past_keys_dims = len(past_keys.shape) - if past_keys_dims == 3: - # Force past to be of dim [self_size, num_heads, ...] for easy indexing - past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:]) - past_values = past_values.view(len(self), -1, *past_values.shape[-2:]) - if is_optimized_for_gaudi: - layer[0] = past_keys[keep_indices] - del past_keys - layer[1] = past_values[keep_indices] - del past_values - else: - if self.keys_head_dim_last: - layer[0] = past_keys[keep_indices, :, -past_kv_length:, :] - else: - layer[0] = past_keys[keep_indices, :, :, -past_kv_length:] - del past_keys - layer[1] = past_values[keep_indices, :, -past_kv_length:, :] - del past_values - if past_keys_dims == 3: - layer[0] = layer[0].view(layer[0].shape[0] * layer[0].shape[1], *layer[0].shape[-2:]) - layer[1] = layer[1].view(layer[1].shape[0] * layer[1].shape[1], *layer[1].shape[-2:]) - - top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices] - max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens - - if kv_tuple: - self.past_key_values = tuple([tuple(layer) for layer in self.past_key_values]) - - self.requests = requests - self.requests_idx_mapping = requests_idx_mapping - self.input_ids = input_ids - self.position_ids = position_ids - self.all_input_ids = all_input_ids - self.input_lengths = input_lengths - self.prefix_offsets = prefix_offsets - self.read_offsets = read_offsets - self.next_token_chooser = next_token_chooser - self.stopping_criterias = stopping_criterias - self.top_n_tokens = top_n_tokens - self.top_n_tokens_tensor = top_n_tokens_tensor - self.max_input_length = max_input_length - self.padding_right_offset = new_padding_right_offset - self.max_tokens = max_tokens - - return self + trace("FILTER") + return self.__class__.recombine([self], [request_ids], is_optimized_for_gaudi) @classmethod @tracer.start_as_current_span("concatenate") def concatenate(cls, batches: List["CausalLMBatch"], is_optimized_for_gaudi: bool = False) -> "CausalLMBatch": - # Used for padding - total_batch_size = 0 - max_input_length = 0 - padding_right_offset = 0 - max_total_tokens = 0 - for batch in batches: - total_batch_size += len(batch) - max_input_length = max(max_input_length, batch.max_input_length) - padding_right_offset = max(padding_right_offset, batch.padding_right_offset) - max_total_tokens = max(max_total_tokens, batch.max_input_length + batch.padding_right_offset) - - if is_optimized_for_gaudi and max_total_tokens > max_input_length: - padding_right_offset = max_total_tokens - max_input_length - - # Batch attributes - requests = [] - requests_idx_mapping = {} - input_lengths = [] - prefix_offsets = [] - read_offsets = [] - all_input_ids = [] - next_token_chooser_parameters = [] - stopping_criterias = [] - top_n_tokens = [] - max_tokens = 0 - - # Batch tensors - input_ids = None - attention_mask = None - position_ids = None - past_key_values = [] - top_n_tokens_tensor = None - - # Used for slicing correctly inside the tensors - # Equivalent to a cumsum on batch sizes - start_index = 0 - for i, batch in enumerate(batches): - requests.extend(batch.requests) - input_lengths.extend(batch.input_lengths) - prefix_offsets.extend(batch.prefix_offsets) - read_offsets.extend(batch.read_offsets) - all_input_ids.extend(batch.all_input_ids) - next_token_chooser_parameters.extend([r.parameters for r in batch.requests]) - stopping_criterias.extend(batch.stopping_criterias) - top_n_tokens.extend(batch.top_n_tokens) - - if i == 0: - requests_idx_mapping = batch.requests_idx_mapping - else: - # We need to offset the mapping for each batch by the cumulative batch size - for k, v in batch.requests_idx_mapping.items(): - requests_idx_mapping[k] = v + start_index - - # Slicing end index for this batch - end_index = start_index + len(batch) - - # We only concatenate batches that did at least one step - if batch.past_key_values is None: - raise ValueError("only concatenate prefilled batches") - - # Create empty tensor - # input_ids is always of shape [batch_size, 1] - # We do not need to pad it - if input_ids is None: - input_ids = batch.input_ids.new_empty((total_batch_size, max_total_tokens)) - # Copy to correct indices - input_ids[start_index:end_index] = batch.input_ids - - # Create padded tensor - if attention_mask is None: - attention_mask = batch.attention_mask.new_zeros( - (total_batch_size, max_input_length + padding_right_offset), - ) - - if top_n_tokens_tensor is None: - top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros( - total_batch_size, - ) - top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor - - # We need to slice the attention mask to remove padding from previous steps - # and to remove unused allocated space - left_offset = max_input_length - batch.max_input_length - batch_left_offset = batch.attention_mask.shape[1] - batch.max_input_length - batch.padding_right_offset - attention_mask[start_index:end_index, left_offset:-padding_right_offset] = batch.attention_mask[ - :, - batch_left_offset : -batch.padding_right_offset, - ] - - # Create empty tensor - # position_ids is always of shape [batch_size, 1] - if position_ids is None: - position_ids = batch.position_ids.new_empty((total_batch_size, 1)) - position_ids[start_index:end_index] = batch.position_ids - - # Shenanigans to get dimensions because BLOOM outputs a past with a different shape - # BLOOM Keys: [batch_size * num_heads, head_dim, seq_length] - # BLOOM Values: [batch_size * num_heads, seq_length, head_dim] - # And ensure that we can update tensors in-place - kv_tuple = False - past_key_values_dims = len(batch.past_key_values[0][0].shape) - if type(batch.past_key_values[0]) == tuple: - batch.past_key_values = [ - [t.view(len(batch), -1, *t.shape[-2:]) for t in layer] for layer in batch.past_key_values - ] - kv_tuple = True - elif past_key_values_dims == 3: - for layer in batch.past_key_values: - for k, t in enumerate(layer): - layer[k] = t.view(len(batch), -1, *t.shape[-2:]) - - # Add eventual padding tokens that were added while concatenating - max_tokens += batch.max_tokens + (max_input_length - batch.max_input_length) * len(batch) - - start_index = end_index - - next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - next_token_chooser_parameters, - dtype=batches[0].next_token_chooser.dtype, - device=batches[0].next_token_chooser.device, - ) - - first_past_kvs = batches[0].past_key_values - _, num_heads, _, head_dim = first_past_kvs[0][1].shape - padded_sequence_length = ( - max_input_length + padding_right_offset if is_optimized_for_gaudi else max_input_length - 1 - ) - padded_past_values_shape = ( - total_batch_size, - num_heads, - padded_sequence_length, - head_dim, - ) - - if batches[0].keys_head_dim_last: - padded_past_keys_shape = padded_past_values_shape - else: - # seq_length is last for BLOOM - padded_past_keys_shape = ( - total_batch_size, - num_heads, - head_dim, - padded_sequence_length, - ) - - # Iterate over attention layers - # Concatenate past key values layer by layer to allow incremental garbage collection - for j in range(len(first_past_kvs)): - padded_past_keys = first_past_kvs[j][0].new_zeros(padded_past_keys_shape) - start_index = 0 - for batch in batches: - past_keys = batch.past_key_values[j][0] - # Clear reference to the original tensor - batch.past_key_values[j][0] = None - - # Slicing end index for this batch - end_index = start_index + len(batch) - # We slice the keys to remove the padding from previous batches - past_seq_len = batch.max_input_length - 1 - # recaculate the offset - left_offset = max_input_length - batch.max_input_length - batch_left_offset = batch.attention_mask.shape[1] - batch.max_input_length - batch.padding_right_offset - - if batch.keys_head_dim_last: - padded_past_keys[ - start_index:end_index, :, left_offset : left_offset + past_seq_len, : - ] = past_keys[:, :, batch_left_offset : batch_left_offset + past_seq_len, :] - else: - # BLOOM case - padded_past_keys[ - start_index:end_index, :, :, left_offset : left_offset + past_seq_len - ] = past_keys[:, :, :, batch_left_offset : batch_left_offset + past_seq_len] - del past_keys - - start_index = end_index - - padded_past_values = first_past_kvs[j][1].new_zeros(padded_past_values_shape) - start_index = 0 - for batch in batches: - past_values = batch.past_key_values[j][1] - # Clear reference to the original tensor - batch.past_key_values[j][1] = None - - # Slicing end index for this batch - end_index = start_index + len(batch) - # We slice the past values to remove the padding from previous batches - past_seq_len = batch.max_input_length - 1 - # recaculate the offset - left_offset = max_input_length - batch.max_input_length - batch_left_offset = batch.attention_mask.shape[1] - batch.max_input_length - batch.padding_right_offset - - padded_past_values[ - start_index:end_index, :, left_offset : left_offset + past_seq_len, : - ] = past_values[:, :, batch_left_offset : batch_left_offset + past_seq_len, :] - del past_values - - # Update values - start_index = end_index - - if past_key_values_dims == 3: - padded_past_keys = padded_past_keys.view( - padded_past_keys.shape[0] * padded_past_keys.shape[1], *padded_past_keys.shape[-2:] - ) - padded_past_values = padded_past_values.view( - padded_past_values.shape[0] * padded_past_values.shape[1], *padded_past_values.shape[-2:] - ) - - if kv_tuple: - past_key_values.append((padded_past_keys, padded_past_values)) - else: - past_key_values.append([padded_past_keys, padded_past_values]) - - if kv_tuple: - past_key_values = tuple(past_key_values) - - return cls( - batch_id=batches[0].batch_id, - requests=requests, - requests_idx_mapping=requests_idx_mapping, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - all_input_ids=all_input_ids, - input_lengths=input_lengths, - prefix_offsets=prefix_offsets, - read_offsets=read_offsets, - next_token_chooser=next_token_chooser, - stopping_criterias=stopping_criterias, - top_n_tokens=top_n_tokens, - top_n_tokens_tensor=top_n_tokens_tensor, - max_input_length=max_input_length, - padding_right_offset=padding_right_offset, - keys_head_dim_last=batches[0].keys_head_dim_last, - max_tokens=max_tokens, - ) + trace('CONCAT') + return cls.recombine(batches, [[req.data.id for req in b.requests] for b in batches], is_optimized_for_gaudi) def __len__(self): return len(self.requests) @@ -719,18 +531,19 @@ class CausalLM(Model): @tracer.start_as_current_span("generate_token") def generate_token(self, batch: CausalLMBatch) -> Tuple[List[Generation], Optional[CausalLMBatch]]: + trace(f'GENERATE ({len(batch.requests)}){[r.data.id for r in batch.requests]}, {batch.input_ids.shape}') self.step = self.step + 1 if self.hb_profer_started == True and self.step > self.profiling_warmup_steps + self.profiling_steps: self.hb_profer.stop() self.hb_profer_started = False if self.is_optimized_for_gaudi: - token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.padding_right_offset).to(self.device) + token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device) attention_mask = batch.attention_mask - else: token_idx = None # slice the attention mask to the correct shape + # TODO fix me! attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] prefill = batch.past_key_values is None if batch.past_key_values: @@ -753,13 +566,13 @@ class CausalLM(Model): stopped = True # Select next token - input_length = batch.input_lengths[0] + input_length = batch.input_length if self.is_optimized_for_gaudi and logits.shape[-2] > 1: - next_input_ids, next_token_logprobs, logprobs = batch.next_token_chooser( + next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser( batch.input_ids[:, :token_idx], logits[:, input_length - 1 : input_length, :].squeeze(-2) ) else: - next_input_ids, next_token_logprobs, logprobs = batch.next_token_chooser( + next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser( batch.input_ids[:, :token_idx], logits.squeeze(-2) ) @@ -769,46 +582,26 @@ class CausalLM(Model): logprobs, ) - htorch.core.mark_step() - logits = logits.to("cpu") - next_token_logprobs = next_token_logprobs.tolist() - next_token_ids = next_input_ids + next_token_ids_cpu = next_token_ids.cpu() + htorch.core.mark_step() + + for req in batch.requests: + i = req.idx + request = req.data + input_length = req.input_length + prefix_offset = req.prefix_offset + read_offset = req.read_offset + do_sample = batch.next_token_chooser.do_sample[i] + seed = batch.next_token_chooser.seeds[i] + stopping_criteria = req.stopping_criteria + all_input_ids = req.all_input_ids + top_n_tokens = batch.top_n_tokens[i] + next_token_id = next_token_ids_cpu[i] + next_token_logprob = next_token_logprobs[i] + top_token_ids = batch_top_token_ids[i] + top_token_logprobs = batch_top_token_logprobs[i] - # Zipped iterator - iterator = zip( - batch.requests, - batch.input_lengths, - batch.prefix_offsets, - batch.read_offsets, - logits, - batch.next_token_chooser.do_sample, - batch.next_token_chooser.seeds, - batch.stopping_criterias, - batch.all_input_ids, - batch.top_n_tokens, - next_token_ids, - next_token_logprobs, - batch_top_token_ids, - batch_top_token_logprobs, - ) - # For each member of the batch - for i, ( - request, - input_length, - prefix_offset, - read_offset, - logits, - do_sample, - seed, - stopping_criteria, - all_input_ids, - top_n_tokens, - next_token_id, - next_token_logprob, - top_token_ids, - top_token_logprobs, - ) in enumerate(iterator): # Append next token to all tokens if self.is_optimized_for_gaudi: all_input_ids[input_length] = next_token_id @@ -890,16 +683,17 @@ class CausalLM(Model): generations.append(generation) - batch.all_input_ids[i] = all_input_ids - batch.input_lengths[i] = new_input_length - batch.prefix_offsets[i] = prefix_offset - batch.read_offsets[i] = read_offset - batch.max_input_length = max(batch.max_input_length, new_input_length) + req.all_input_ids = all_input_ids + req.input_length = new_input_length + req.prefix_offset = prefix_offset + req.read_offset = read_offset + htorch.core.mark_step() if token_idx is None: batch.input_ids[:, 0] = next_token_ids[:, 0] else: - batch.input_ids[:, token_idx] = next_token_ids + batch.input_ids.index_copy_(1, token_idx.cpu(), next_token_ids.unsqueeze(1)) + # We finished all generations in the batch; there is no next batch if stopped: if self.hb_profer_started == True: @@ -915,8 +709,11 @@ class CausalLM(Model): batch.attention_mask.index_fill_(1, token_idx, 1) else: batch.attention_mask[:, -batch.padding_right_offset] = 1 - # Decrease right offset - batch.padding_right_offset -= 1 + + # Adjust lengths + batch.input_length += 1 + if batch.right_padding > 0: + batch.right_padding -= 1 # Update position_ids if prefill: @@ -927,5 +724,6 @@ class CausalLM(Model): batch.past_key_values = past if self.hb_profer_started == True: self.hb_profer.step() + htorch.core.mark_step() return generations, batch