diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 42c9356f..ab6b034d 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -45,6 +45,7 @@ if 'GRAPH_VISUALIZATION' in os.environ: for f in glob.glob('.graph_dumps/*'): os.remove(f) +MAX_TOTAL_TOKENS = int(os.getenv("MAX_TOTAL_TOKENS", "0")) BATCH_BUCKET_SIZE = int(os.environ.get('BATCH_BUCKET_SIZE', 8)) PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 128)) PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 4)) @@ -98,26 +99,34 @@ def move_data(dst_tensor, chunk_size, indices, src_tensors): return result -def shift(tensor, dim, offset): - shape = tensor.shape - elements = shape[dim] - if offset == 0 or abs(offset) > elements: - return tensor - htorch.core.mark_step() - # We generate indices from (0 - offset + elements) to (elements - offset + elements) - # so that next modulo operation operates on positive values - indices = torch.arange(0, elements, dtype=torch.int32, device=tensor.device) - offset = torch.tensor(-offset + elements, dtype=torch.int32, device=tensor.device) - indices.add_(offset) - indices.remainder_(elements) - target_shape = [1,] * len(tensor.shape) - target_shape[dim] = elements - indices = indices.view(target_shape).expand(shape) - result = torch.gather(tensor, dim, indices) - htorch.core.mark_step() +def generate_shift_chunks(offset): + chunk_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] + result = [] + while offset != 0: + sign = 1 if offset > 0 else -1 + best_chunk = min((abs(offset - sign * c), sign * c) for c in chunk_sizes)[1] + result.append(best_chunk) + offset = offset - best_chunk return result +def roll(tensor, dim, chunks): + dbg_trace('ROLL', f'shape:{list(tensor.shape)} dim:{dim} chunks:{chunks}') + for c in chunks: + tensor = torch.roll(tensor, c, dim) + htorch.core.mark_step() + return tensor + + +def shift(tensor, dim, offset): + assert dim < 0, 'Only negative dims are supported' + if offset == 0: + return tensor + chunks = generate_shift_chunks(offset) + tensor = roll(tensor, dim, chunks) + return tensor + + def shift_all(srcs, dim, offsets): return [shift(src, dim, offset) for src, offset in zip(srcs, offsets)] @@ -197,7 +206,6 @@ class CausalLMBatch(Batch): top_n_tokens_tensor: torch.Tensor input_length: int - right_padding: int def to_pb(self) -> generate_pb2.CachedBatch: return generate_pb2.CachedBatch( @@ -214,9 +222,13 @@ class CausalLMBatch(Batch): batch_id = batches[0].batch_id device = batches[0].input_ids.device - max_input_length = max(b.input_length for b in batches) + input_lengths = [b.input_length for b in batches] + max_input_length = max(input_lengths) offsets = [max_input_length - b.input_length for b in batches] padding = [b.right_padding for b in batches] + # For prefill there is a space allocated only for first token + # Need to add padding to the max total tokens before first decode + extra_padding = [MAX_TOTAL_TOKENS - b.seq_length for b in batches] moves_needed = [total_requests - len(b) if b.batch_size == new_bs else total_requests for b in batches] target_batch_idx = min(enumerate(moves_needed), key=lambda idx_val: idx_val[1])[0] @@ -225,9 +237,9 @@ class CausalLMBatch(Batch): # FIXME: max_seq_len for non optimized code if len(batches) > 1: scenario = 'CONCAT' - elif batches[0].batch_size != new_bs: + elif batches[target_batch_idx].batch_size != new_bs: scenario = 'RESHAPE' - elif padding[0] <= 0: + elif padding[target_batch_idx] <= 0: scenario = 'SHIFT' offsets = [b.max_input_length - max_input_length for b in batches] max_input_length = max(b.max_input_length for b in batches) @@ -235,9 +247,15 @@ class CausalLMBatch(Batch): # Nothing to do return batches[0] - inplace = batches[target_batch_idx].batch_size == new_bs + inplace = (batches[target_batch_idx].batch_size == new_bs) + dbg_trace( - scenario, f'bs:{[b.batch_size for b in batches]}->{new_bs} reqs:{[len(b) for b in batches]} offsets:{offsets} padding:{padding} moves_needed:{moves_needed} inplace:{inplace}') + scenario, f'bs:{[b.batch_size for b in batches]}->{new_bs}' + f' reqs:{[len(b) for b in batches]}' + f' offsets:{offsets}' + f' input_lengths:{input_lengths}' + f' cur_padding:{padding}' + f' inplace:{inplace}') grouped_requests = [[req for req in batch.requests] for batch in batches] flat_requests = list(itertools.chain(*grouped_requests)) @@ -256,7 +274,7 @@ class CausalLMBatch(Batch): num_layers = len(batches[0].past_key_values) past_key_values_type = type(batches[0].past_key_values) - seq_dim = 1 + seq_dim = -1 if batches[0].past_key_values[0][0].size(-1) != batches[0].past_key_values[0][1].size(-1): # Case for Bloom key_dim = -1 @@ -267,14 +285,10 @@ class CausalLMBatch(Batch): for b in batches: b.past_key_values = list(b.past_key_values) - # For prefill there is a space allocated only for first token - # Need to add padding to the max total tokens before first decode - paddings = [(batch.input_length + batch.right_padding) - batch.seq_length for batch in batches] - src = [b.input_ids for b in batches] for b in batches: del b.input_ids - src = pad_tensors(src, paddings, seq_dim, pad_token_id) + src = pad_tensors(src, extra_padding, seq_dim, pad_token_id) src = shift_all(src, seq_dim, offsets) input_ids = prepare_memory(new_bs, src[target_batch_idx], inplace) input_ids = move_data(input_ids, 1, indices, src) @@ -282,7 +296,7 @@ class CausalLMBatch(Batch): src = [b.attention_mask for b in batches] for b in batches: del b.attention_mask - src = pad_tensors(src, paddings, seq_dim, 0) + src = pad_tensors(src, extra_padding, seq_dim, 0) src = shift_all(src, seq_dim, offsets) attention_mask = prepare_memory(new_bs, src[target_batch_idx], inplace) attention_mask = move_data(attention_mask, 1, indices, src) @@ -290,29 +304,36 @@ class CausalLMBatch(Batch): src = [b.position_ids for b in batches] for b in batches: del b.position_ids - src = shift_all(src, seq_dim, offsets) position_ids = prepare_memory(new_bs, src[target_batch_idx], inplace) position_ids = move_data(position_ids, 1, indices, src) - past_key_values = [] - for layer_num in range(num_layers): - src = [b.past_key_values[layer_num][0] for b in batches] - src = pad_tensors(src, paddings, key_dim, 0) - src = shift_all(src, key_dim, offsets) - updated_key = prepare_memory(new_bs * chunk_size, src[target_batch_idx], inplace) - updated_key = move_data(updated_key, chunk_size, indices, src) + src = None + src_keys = [[b.past_key_values[layer_num][0] for layer_num in range(num_layers)] for b in batches] + src_values = [[b.past_key_values[layer_num][1] for layer_num in range(num_layers)] for b in batches] + for b in batches: + del b.past_key_values - src = [b.past_key_values[layer_num][1] for b in batches] - src = pad_tensors(src, paddings, value_dim, 0) - src = shift_all(src, value_dim, offsets) - updated_value = prepare_memory(new_bs * chunk_size, src[target_batch_idx], inplace) - updated_value = move_data(updated_value, chunk_size, indices, src) + src_keys = [torch.stack(src) for src in src_keys] + htorch.core.mark_step() + src_keys = pad_tensors(src_keys, extra_padding, key_dim, 0) + src_keys = shift_all(src_keys, key_dim, offsets) + src_keys = [[t.squeeze(0).clone() for t in torch.split(src, 1)] for src in src_keys] + htorch.core.mark_step() - past_key_values.append((updated_key, updated_value)) - for b in batches: - b.past_key_values[layer_num] = None + dst_keys = [prepare_memory(new_bs * chunk_size, prev, inplace) for prev in src_keys[target_batch_idx]] + dst_keys = [move_data(dst_keys[layer_num], chunk_size, indices, [src[layer_num] for src in src_keys]) for layer_num in range(num_layers)] - past_key_values = past_key_values_type(past_key_values) + src_values = [torch.stack(src) for src in src_values] + htorch.core.mark_step() + src_values = pad_tensors(src_values, extra_padding, value_dim, 0) + src_values = shift_all(src_values, value_dim, offsets) + src_values = [[t.squeeze(0).clone() for t in torch.split(src, 1)] for src in src_values] + htorch.core.mark_step() + + dst_values = [prepare_memory(new_bs * chunk_size, prev, inplace) for prev in src_values[target_batch_idx]] + dst_values = [move_data(dst_values[layer_num], chunk_size, indices, [src[layer_num] for src in src_values]) for layer_num in range(num_layers)] + + past_key_values = past_key_values_type(zip(dst_keys, dst_values)) top_n_tokens = [r.data.top_n_tokens for r in flat_requests] top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64) @@ -324,7 +345,6 @@ class CausalLMBatch(Batch): max_seq_len = attention_mask.size(1) input_length = max_input_length - right_padding = max_seq_len - input_length htorch.core.mark_step() @@ -339,7 +359,6 @@ class CausalLMBatch(Batch): top_n_tokens=top_n_tokens, top_n_tokens_tensor=top_n_tokens_tensor, input_length=input_length, - right_padding=right_padding ) @classmethod @@ -362,15 +381,6 @@ class CausalLMBatch(Batch): top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64) next_token_chooser = HeterogeneousNextTokenChooser.from_pb([r.parameters for r in pb.requests], dtype, device) - # TODO: this should be set to rust side `max_total_tokens`, - # (see https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs#L177) - # but TGI does not offer an API to expose this variable to python, as this variable - # is handled by the client but it appears the model is initialized by the server. - # An alternative could be to initialize the buffers during warmup. - # Dummy - max_total_tokens = int(os.getenv("MAX_TOTAL_TOKENS", "0")) - logger.info("MAX_TOTAL_TOKENS = {}".format(max_total_tokens)) - # TODO: by tokenizing all inputs at once we loose information on actual input lengths # this means that we cannot shift inputs to the left after a long input sequence # was filtered out @@ -394,10 +404,6 @@ class CausalLMBatch(Batch): bucket_size = round_up(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF) - 1 left_padding = bucket_size - input_len - extra_padding = 0 - if is_optimized_for_gaudi and max_total_tokens > 0: - extra_padding = max(extra_padding, max_total_tokens - (bucket_size + 1) - max_new_tokens) - input_ids = tokenized_inputs["input_ids"] attention_mask = tokenized_inputs["attention_mask"] @@ -410,7 +416,7 @@ class CausalLMBatch(Batch): attention_mask, (left_padding, 1), value=0 ) all_input_ids = torch.nn.functional.pad( - input_ids, (0, max_new_tokens + extra_padding), value=tokenizer.pad_token_id + input_ids, (0, max_new_tokens), value=tokenizer.pad_token_id ).T.split(1, dim=1) else: all_input_ids = input_ids.clone().T.split(1, dim=1) @@ -441,7 +447,6 @@ class CausalLMBatch(Batch): top_n_tokens=top_n_tokens, top_n_tokens_tensor=top_n_tokens_tensor, input_length=input_len, - right_padding=max_new_tokens + extra_padding if is_optimized_for_gaudi else 0 ) @tracer.start_as_current_span("filter") @@ -471,6 +476,10 @@ class CausalLMBatch(Batch): def seq_length(self): return self.attention_mask.size(1) + @property + def right_padding(self): + return self.seq_length - self.input_length + # Maximum number of tokens this batch will grow to @property def max_tokens(self): @@ -914,8 +923,6 @@ class CausalLM(Model): # Adjust lengths batch.input_length += 1 - if batch.right_padding > 0: - batch.right_padding -= 1 # Update position_ids if prefill: