diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index db68fc9c..7e860e07 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -44,11 +44,12 @@ def default_causal_lm_batch(default_pb_batch, gpt2_tokenizer): @pytest.fixture def default_multi_requests_causal_lm_batch(default_pb_request, gpt2_tokenizer): req_0 = copy(default_pb_request) + req_0.id = 1 req_1 = default_pb_request - req_1.id = 1 + req_1.id = 2 req_1.stopping_parameters.max_new_tokens = 5 - batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2) + batch_pb = generate_pb2.Batch(id=1, requests=[req_0, req_1], size=2) return CausalLMBatch.from_pb(batch_pb, gpt2_tokenizer, torch.device("cpu")) @@ -67,12 +68,12 @@ def test_batch_from_pb(default_pb_batch, default_causal_lm_batch): assert batch.past_key_values is None - assert torch.equal(batch.input_ids, batch.all_input_ids[:, :, 0]) + assert all([torch.equal(input_ids, all_input_ids[:, 0]) for input_ids, all_input_ids in zip(batch.input_ids, batch.all_input_ids)]) assert batch.input_lengths == [1] - assert batch.size == default_pb_batch.size - assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size + assert len(batch) == default_pb_batch.size + assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == len(batch) assert batch.max_input_length == batch.input_lengths[0] @@ -93,7 +94,7 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch): assert len(generations) == len(next_batch) assert isinstance(next_batch, CausalLMBatch) - assert len(next_batch.all_input_ids) == next_batch.size + assert len(next_batch.all_input_ids) == len(next_batch) assert len(next_batch.all_input_ids[0]) == sequence_length + 1 assert len(next_batch.attention_mask[0]) == 11 assert next_batch.all_input_ids[0][-1] == 13 @@ -103,7 +104,7 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch): assert torch.all(next_batch.attention_mask[0][0:2] == 1) assert torch.all(next_batch.attention_mask[0][2:] == 0) - assert next_batch.input_ids.shape == (next_batch.size, 1) + assert next_batch.input_ids.shape == (len(next_batch), 1) assert next_batch.input_ids[0, 0] == 13 assert next_batch.input_lengths == [2] @@ -168,6 +169,8 @@ def test_causal_lm_generate_token_completion_multi( == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens ) + next_batch = next_batch.filter([next_batch.requests[0]]) + for _ in range( default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens @@ -266,6 +269,8 @@ def test_batch_concatenate( == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens ) + next_batch = next_batch.filter([next_batch.requests[0], next_batch.requests[1]]) + for _ in range( default_causal_lm_batch.stopping_criterias[0].max_new_tokens - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens @@ -285,6 +290,8 @@ def test_batch_concatenate( == default_causal_lm_batch.stopping_criterias[0].max_new_tokens ) + next_batch = next_batch.filter([next_batch.requests[1]]) + for _ in range( default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens - default_causal_lm_batch.stopping_criterias[0].max_new_tokens diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index aa50f8f8..586334ed 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -147,10 +147,13 @@ class CausalLMBatch(Batch): all_input_ids = [] max_input_length = 0 + next_token_choosers = [] + stopping_criterias = [] + for i, r in enumerate(requests): idx = self.requests_idx_mapping[r.id] - keep_indices.append(idx) requests_idx_mapping[r.id] = i + keep_indices.append(idx) offsets.append(self.offsets[idx]) token_offsets.append(self.token_offsets[idx]) @@ -160,28 +163,37 @@ class CausalLMBatch(Batch): input_lengths.append(request_input_length) max_input_length = max(max_input_length, request_input_length) - # Replace metadata - self.requests_idx_mapping = requests_idx_mapping - self.input_lengths = input_lengths - self.offsets = offsets - self.token_offsets = token_offsets - self.all_input_ids = all_input_ids - self.max_input_length = max_input_length + next_token_choosers.append(self.next_token_choosers[idx]) + stopping_criterias.append(self.stopping_criterias[idx]) # Apply indices to input_ids, attention mask, past key values and other items that need to be cached - self.input_ids = self.input_ids[keep_indices] - self.attention_mask = self.attention_mask[keep_indices] - self.position_ids = self.position_ids[keep_indices] + input_ids = self.input_ids[keep_indices] + attention_mask = self.attention_mask[keep_indices] + position_ids = self.position_ids[keep_indices] # Force past to be of dim [self_size, num_heads, ...] for easy indexing - self.past_key_values = [ + past_key_values = [ [t.view(len(self), -1, *t.shape[-2:])[keep_indices] for t in layer] for layer in self.past_key_values ] - self.requests = requests - self.next_token_choosers = [self.next_token_choosers[i] for i in keep_indices] - self.stopping_criterias = [self.stopping_criterias[i] for i in keep_indices] - return self + return CausalLMBatch( + batch_id=self.batch_id, + requests=requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + all_input_ids=all_input_ids, + input_lengths=input_lengths, + offsets=offsets, + token_offsets=token_offsets, + next_token_choosers=next_token_choosers, + stopping_criterias=stopping_criterias, + max_input_length=max_input_length, + padding_right_offset=self.padding_right_offset, + keys_head_dim_last=self.keys_head_dim_last, + ) @classmethod @tracer.start_as_current_span("concatenate") @@ -224,8 +236,9 @@ class CausalLMBatch(Batch): stopping_criterias.extend(batch.stopping_criterias) if i == 0: - requests_idx_mapping = requests_idx_mapping + requests_idx_mapping = batch.requests_idx_mapping else: + # We need to offset the mapping for each batch by the cumulative batch size for k, v in batch.requests_idx_mapping.items(): requests_idx_mapping[k] = v + start_index @@ -525,17 +538,20 @@ class CausalLM(Model): generations.append(generation) # Update values - batch.input_ids[i] = next_token_id + batch.input_ids[i, 0] = next_token_id batch.all_input_ids[i] = all_input_ids batch.input_lengths[i] = new_input_length batch.offsets[i] = offset batch.token_offsets[i] = token_offset batch.max_input_length = max(batch.max_input_length, new_input_length) - # Decrease right offset - batch.padding_right_offset -= 1 + # Slice unused values from prefill + batch.input_ids = batch.input_ids[:, :1] + # Update attention_mask as we added a new token to input_ids batch.attention_mask[:, -batch.padding_right_offset] = 1 + # Decrease right offset + batch.padding_right_offset -= 1 # Update position_ids batch.position_ids = batch.position_ids[:, -1:] + 1 diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 530ec9fe..382f917d 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -244,9 +244,12 @@ class FlashCausalLMBatch(Batch): for i, batch in enumerate(batches): requests.extend(batch.requests) - # We need to offset the mapping for each batch by the cumulative batch size - for k, v in batch.requests_idx_mapping.items(): - requests_idx_mapping[k] = v + cumulative_batch_size + if i == 0: + requests_idx_mapping = batch.requests_idx_mapping + else: + # We need to offset the mapping for each batch by the cumulative batch size + for k, v in batch.requests_idx_mapping.items(): + requests_idx_mapping[k] = v + cumulative_batch_size input_ids.extend(batch.input_ids) position_ids.extend(batch.position_ids)