Bulk shifting (#40) (#70)

Co-authored-by: madamczykhabana <110973826+madamczykhabana@users.noreply.github.com>
This commit is contained in:
jkaniecki 2024-02-26 17:29:56 +01:00 committed by GitHub
parent 8f4aba6ad3
commit 83b059bd27
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -45,6 +45,7 @@ if 'GRAPH_VISUALIZATION' in os.environ:
for f in glob.glob('.graph_dumps/*'): for f in glob.glob('.graph_dumps/*'):
os.remove(f) os.remove(f)
MAX_TOTAL_TOKENS = int(os.getenv("MAX_TOTAL_TOKENS", "0"))
BATCH_BUCKET_SIZE = int(os.environ.get('BATCH_BUCKET_SIZE', 8)) BATCH_BUCKET_SIZE = int(os.environ.get('BATCH_BUCKET_SIZE', 8))
PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 128)) PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 128))
PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 4)) PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 4))
@ -98,26 +99,34 @@ def move_data(dst_tensor, chunk_size, indices, src_tensors):
return result return result
def shift(tensor, dim, offset): def generate_shift_chunks(offset):
shape = tensor.shape chunk_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
elements = shape[dim] result = []
if offset == 0 or abs(offset) > elements: while offset != 0:
return tensor sign = 1 if offset > 0 else -1
htorch.core.mark_step() best_chunk = min((abs(offset - sign * c), sign * c) for c in chunk_sizes)[1]
# We generate indices from (0 - offset + elements) to (elements - offset + elements) result.append(best_chunk)
# so that next modulo operation operates on positive values offset = offset - best_chunk
indices = torch.arange(0, elements, dtype=torch.int32, device=tensor.device)
offset = torch.tensor(-offset + elements, dtype=torch.int32, device=tensor.device)
indices.add_(offset)
indices.remainder_(elements)
target_shape = [1,] * len(tensor.shape)
target_shape[dim] = elements
indices = indices.view(target_shape).expand(shape)
result = torch.gather(tensor, dim, indices)
htorch.core.mark_step()
return result return result
def roll(tensor, dim, chunks):
dbg_trace('ROLL', f'shape:{list(tensor.shape)} dim:{dim} chunks:{chunks}')
for c in chunks:
tensor = torch.roll(tensor, c, dim)
htorch.core.mark_step()
return tensor
def shift(tensor, dim, offset):
assert dim < 0, 'Only negative dims are supported'
if offset == 0:
return tensor
chunks = generate_shift_chunks(offset)
tensor = roll(tensor, dim, chunks)
return tensor
def shift_all(srcs, dim, offsets): def shift_all(srcs, dim, offsets):
return [shift(src, dim, offset) for src, offset in zip(srcs, offsets)] return [shift(src, dim, offset) for src, offset in zip(srcs, offsets)]
@ -197,7 +206,6 @@ class CausalLMBatch(Batch):
top_n_tokens_tensor: torch.Tensor top_n_tokens_tensor: torch.Tensor
input_length: int input_length: int
right_padding: int
def to_pb(self) -> generate_pb2.CachedBatch: def to_pb(self) -> generate_pb2.CachedBatch:
return generate_pb2.CachedBatch( return generate_pb2.CachedBatch(
@ -214,9 +222,13 @@ class CausalLMBatch(Batch):
batch_id = batches[0].batch_id batch_id = batches[0].batch_id
device = batches[0].input_ids.device device = batches[0].input_ids.device
max_input_length = max(b.input_length for b in batches) input_lengths = [b.input_length for b in batches]
max_input_length = max(input_lengths)
offsets = [max_input_length - b.input_length for b in batches] offsets = [max_input_length - b.input_length for b in batches]
padding = [b.right_padding for b in batches] padding = [b.right_padding for b in batches]
# For prefill there is a space allocated only for first token
# Need to add padding to the max total tokens before first decode
extra_padding = [MAX_TOTAL_TOKENS - b.seq_length for b in batches]
moves_needed = [total_requests - len(b) if b.batch_size == new_bs else total_requests for b in batches] moves_needed = [total_requests - len(b) if b.batch_size == new_bs else total_requests for b in batches]
target_batch_idx = min(enumerate(moves_needed), key=lambda idx_val: idx_val[1])[0] target_batch_idx = min(enumerate(moves_needed), key=lambda idx_val: idx_val[1])[0]
@ -225,9 +237,9 @@ class CausalLMBatch(Batch):
# FIXME: max_seq_len for non optimized code # FIXME: max_seq_len for non optimized code
if len(batches) > 1: if len(batches) > 1:
scenario = 'CONCAT' scenario = 'CONCAT'
elif batches[0].batch_size != new_bs: elif batches[target_batch_idx].batch_size != new_bs:
scenario = 'RESHAPE' scenario = 'RESHAPE'
elif padding[0] <= 0: elif padding[target_batch_idx] <= 0:
scenario = 'SHIFT' scenario = 'SHIFT'
offsets = [b.max_input_length - max_input_length for b in batches] offsets = [b.max_input_length - max_input_length for b in batches]
max_input_length = max(b.max_input_length for b in batches) max_input_length = max(b.max_input_length for b in batches)
@ -235,9 +247,15 @@ class CausalLMBatch(Batch):
# Nothing to do # Nothing to do
return batches[0] return batches[0]
inplace = batches[target_batch_idx].batch_size == new_bs inplace = (batches[target_batch_idx].batch_size == new_bs)
dbg_trace( dbg_trace(
scenario, f'bs:{[b.batch_size for b in batches]}->{new_bs} reqs:{[len(b) for b in batches]} offsets:{offsets} padding:{padding} moves_needed:{moves_needed} inplace:{inplace}') scenario, f'bs:{[b.batch_size for b in batches]}->{new_bs}'
f' reqs:{[len(b) for b in batches]}'
f' offsets:{offsets}'
f' input_lengths:{input_lengths}'
f' cur_padding:{padding}'
f' inplace:{inplace}')
grouped_requests = [[req for req in batch.requests] for batch in batches] grouped_requests = [[req for req in batch.requests] for batch in batches]
flat_requests = list(itertools.chain(*grouped_requests)) flat_requests = list(itertools.chain(*grouped_requests))
@ -256,7 +274,7 @@ class CausalLMBatch(Batch):
num_layers = len(batches[0].past_key_values) num_layers = len(batches[0].past_key_values)
past_key_values_type = type(batches[0].past_key_values) past_key_values_type = type(batches[0].past_key_values)
seq_dim = 1 seq_dim = -1
if batches[0].past_key_values[0][0].size(-1) != batches[0].past_key_values[0][1].size(-1): if batches[0].past_key_values[0][0].size(-1) != batches[0].past_key_values[0][1].size(-1):
# Case for Bloom # Case for Bloom
key_dim = -1 key_dim = -1
@ -267,14 +285,10 @@ class CausalLMBatch(Batch):
for b in batches: for b in batches:
b.past_key_values = list(b.past_key_values) b.past_key_values = list(b.past_key_values)
# For prefill there is a space allocated only for first token
# Need to add padding to the max total tokens before first decode
paddings = [(batch.input_length + batch.right_padding) - batch.seq_length for batch in batches]
src = [b.input_ids for b in batches] src = [b.input_ids for b in batches]
for b in batches: for b in batches:
del b.input_ids del b.input_ids
src = pad_tensors(src, paddings, seq_dim, pad_token_id) src = pad_tensors(src, extra_padding, seq_dim, pad_token_id)
src = shift_all(src, seq_dim, offsets) src = shift_all(src, seq_dim, offsets)
input_ids = prepare_memory(new_bs, src[target_batch_idx], inplace) input_ids = prepare_memory(new_bs, src[target_batch_idx], inplace)
input_ids = move_data(input_ids, 1, indices, src) input_ids = move_data(input_ids, 1, indices, src)
@ -282,7 +296,7 @@ class CausalLMBatch(Batch):
src = [b.attention_mask for b in batches] src = [b.attention_mask for b in batches]
for b in batches: for b in batches:
del b.attention_mask del b.attention_mask
src = pad_tensors(src, paddings, seq_dim, 0) src = pad_tensors(src, extra_padding, seq_dim, 0)
src = shift_all(src, seq_dim, offsets) src = shift_all(src, seq_dim, offsets)
attention_mask = prepare_memory(new_bs, src[target_batch_idx], inplace) attention_mask = prepare_memory(new_bs, src[target_batch_idx], inplace)
attention_mask = move_data(attention_mask, 1, indices, src) attention_mask = move_data(attention_mask, 1, indices, src)
@ -290,29 +304,36 @@ class CausalLMBatch(Batch):
src = [b.position_ids for b in batches] src = [b.position_ids for b in batches]
for b in batches: for b in batches:
del b.position_ids del b.position_ids
src = shift_all(src, seq_dim, offsets)
position_ids = prepare_memory(new_bs, src[target_batch_idx], inplace) position_ids = prepare_memory(new_bs, src[target_batch_idx], inplace)
position_ids = move_data(position_ids, 1, indices, src) position_ids = move_data(position_ids, 1, indices, src)
past_key_values = [] src = None
for layer_num in range(num_layers): src_keys = [[b.past_key_values[layer_num][0] for layer_num in range(num_layers)] for b in batches]
src = [b.past_key_values[layer_num][0] for b in batches] src_values = [[b.past_key_values[layer_num][1] for layer_num in range(num_layers)] for b in batches]
src = pad_tensors(src, paddings, key_dim, 0) for b in batches:
src = shift_all(src, key_dim, offsets) del b.past_key_values
updated_key = prepare_memory(new_bs * chunk_size, src[target_batch_idx], inplace)
updated_key = move_data(updated_key, chunk_size, indices, src)
src = [b.past_key_values[layer_num][1] for b in batches] src_keys = [torch.stack(src) for src in src_keys]
src = pad_tensors(src, paddings, value_dim, 0) htorch.core.mark_step()
src = shift_all(src, value_dim, offsets) src_keys = pad_tensors(src_keys, extra_padding, key_dim, 0)
updated_value = prepare_memory(new_bs * chunk_size, src[target_batch_idx], inplace) src_keys = shift_all(src_keys, key_dim, offsets)
updated_value = move_data(updated_value, chunk_size, indices, src) src_keys = [[t.squeeze(0).clone() for t in torch.split(src, 1)] for src in src_keys]
htorch.core.mark_step()
past_key_values.append((updated_key, updated_value)) dst_keys = [prepare_memory(new_bs * chunk_size, prev, inplace) for prev in src_keys[target_batch_idx]]
for b in batches: dst_keys = [move_data(dst_keys[layer_num], chunk_size, indices, [src[layer_num] for src in src_keys]) for layer_num in range(num_layers)]
b.past_key_values[layer_num] = None
past_key_values = past_key_values_type(past_key_values) src_values = [torch.stack(src) for src in src_values]
htorch.core.mark_step()
src_values = pad_tensors(src_values, extra_padding, value_dim, 0)
src_values = shift_all(src_values, value_dim, offsets)
src_values = [[t.squeeze(0).clone() for t in torch.split(src, 1)] for src in src_values]
htorch.core.mark_step()
dst_values = [prepare_memory(new_bs * chunk_size, prev, inplace) for prev in src_values[target_batch_idx]]
dst_values = [move_data(dst_values[layer_num], chunk_size, indices, [src[layer_num] for src in src_values]) for layer_num in range(num_layers)]
past_key_values = past_key_values_type(zip(dst_keys, dst_values))
top_n_tokens = [r.data.top_n_tokens for r in flat_requests] top_n_tokens = [r.data.top_n_tokens for r in flat_requests]
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64) top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
@ -324,7 +345,6 @@ class CausalLMBatch(Batch):
max_seq_len = attention_mask.size(1) max_seq_len = attention_mask.size(1)
input_length = max_input_length input_length = max_input_length
right_padding = max_seq_len - input_length
htorch.core.mark_step() htorch.core.mark_step()
@ -339,7 +359,6 @@ class CausalLMBatch(Batch):
top_n_tokens=top_n_tokens, top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor, top_n_tokens_tensor=top_n_tokens_tensor,
input_length=input_length, input_length=input_length,
right_padding=right_padding
) )
@classmethod @classmethod
@ -362,15 +381,6 @@ class CausalLMBatch(Batch):
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64) top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb([r.parameters for r in pb.requests], dtype, device) next_token_chooser = HeterogeneousNextTokenChooser.from_pb([r.parameters for r in pb.requests], dtype, device)
# TODO: this should be set to rust side `max_total_tokens`,
# (see https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs#L177)
# but TGI does not offer an API to expose this variable to python, as this variable
# is handled by the client but it appears the model is initialized by the server.
# An alternative could be to initialize the buffers during warmup.
# Dummy
max_total_tokens = int(os.getenv("MAX_TOTAL_TOKENS", "0"))
logger.info("MAX_TOTAL_TOKENS = {}".format(max_total_tokens))
# TODO: by tokenizing all inputs at once we loose information on actual input lengths # TODO: by tokenizing all inputs at once we loose information on actual input lengths
# this means that we cannot shift inputs to the left after a long input sequence # this means that we cannot shift inputs to the left after a long input sequence
# was filtered out # was filtered out
@ -394,10 +404,6 @@ class CausalLMBatch(Batch):
bucket_size = round_up(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF) - 1 bucket_size = round_up(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF) - 1
left_padding = bucket_size - input_len left_padding = bucket_size - input_len
extra_padding = 0
if is_optimized_for_gaudi and max_total_tokens > 0:
extra_padding = max(extra_padding, max_total_tokens - (bucket_size + 1) - max_new_tokens)
input_ids = tokenized_inputs["input_ids"] input_ids = tokenized_inputs["input_ids"]
attention_mask = tokenized_inputs["attention_mask"] attention_mask = tokenized_inputs["attention_mask"]
@ -410,7 +416,7 @@ class CausalLMBatch(Batch):
attention_mask, (left_padding, 1), value=0 attention_mask, (left_padding, 1), value=0
) )
all_input_ids = torch.nn.functional.pad( all_input_ids = torch.nn.functional.pad(
input_ids, (0, max_new_tokens + extra_padding), value=tokenizer.pad_token_id input_ids, (0, max_new_tokens), value=tokenizer.pad_token_id
).T.split(1, dim=1) ).T.split(1, dim=1)
else: else:
all_input_ids = input_ids.clone().T.split(1, dim=1) all_input_ids = input_ids.clone().T.split(1, dim=1)
@ -441,7 +447,6 @@ class CausalLMBatch(Batch):
top_n_tokens=top_n_tokens, top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor, top_n_tokens_tensor=top_n_tokens_tensor,
input_length=input_len, input_length=input_len,
right_padding=max_new_tokens + extra_padding if is_optimized_for_gaudi else 0
) )
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
@ -471,6 +476,10 @@ class CausalLMBatch(Batch):
def seq_length(self): def seq_length(self):
return self.attention_mask.size(1) return self.attention_mask.size(1)
@property
def right_padding(self):
return self.seq_length - self.input_length
# Maximum number of tokens this batch will grow to # Maximum number of tokens this batch will grow to
@property @property
def max_tokens(self): def max_tokens(self):
@ -914,8 +923,6 @@ class CausalLM(Model):
# Adjust lengths # Adjust lengths
batch.input_length += 1 batch.input_length += 1
if batch.right_padding > 0:
batch.right_padding -= 1
# Update position_ids # Update position_ids
if prefill: if prefill: