mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 00:12:08 +00:00
Co-authored-by: madamczykhabana <110973826+madamczykhabana@users.noreply.github.com>
This commit is contained in:
parent
8f4aba6ad3
commit
83b059bd27
@ -45,6 +45,7 @@ if 'GRAPH_VISUALIZATION' in os.environ:
|
||||
for f in glob.glob('.graph_dumps/*'):
|
||||
os.remove(f)
|
||||
|
||||
MAX_TOTAL_TOKENS = int(os.getenv("MAX_TOTAL_TOKENS", "0"))
|
||||
BATCH_BUCKET_SIZE = int(os.environ.get('BATCH_BUCKET_SIZE', 8))
|
||||
PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 128))
|
||||
PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 4))
|
||||
@ -98,26 +99,34 @@ def move_data(dst_tensor, chunk_size, indices, src_tensors):
|
||||
return result
|
||||
|
||||
|
||||
def shift(tensor, dim, offset):
|
||||
shape = tensor.shape
|
||||
elements = shape[dim]
|
||||
if offset == 0 or abs(offset) > elements:
|
||||
return tensor
|
||||
htorch.core.mark_step()
|
||||
# We generate indices from (0 - offset + elements) to (elements - offset + elements)
|
||||
# so that next modulo operation operates on positive values
|
||||
indices = torch.arange(0, elements, dtype=torch.int32, device=tensor.device)
|
||||
offset = torch.tensor(-offset + elements, dtype=torch.int32, device=tensor.device)
|
||||
indices.add_(offset)
|
||||
indices.remainder_(elements)
|
||||
target_shape = [1,] * len(tensor.shape)
|
||||
target_shape[dim] = elements
|
||||
indices = indices.view(target_shape).expand(shape)
|
||||
result = torch.gather(tensor, dim, indices)
|
||||
htorch.core.mark_step()
|
||||
def generate_shift_chunks(offset):
|
||||
chunk_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
|
||||
result = []
|
||||
while offset != 0:
|
||||
sign = 1 if offset > 0 else -1
|
||||
best_chunk = min((abs(offset - sign * c), sign * c) for c in chunk_sizes)[1]
|
||||
result.append(best_chunk)
|
||||
offset = offset - best_chunk
|
||||
return result
|
||||
|
||||
|
||||
def roll(tensor, dim, chunks):
|
||||
dbg_trace('ROLL', f'shape:{list(tensor.shape)} dim:{dim} chunks:{chunks}')
|
||||
for c in chunks:
|
||||
tensor = torch.roll(tensor, c, dim)
|
||||
htorch.core.mark_step()
|
||||
return tensor
|
||||
|
||||
|
||||
def shift(tensor, dim, offset):
|
||||
assert dim < 0, 'Only negative dims are supported'
|
||||
if offset == 0:
|
||||
return tensor
|
||||
chunks = generate_shift_chunks(offset)
|
||||
tensor = roll(tensor, dim, chunks)
|
||||
return tensor
|
||||
|
||||
|
||||
def shift_all(srcs, dim, offsets):
|
||||
return [shift(src, dim, offset) for src, offset in zip(srcs, offsets)]
|
||||
|
||||
@ -197,7 +206,6 @@ class CausalLMBatch(Batch):
|
||||
top_n_tokens_tensor: torch.Tensor
|
||||
|
||||
input_length: int
|
||||
right_padding: int
|
||||
|
||||
def to_pb(self) -> generate_pb2.CachedBatch:
|
||||
return generate_pb2.CachedBatch(
|
||||
@ -214,9 +222,13 @@ class CausalLMBatch(Batch):
|
||||
batch_id = batches[0].batch_id
|
||||
device = batches[0].input_ids.device
|
||||
|
||||
max_input_length = max(b.input_length for b in batches)
|
||||
input_lengths = [b.input_length for b in batches]
|
||||
max_input_length = max(input_lengths)
|
||||
offsets = [max_input_length - b.input_length for b in batches]
|
||||
padding = [b.right_padding for b in batches]
|
||||
# For prefill there is a space allocated only for first token
|
||||
# Need to add padding to the max total tokens before first decode
|
||||
extra_padding = [MAX_TOTAL_TOKENS - b.seq_length for b in batches]
|
||||
|
||||
moves_needed = [total_requests - len(b) if b.batch_size == new_bs else total_requests for b in batches]
|
||||
target_batch_idx = min(enumerate(moves_needed), key=lambda idx_val: idx_val[1])[0]
|
||||
@ -225,9 +237,9 @@ class CausalLMBatch(Batch):
|
||||
# FIXME: max_seq_len for non optimized code
|
||||
if len(batches) > 1:
|
||||
scenario = 'CONCAT'
|
||||
elif batches[0].batch_size != new_bs:
|
||||
elif batches[target_batch_idx].batch_size != new_bs:
|
||||
scenario = 'RESHAPE'
|
||||
elif padding[0] <= 0:
|
||||
elif padding[target_batch_idx] <= 0:
|
||||
scenario = 'SHIFT'
|
||||
offsets = [b.max_input_length - max_input_length for b in batches]
|
||||
max_input_length = max(b.max_input_length for b in batches)
|
||||
@ -235,9 +247,15 @@ class CausalLMBatch(Batch):
|
||||
# Nothing to do
|
||||
return batches[0]
|
||||
|
||||
inplace = batches[target_batch_idx].batch_size == new_bs
|
||||
inplace = (batches[target_batch_idx].batch_size == new_bs)
|
||||
|
||||
dbg_trace(
|
||||
scenario, f'bs:{[b.batch_size for b in batches]}->{new_bs} reqs:{[len(b) for b in batches]} offsets:{offsets} padding:{padding} moves_needed:{moves_needed} inplace:{inplace}')
|
||||
scenario, f'bs:{[b.batch_size for b in batches]}->{new_bs}'
|
||||
f' reqs:{[len(b) for b in batches]}'
|
||||
f' offsets:{offsets}'
|
||||
f' input_lengths:{input_lengths}'
|
||||
f' cur_padding:{padding}'
|
||||
f' inplace:{inplace}')
|
||||
|
||||
grouped_requests = [[req for req in batch.requests] for batch in batches]
|
||||
flat_requests = list(itertools.chain(*grouped_requests))
|
||||
@ -256,7 +274,7 @@ class CausalLMBatch(Batch):
|
||||
num_layers = len(batches[0].past_key_values)
|
||||
past_key_values_type = type(batches[0].past_key_values)
|
||||
|
||||
seq_dim = 1
|
||||
seq_dim = -1
|
||||
if batches[0].past_key_values[0][0].size(-1) != batches[0].past_key_values[0][1].size(-1):
|
||||
# Case for Bloom
|
||||
key_dim = -1
|
||||
@ -267,14 +285,10 @@ class CausalLMBatch(Batch):
|
||||
for b in batches:
|
||||
b.past_key_values = list(b.past_key_values)
|
||||
|
||||
# For prefill there is a space allocated only for first token
|
||||
# Need to add padding to the max total tokens before first decode
|
||||
paddings = [(batch.input_length + batch.right_padding) - batch.seq_length for batch in batches]
|
||||
|
||||
src = [b.input_ids for b in batches]
|
||||
for b in batches:
|
||||
del b.input_ids
|
||||
src = pad_tensors(src, paddings, seq_dim, pad_token_id)
|
||||
src = pad_tensors(src, extra_padding, seq_dim, pad_token_id)
|
||||
src = shift_all(src, seq_dim, offsets)
|
||||
input_ids = prepare_memory(new_bs, src[target_batch_idx], inplace)
|
||||
input_ids = move_data(input_ids, 1, indices, src)
|
||||
@ -282,7 +296,7 @@ class CausalLMBatch(Batch):
|
||||
src = [b.attention_mask for b in batches]
|
||||
for b in batches:
|
||||
del b.attention_mask
|
||||
src = pad_tensors(src, paddings, seq_dim, 0)
|
||||
src = pad_tensors(src, extra_padding, seq_dim, 0)
|
||||
src = shift_all(src, seq_dim, offsets)
|
||||
attention_mask = prepare_memory(new_bs, src[target_batch_idx], inplace)
|
||||
attention_mask = move_data(attention_mask, 1, indices, src)
|
||||
@ -290,29 +304,36 @@ class CausalLMBatch(Batch):
|
||||
src = [b.position_ids for b in batches]
|
||||
for b in batches:
|
||||
del b.position_ids
|
||||
src = shift_all(src, seq_dim, offsets)
|
||||
position_ids = prepare_memory(new_bs, src[target_batch_idx], inplace)
|
||||
position_ids = move_data(position_ids, 1, indices, src)
|
||||
|
||||
past_key_values = []
|
||||
for layer_num in range(num_layers):
|
||||
src = [b.past_key_values[layer_num][0] for b in batches]
|
||||
src = pad_tensors(src, paddings, key_dim, 0)
|
||||
src = shift_all(src, key_dim, offsets)
|
||||
updated_key = prepare_memory(new_bs * chunk_size, src[target_batch_idx], inplace)
|
||||
updated_key = move_data(updated_key, chunk_size, indices, src)
|
||||
|
||||
src = [b.past_key_values[layer_num][1] for b in batches]
|
||||
src = pad_tensors(src, paddings, value_dim, 0)
|
||||
src = shift_all(src, value_dim, offsets)
|
||||
updated_value = prepare_memory(new_bs * chunk_size, src[target_batch_idx], inplace)
|
||||
updated_value = move_data(updated_value, chunk_size, indices, src)
|
||||
|
||||
past_key_values.append((updated_key, updated_value))
|
||||
src = None
|
||||
src_keys = [[b.past_key_values[layer_num][0] for layer_num in range(num_layers)] for b in batches]
|
||||
src_values = [[b.past_key_values[layer_num][1] for layer_num in range(num_layers)] for b in batches]
|
||||
for b in batches:
|
||||
b.past_key_values[layer_num] = None
|
||||
del b.past_key_values
|
||||
|
||||
past_key_values = past_key_values_type(past_key_values)
|
||||
src_keys = [torch.stack(src) for src in src_keys]
|
||||
htorch.core.mark_step()
|
||||
src_keys = pad_tensors(src_keys, extra_padding, key_dim, 0)
|
||||
src_keys = shift_all(src_keys, key_dim, offsets)
|
||||
src_keys = [[t.squeeze(0).clone() for t in torch.split(src, 1)] for src in src_keys]
|
||||
htorch.core.mark_step()
|
||||
|
||||
dst_keys = [prepare_memory(new_bs * chunk_size, prev, inplace) for prev in src_keys[target_batch_idx]]
|
||||
dst_keys = [move_data(dst_keys[layer_num], chunk_size, indices, [src[layer_num] for src in src_keys]) for layer_num in range(num_layers)]
|
||||
|
||||
src_values = [torch.stack(src) for src in src_values]
|
||||
htorch.core.mark_step()
|
||||
src_values = pad_tensors(src_values, extra_padding, value_dim, 0)
|
||||
src_values = shift_all(src_values, value_dim, offsets)
|
||||
src_values = [[t.squeeze(0).clone() for t in torch.split(src, 1)] for src in src_values]
|
||||
htorch.core.mark_step()
|
||||
|
||||
dst_values = [prepare_memory(new_bs * chunk_size, prev, inplace) for prev in src_values[target_batch_idx]]
|
||||
dst_values = [move_data(dst_values[layer_num], chunk_size, indices, [src[layer_num] for src in src_values]) for layer_num in range(num_layers)]
|
||||
|
||||
past_key_values = past_key_values_type(zip(dst_keys, dst_values))
|
||||
|
||||
top_n_tokens = [r.data.top_n_tokens for r in flat_requests]
|
||||
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
|
||||
@ -324,7 +345,6 @@ class CausalLMBatch(Batch):
|
||||
|
||||
max_seq_len = attention_mask.size(1)
|
||||
input_length = max_input_length
|
||||
right_padding = max_seq_len - input_length
|
||||
|
||||
htorch.core.mark_step()
|
||||
|
||||
@ -339,7 +359,6 @@ class CausalLMBatch(Batch):
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||
input_length=input_length,
|
||||
right_padding=right_padding
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -362,15 +381,6 @@ class CausalLMBatch(Batch):
|
||||
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
|
||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb([r.parameters for r in pb.requests], dtype, device)
|
||||
|
||||
# TODO: this should be set to rust side `max_total_tokens`,
|
||||
# (see https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs#L177)
|
||||
# but TGI does not offer an API to expose this variable to python, as this variable
|
||||
# is handled by the client but it appears the model is initialized by the server.
|
||||
# An alternative could be to initialize the buffers during warmup.
|
||||
# Dummy
|
||||
max_total_tokens = int(os.getenv("MAX_TOTAL_TOKENS", "0"))
|
||||
logger.info("MAX_TOTAL_TOKENS = {}".format(max_total_tokens))
|
||||
|
||||
# TODO: by tokenizing all inputs at once we loose information on actual input lengths
|
||||
# this means that we cannot shift inputs to the left after a long input sequence
|
||||
# was filtered out
|
||||
@ -394,10 +404,6 @@ class CausalLMBatch(Batch):
|
||||
bucket_size = round_up(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF) - 1
|
||||
left_padding = bucket_size - input_len
|
||||
|
||||
extra_padding = 0
|
||||
if is_optimized_for_gaudi and max_total_tokens > 0:
|
||||
extra_padding = max(extra_padding, max_total_tokens - (bucket_size + 1) - max_new_tokens)
|
||||
|
||||
input_ids = tokenized_inputs["input_ids"]
|
||||
attention_mask = tokenized_inputs["attention_mask"]
|
||||
|
||||
@ -410,7 +416,7 @@ class CausalLMBatch(Batch):
|
||||
attention_mask, (left_padding, 1), value=0
|
||||
)
|
||||
all_input_ids = torch.nn.functional.pad(
|
||||
input_ids, (0, max_new_tokens + extra_padding), value=tokenizer.pad_token_id
|
||||
input_ids, (0, max_new_tokens), value=tokenizer.pad_token_id
|
||||
).T.split(1, dim=1)
|
||||
else:
|
||||
all_input_ids = input_ids.clone().T.split(1, dim=1)
|
||||
@ -441,7 +447,6 @@ class CausalLMBatch(Batch):
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||
input_length=input_len,
|
||||
right_padding=max_new_tokens + extra_padding if is_optimized_for_gaudi else 0
|
||||
)
|
||||
|
||||
@tracer.start_as_current_span("filter")
|
||||
@ -471,6 +476,10 @@ class CausalLMBatch(Batch):
|
||||
def seq_length(self):
|
||||
return self.attention_mask.size(1)
|
||||
|
||||
@property
|
||||
def right_padding(self):
|
||||
return self.seq_length - self.input_length
|
||||
|
||||
# Maximum number of tokens this batch will grow to
|
||||
@property
|
||||
def max_tokens(self):
|
||||
@ -914,8 +923,6 @@ class CausalLM(Model):
|
||||
|
||||
# Adjust lengths
|
||||
batch.input_length += 1
|
||||
if batch.right_padding > 0:
|
||||
batch.right_padding -= 1
|
||||
|
||||
# Update position_ids
|
||||
if prefill:
|
||||
|
Loading…
Reference in New Issue
Block a user