From 381ec38cadc0600a7d51ad007d19d6245b292be0 Mon Sep 17 00:00:00 2001 From: madamczykhabana <110973826+madamczykhabana@users.noreply.github.com> Date: Wed, 17 Jan 2024 10:09:27 +0100 Subject: [PATCH] Batch bucketing improvements (#15) --- .../models/causal_lm.py | 135 ++++++++++++------ 1 file changed, 92 insertions(+), 43 deletions(-) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 61376640..ec0793f7 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -68,8 +68,11 @@ def round_up(number, k): return (number + k - 1) // k * k -def batch_alloc(new_bs, tensor): - return tensor.new_empty((new_bs,) + tensor.shape[1:]) +def prepare_memory(new_bs, tensor, inplace): + if inplace: + return tensor + else: + return tensor.new_empty((new_bs,) + tensor.shape[1:]) def move_data(dst_tensor, chunk_size, indices, src_tensors): @@ -154,9 +157,6 @@ class CausalLMBatch(Batch): top_n_tokens: List[int] top_n_tokens_tensor: torch.Tensor - # Maximum number of tokens this batch will grow to - max_tokens: int - input_length: int right_padding: int @@ -169,32 +169,53 @@ class CausalLMBatch(Batch): ) @classmethod - def recombine(cls, batches: List["CausalLMBatch"], req_ids: List[List[int]], is_optimized_for_gaudi: bool = False) -> "CausalLMBatch": - new_bs = round_up(sum([len(reqs) for reqs in req_ids]), BATCH_BUCKET_SIZE) + def recombine(cls, batches: List["CausalLMBatch"], is_optimized_for_gaudi: bool = False) -> "CausalLMBatch": + total_requests = sum(len(b) for b in batches) + new_bs = round_up(total_requests, BATCH_BUCKET_SIZE) batch_id = batches[0].batch_id device = batches[0].input_ids.device - # TODO: for now use consecutive indices. This could be optimized to reuse existing batch memory and only overwrite - # indices that are no longer used instead of allocating new memory - free_indices = itertools.count(0) - to_tensors = lambda ind: (torch.tensor(ind[0], device=device), torch.tensor(ind[1], device=device)) - requests = [[req for req in batch.requests if req.data.id in ids] for batch, ids in zip(batches, req_ids)] - indices = [[to_tensors(req.update_idx(next(free_indices))) for req in batch_reqs] for batch_reqs in requests] - requests = list(itertools.chain(*requests)) + max_input_length = max(b.input_length for b in batches) + offsets = [max_input_length - b.input_length for b in batches] + padding = [b.right_padding for b in batches] + + moves_needed = [total_requests - len(b) if b.batch_size == new_bs else total_requests for b in batches] + target_batch_idx = min(enumerate(moves_needed), key=lambda idx_val: idx_val[1])[0] # TODO: Add support for changing max seq len, i.e. due to output length bucketing # FIXME: max_seq_len for non optimized code - max_input_length = max(req.input_length for req in requests) - offsets = [(max_input_length - b.input_length) for b in batches] - scenario = 'CONCAT' if len(batches) > 1 else 'FILTER' - dbg_trace(scenario, f'bs:{[b.input_ids.size(0) for b in batches]}->{new_bs} num_reqs:{[len(b.requests) for b in batches]}->{len(requests)} offsets:{offsets}') + if len(batches) > 1: + scenario = 'CONCAT' + elif batches[0].batch_size != new_bs: + scenario = 'RESHAPE' + elif padding[0] <= 1: + scenario = 'SHIFT' + offsets = [b.max_input_length - max_input_length for b in batches] + max_input_length = max(b.max_input_length for b in batches) + else: + # Nothing to do + return batches[0] + + inplace = batches[target_batch_idx].batch_size == new_bs + dbg_trace(scenario, f'bs:{[b.batch_size for b in batches]}->{new_bs} reqs:{[len(b) for b in batches]} offsets:{offsets} padding:{padding} moves_needed:{moves_needed} inplace:{inplace}') + + grouped_requests = [[req for req in batch.requests] for batch in batches] + flat_requests = list(itertools.chain(*grouped_requests)) + if inplace and scenario != 'SHIFT': + # The data is already present in the batch. No need to move it + grouped_requests[target_batch_idx] = [] + free_indices = batches[target_batch_idx].free_indices() + else: + free_indices = itertools.count(0) + + to_tensors = lambda ind: (torch.tensor(ind[0], device=device), torch.tensor(ind[1], device=device)) + indices = [[to_tensors(req.update_idx(next(free_indices))) for req in batch_reqs] for batch_reqs in grouped_requests] max_seq_len = batches[0].attention_mask.size(1) - input_length = max(r.input_length for r in requests) + input_length = max_input_length right_padding = max_seq_len - input_length - max_tokens = len(requests) * max_seq_len - chunk_size = batches[0].past_key_values[0][0].size(0) // batches[0].input_ids.size(0) + chunk_size = batches[0].past_key_values[0][0].size(0) // batches[0].batch_size num_layers = len(batches[0].past_key_values) past_key_values_type = type(batches[0].past_key_values) @@ -213,33 +234,33 @@ class CausalLMBatch(Batch): for b in batches: del b.input_ids src = shift_all(src, seq_dim, offsets) - input_ids = batch_alloc(new_bs, src[0]) + input_ids = prepare_memory(new_bs, src[target_batch_idx], inplace) input_ids = move_data(input_ids, 1, indices, src) src = [b.attention_mask for b in batches] for b in batches: del b.attention_mask src = shift_all(src, seq_dim, offsets) - attention_mask = batch_alloc(new_bs, src[0]) + attention_mask = prepare_memory(new_bs, src[target_batch_idx], inplace) attention_mask = move_data(attention_mask, 1, indices, src) src = [b.position_ids for b in batches] for b in batches: del b.position_ids src = shift_all(src, seq_dim, offsets) - position_ids = batch_alloc(new_bs, src[0]) + position_ids = prepare_memory(new_bs, src[target_batch_idx], inplace) position_ids = move_data(position_ids, 1, indices, src) past_key_values = [] for layer_num in range(num_layers): src = [b.past_key_values[layer_num][0] for b in batches] src = shift_all(src, key_dim, offsets) - updated_key = batch_alloc(new_bs * chunk_size, src[0]) + updated_key = prepare_memory(new_bs * chunk_size, src[target_batch_idx], inplace) updated_key = move_data(updated_key, chunk_size, indices, src) src = [b.past_key_values[layer_num][1] for b in batches] src = shift_all(src, value_dim, offsets) - updated_value = batch_alloc(new_bs * chunk_size, src[0]) + updated_value = prepare_memory(new_bs * chunk_size, src[target_batch_idx], inplace) updated_value = move_data(updated_value, chunk_size, indices, src) past_key_values.append((updated_key, updated_value)) @@ -248,10 +269,10 @@ class CausalLMBatch(Batch): past_key_values = past_key_values_type(past_key_values) - top_n_tokens = [r.data.top_n_tokens for r in requests] + top_n_tokens = [r.data.top_n_tokens for r in flat_requests] top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - [r.data.parameters for r in requests], + [r.data.parameters for r in flat_requests], batches[0].next_token_chooser.device, batches[0].next_token_chooser.dtype ) @@ -260,7 +281,7 @@ class CausalLMBatch(Batch): return cls( batch_id=batch_id, - requests=requests, + requests=flat_requests, input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -268,7 +289,6 @@ class CausalLMBatch(Batch): next_token_chooser=next_token_chooser, top_n_tokens=top_n_tokens, top_n_tokens_tensor=top_n_tokens_tensor, - max_tokens=max_tokens, input_length=input_length, right_padding=right_padding ) @@ -327,9 +347,6 @@ class CausalLMBatch(Batch): r.prefix_offset = input_len - 5 r.read_offset = input_len - #max_tokens = new_bs * max_total_tokens - max_tokens = len(requests) * max_total_tokens - input_ids = tokenized_inputs["input_ids"] attention_mask = tokenized_inputs["attention_mask"] @@ -363,23 +380,50 @@ class CausalLMBatch(Batch): next_token_chooser=next_token_chooser, top_n_tokens=top_n_tokens, top_n_tokens_tensor=top_n_tokens_tensor, - max_tokens=max_tokens, input_length=max_input_length, right_padding=max_new_tokens + extra_padding if is_optimized_for_gaudi else 0 ) @tracer.start_as_current_span("filter") def filter(self, request_ids: List[int], is_optimized_for_gaudi: bool = False) -> Optional["CausalLMBatch"]: - return self.__class__.recombine([self], [request_ids], is_optimized_for_gaudi) + dbg_trace('FILTER', f'num_reqs:{len(self.requests)} -> {len(request_ids)}') + request_ids = set(request_ids) + self.requests = [req for req in self.requests if req.data.id in request_ids] + return self.__class__.recombine([self], is_optimized_for_gaudi) @classmethod @tracer.start_as_current_span("concatenate") def concatenate(cls, batches: List["CausalLMBatch"], is_optimized_for_gaudi: bool = False) -> "CausalLMBatch": - return cls.recombine(batches, [[req.data.id for req in b.requests] for b in batches], is_optimized_for_gaudi) + return cls.recombine(batches, is_optimized_for_gaudi) def __len__(self): return len(self.requests) + @property + def max_input_length(self): + return max(req.input_length for req in self.requests) + + @property + def batch_size(self): + return self.attention_mask.size(0) + + @property + def seq_length(self): + return self.attention_mask.size(1) + + # Maximum number of tokens this batch will grow to + @property + def max_tokens(self): + max_total_tokens = self.attention_mask.size(1) + return len(self.requests) * max_total_tokens + + def free_indices(self): + used = set(req.idx for req in self.requests) + for i in range(self.batch_size): + if i in used: + continue + yield i + class CausalLM(Model): def __init__( @@ -550,8 +594,12 @@ class CausalLM(Model): @tracer.start_as_current_span("generate_token") def generate_token(self, batch: CausalLMBatch) -> Tuple[List[Generation], Optional[CausalLMBatch]]: prefill = batch.past_key_values is None + # Check if we need to do any bookkeeping first + if not prefill: + batch = batch.__class__.recombine([batch], self.is_optimized_for_gaudi) + scenario = 'PREFILL' if prefill else 'GENERATE' - dbg_trace(scenario, f'bs:{batch.input_ids.size(0)} num_reqs:{len(batch.requests)} seq_len:{batch.input_ids.shape[1]}') + dbg_trace(scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length}') self.step = self.step + 1 if self.hb_profer_started == True and self.step > self.profiling_warmup_steps + self.profiling_steps: self.hb_profer.stop() @@ -605,21 +653,21 @@ class CausalLM(Model): next_token_ids_cpu = next_token_ids.cpu() htorch.core.mark_step() - for req in batch.requests: + for req_idx, req in enumerate(batch.requests): i = req.idx request = req.data input_length = req.input_length prefix_offset = req.prefix_offset read_offset = req.read_offset - do_sample = batch.next_token_chooser.do_sample[i] - seed = batch.next_token_chooser.seeds[i] + do_sample = batch.next_token_chooser.do_sample[req_idx] + seed = batch.next_token_chooser.seeds[req_idx] stopping_criteria = req.stopping_criteria all_input_ids = req.all_input_ids - top_n_tokens = batch.top_n_tokens[i] + top_n_tokens = batch.top_n_tokens[req_idx] next_token_id = next_token_ids_cpu[i] next_token_logprob = next_token_logprobs[i] - top_token_ids = batch_top_token_ids[i] - top_token_logprobs = batch_top_token_logprobs[i] + top_token_ids = batch_top_token_ids[req_idx] + top_token_logprobs = batch_top_token_logprobs[req_idx] # Append next token to all tokens if self.is_optimized_for_gaudi: @@ -717,6 +765,7 @@ class CausalLM(Model): if stopped: if self.hb_profer_started == True: self.hb_profer.step() + htorch.core.mark_step() return generations, None # Slice unused values from prefill, use it to store next token