diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 336c9823..3354d426 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -48,6 +48,8 @@ class CausalLMBatch(Batch): # Maximum number of tokens this batch will grow to max_tokens: int + # Maximum number of decode steps before at least one request finish + max_decode_steps: int # Past metadata keys_head_dim_last: bool = True @@ -77,7 +79,7 @@ class CausalLMBatch(Batch): # Parse batch max_truncation = 0 padding_right_offset = 0 - max_decode_tokens = 0 + max_decode_steps = None for i, r in enumerate(pb.requests): requests_idx_mapping[r.id] = i inputs.append(r.inputs) @@ -89,7 +91,15 @@ class CausalLMBatch(Batch): ) stopping_criterias.append(stopping_criteria) max_truncation = max(max_truncation, r.truncate) - max_decode_tokens += stopping_criteria.max_new_tokens + + # Maximum number of decode steps before one request finish + if max_decode_steps is None: + max_decode_steps = stopping_criteria.max_new_tokens + else: + max_decode_steps = min( + max_decode_steps, stopping_criteria.max_new_tokens + ) + padding_right_offset = max( padding_right_offset, stopping_criteria.max_new_tokens ) @@ -118,7 +128,10 @@ class CausalLMBatch(Batch): position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1) - max_tokens = len(inputs) * max_input_length + max_decode_tokens + # Since we are sure that at least one request will be dropped in max_decode_steps, + # we know the kv_cache will only grow to cumulative_length + batch_size * max_decode_steps + # before getting filtered and decreasing in size + max_tokens = len(inputs) * (max_input_length + max_decode_steps) return cls( batch_id=pb.id, @@ -137,6 +150,7 @@ class CausalLMBatch(Batch): max_input_length=max_input_length.item(), padding_right_offset=padding_right_offset, max_tokens=max_tokens, + max_decode_steps=max_decode_steps, ) @tracer.start_as_current_span("filter") @@ -159,8 +173,8 @@ class CausalLMBatch(Batch): next_token_choosers = [] stopping_criterias = [] - total_remaining_decode_tokens = 0 new_padding_right_offset = 0 + max_decode_steps = None for i, r in enumerate(requests): idx = self.requests_idx_mapping[r.id] @@ -178,13 +192,17 @@ class CausalLMBatch(Batch): next_token_choosers.append(self.next_token_choosers[idx]) stopping_criteria = self.stopping_criterias[idx] stopping_criterias.append(stopping_criteria) - remaining_decode_tokens = ( + + # Remaining decode steps for this request + remaining_decode = ( stopping_criteria.max_new_tokens - stopping_criteria.current_tokens ) - total_remaining_decode_tokens += remaining_decode_tokens - new_padding_right_offset = max( - new_padding_right_offset, remaining_decode_tokens - ) + if max_decode_steps is None: + max_decode_steps = remaining_decode + else: + max_decode_steps = min(max_decode_steps, remaining_decode) + + new_padding_right_offset = max(new_padding_right_offset, remaining_decode) # Apply indices to input_ids, attention mask, past key values and other items that need to be cached input_ids = self.input_ids[keep_indices] @@ -217,7 +235,10 @@ class CausalLMBatch(Batch): layer[1] = past_values[keep_indices, :, -past_kv_length:, :] del past_values - max_tokens = len(requests) * max_input_length + total_remaining_decode_tokens + # Since we are sure that at least one request will be dropped in max_decode_steps, + # we know the kv_cache will only grow to cumulative_length + batch_size * max_decode_steps + # before getting filtered and decreasing in size + max_tokens = len(requests) * (max_input_length + max_decode_steps) self.requests = requests self.requests_idx_mapping = requests_idx_mapping @@ -232,6 +253,7 @@ class CausalLMBatch(Batch): self.max_input_length = max_input_length self.padding_right_offset = new_padding_right_offset self.max_tokens = max_tokens + self.max_decode_steps = max_decode_steps return self @@ -256,7 +278,6 @@ class CausalLMBatch(Batch): all_input_ids = [] next_token_choosers = [] stopping_criterias = [] - max_tokens = 0 # Batch tensors input_ids = None @@ -264,6 +285,8 @@ class CausalLMBatch(Batch): position_ids = None past_key_values = [] + max_decode_steps = None + # Used for slicing correctly inside the tensors # Equivalent to a cumsum on batch sizes start_index = 0 @@ -341,10 +364,11 @@ class CausalLMBatch(Batch): layer[k] = t.view(len(batch), -1, *t.shape[-2:]) start_index = end_index - # Add eventual padding tokens that were added while concatenating - max_tokens += batch.max_tokens + ( - max_input_length - batch.max_input_length - ) * len(batch) + + if max_decode_steps is None: + max_decode_steps = batch.max_decode_steps + else: + max_decode_steps = min(max_decode_steps, batch.max_decode_steps) first_past_kvs = batches[0].past_key_values _, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape @@ -417,6 +441,8 @@ class CausalLMBatch(Batch): past_key_values.append([padded_past_keys, padded_past_values]) + max_tokens = len(requests) * (max_input_length + max_decode_steps) + return cls( batch_id=batches[0].batch_id, requests=requests, @@ -435,6 +461,7 @@ class CausalLMBatch(Batch): padding_right_offset=padding_right_offset, keys_head_dim_last=batches[0].keys_head_dim_last, max_tokens=max_tokens, + max_decode_steps=max_decode_steps, ) def __len__(self): @@ -636,6 +663,8 @@ class CausalLM(Model): batch.attention_mask[:, -batch.padding_right_offset] = 1 # Decrease right offset batch.padding_right_offset -= 1 + # Decrease max_decode_steps + batch.max_decode_steps -= 1 # Update position_ids batch.position_ids = batch.position_ids[:, -1:] + 1 diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 61ccca84..7c8865f9 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -58,6 +58,8 @@ class FlashCausalLMBatch(Batch): # Maximum number of tokens this batch will grow to max_tokens: int + # Maximum number of decode steps before at least one request finish + max_decode_steps: int def to_pb(self) -> generate_pb2.Batch: return generate_pb2.Batch( @@ -92,7 +94,7 @@ class FlashCausalLMBatch(Batch): # Cumulative length cumulative_length = 0 - max_tokens = 0 + max_decode_steps = None # Parse batch for i, r in enumerate(pb.requests): @@ -127,7 +129,15 @@ class FlashCausalLMBatch(Batch): stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer ) - max_new_tokens = stopping_criteria.max_new_tokens + + # Maximum number of decode steps before one request finish + if max_decode_steps is None: + max_decode_steps = stopping_criteria.max_new_tokens + else: + max_decode_steps = min( + max_decode_steps, stopping_criteria.max_new_tokens + ) + stopping_criterias.append(stopping_criteria) all_input_ids_tensor.append( @@ -136,7 +146,11 @@ class FlashCausalLMBatch(Batch): # Update cumulative_length += input_length - max_tokens += input_length + max_new_tokens + + # Since we are sure that at least one request will be dropped in max_decode_steps, + # we know the kv_cache will only grow to cumulative_length + batch_size * max_decode_steps + # before getting filtered and decreasing in size + max_tokens = cumulative_length + max_decode_steps * len(pb.requests) return cls( batch_id=pb.id, @@ -156,6 +170,7 @@ class FlashCausalLMBatch(Batch): stopping_criterias=stopping_criterias, past_pad=None, max_tokens=max_tokens, + max_decode_steps=max_decode_steps, ) @tracer.start_as_current_span("filter") @@ -190,7 +205,7 @@ class FlashCausalLMBatch(Batch): next_token_choosers = [] stopping_criterias = [] - max_tokens = 0 + max_decode_steps = None for i, r in enumerate(requests): idx = self.requests_idx_mapping[r.id] @@ -221,11 +236,21 @@ class FlashCausalLMBatch(Batch): stopping_criteria = self.stopping_criterias[idx] stopping_criterias.append(stopping_criteria) - - cumulative_length += request_input_length - max_tokens += request_input_length + ( + # Remaining decode steps for this request + remaining_decode = ( stopping_criteria.max_new_tokens - stopping_criteria.current_tokens ) + if max_decode_steps is None: + max_decode_steps = remaining_decode + else: + max_decode_steps = min(max_decode_steps, remaining_decode) + + cumulative_length += request_input_length + + # Since we are sure that at least one request will be dropped in max_decode_steps, + # we know the kv_cache will only grow to cumulative_length + batch_size * max_decode_steps + # before getting filtered and decreasing in size + max_tokens = cumulative_length + max_decode_steps * len(requests) if single_request: # Preallocate tensor for bs = 1 case @@ -290,7 +315,8 @@ class FlashCausalLMBatch(Batch): # Cumulative length cumulative_batch_size = 0 cumulative_length = 0 - max_tokens = 0 + + max_decode_steps = None for i, batch in enumerate(batches): requests.extend(batch.requests) @@ -329,10 +355,16 @@ class FlashCausalLMBatch(Batch): next_token_choosers.extend(batch.next_token_choosers) stopping_criterias.extend(batch.stopping_criterias) + if max_decode_steps is None: + max_decode_steps = batch.max_decode_steps + else: + max_decode_steps = min(max_decode_steps, batch.max_decode_steps) + # Update cumulative_length += batch.cu_seqlens[-1] cumulative_batch_size += len(batch) - max_tokens += batch.max_tokens + + max_tokens = cumulative_length + max_decode_steps * cumulative_batch_size return FlashCausalLMBatch( batch_id=batches[0].batch_id, @@ -352,6 +384,7 @@ class FlashCausalLMBatch(Batch): next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, max_tokens=max_tokens, + max_decode_steps=max_decode_steps, ) def __len__(self): @@ -617,6 +650,7 @@ class FlashCausalLM(Model): batch.all_input_ids[i] = all_input_ids batch.all_input_ids_tensor[i] = all_input_ids_tensor batch.max_seqlen = max(batch.max_seqlen, new_input_length) + batch.max_decode_steps -= 1 if len(batch) != 1: # Add each sequence before its padding batch.past_key_values[i * 2] = present[:, start_index:end_index] diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 0cb20760..627eef3a 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -56,6 +56,8 @@ class Seq2SeqLMBatch(Batch): # Maximum number of tokens this batch will grow to max_tokens: int + # Maximum number of decode steps before at least one request finish + max_decode_steps: int def to_pb(self) -> generate_pb2.Batch: """Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf""" @@ -86,7 +88,7 @@ class Seq2SeqLMBatch(Batch): # Parse batch max_truncation = 0 padding_right_offset = 0 - max_decode_tokens = 0 + max_decode_steps = None for i, r in enumerate(pb.requests): inputs.append(r.inputs) requests_idx_mapping[r.id] = i @@ -99,7 +101,15 @@ class Seq2SeqLMBatch(Batch): ) stopping_criterias.append(stopping_criteria) max_truncation = max(max_truncation, r.truncate) - max_decode_tokens += stopping_criteria.max_new_tokens + + # Maximum number of decode steps before one request finish + if max_decode_steps is None: + max_decode_steps = stopping_criteria.max_new_tokens + else: + max_decode_steps = min( + max_decode_steps, stopping_criteria.max_new_tokens + ) + padding_right_offset = max( padding_right_offset, stopping_criteria.max_new_tokens ) @@ -125,7 +135,7 @@ class Seq2SeqLMBatch(Batch): ) all_decoder_input_ids = decoder_input_ids.view(-1).split(1) - max_tokens = len(inputs) * max_input_length + max_decode_tokens + max_tokens = len(inputs) * (max_input_length + max_decode_steps) return cls( batch_id=pb.id, @@ -148,6 +158,7 @@ class Seq2SeqLMBatch(Batch): max_decoder_input_length=1, padding_right_offset=padding_right_offset, max_tokens=max_tokens, + max_decode_steps=max_decode_steps, ) @tracer.start_as_current_span("filter") @@ -177,7 +188,7 @@ class Seq2SeqLMBatch(Batch): max_decoder_input_length = 0 padding_right_offset = 0 - remaining_decode_tokens = 0 + max_decode_steps = None for i, r in enumerate(requests): idx = self.requests_idx_mapping[r.id] @@ -207,9 +218,15 @@ class Seq2SeqLMBatch(Batch): next_token_choosers.append(self.next_token_choosers[idx]) stopping_criteria = self.stopping_criterias[idx] stopping_criterias.append(stopping_criteria) - remaining_decode_tokens += ( + + # Remaining decode steps for this request + remaining_decode = ( stopping_criteria.max_new_tokens - stopping_criteria.current_tokens ) + if max_decode_steps is None: + max_decode_steps = remaining_decode + else: + max_decode_steps = min(max_decode_steps, remaining_decode) # Apply indices to input_ids, attention mask, past key values and other items that need to be cached self.decoder_input_ids = self.decoder_input_ids[keep_indices] @@ -240,9 +257,8 @@ class Seq2SeqLMBatch(Batch): layer[2] = layer[2][keep_indices, :, -max_input_length:] layer[3] = layer[3][keep_indices, :, -max_input_length:] - max_tokens = ( - len(requests) * (max_input_length + max_decoder_input_length) - + remaining_decode_tokens + max_tokens = len(requests) * ( + max_input_length + max_decoder_input_length + max_decode_steps ) self.requests = requests @@ -259,6 +275,7 @@ class Seq2SeqLMBatch(Batch): self.max_decoder_input_length = max_decoder_input_length self.padding_right_offset = padding_right_offset self.max_tokens = max_tokens + self.max_decode_steps = max_decode_steps return self @@ -290,7 +307,7 @@ class Seq2SeqLMBatch(Batch): token_offsets = [] next_token_choosers = [] stopping_criterias = [] - max_tokens = 0 + max_decode_steps = 0 # Batch tensors attention_mask = None @@ -398,13 +415,11 @@ class Seq2SeqLMBatch(Batch): ] start_index = end_index - # Add eventual padding tokens that were added while concatenating - max_tokens += batch.max_tokens + ( - max_input_length - - batch.max_input_length - + max_decoder_input_length - - batch.max_decoder_input_length - ) * len(batch) + + if max_decode_steps is None: + max_decode_steps = batch.max_decode_steps + else: + max_decode_steps = min(max_decode_steps, batch.max_decode_steps) # Determine shapes for new past kv tensors first_past_kvs = batches[0].past_key_values @@ -471,6 +486,10 @@ class Seq2SeqLMBatch(Batch): start_index = end_index + max_tokens = len(requests) * ( + max_input_length + max_decoder_input_length + max_decode_steps + ) + return cls( batch_id=batches[0].batch_id, requests=requests, @@ -492,6 +511,7 @@ class Seq2SeqLMBatch(Batch): max_decoder_input_length=max_decoder_input_length, padding_right_offset=padding_right_offset, max_tokens=max_tokens, + max_decode_steps=max_decode_steps, ) def __len__(self): @@ -717,5 +737,6 @@ class Seq2SeqLM(Model): if batch.decoder_attention_mask is not None: batch.decoder_attention_mask[:, -batch.padding_right_offset] = 1 batch.padding_right_offset -= 1 + batch.max_decode_steps -= 1 return generations, batch