diff --git a/server/tests/models/test_vectorized_causal_lm.py b/server/tests/models/test_vectorized_causal_lm.py index 160c7670..d922865c 100644 --- a/server/tests/models/test_vectorized_causal_lm.py +++ b/server/tests/models/test_vectorized_causal_lm.py @@ -5,7 +5,10 @@ from copy import copy from transformers import AutoTokenizer from text_generation_server.pb import generate_pb2 -from text_generation_server.models.vectorized_causal_lm import VectorizedCausalLM, VectorizedCausalLMBatch +from text_generation_server.models.vectorized_causal_lm import ( + VectorizedCausalLM, + VectorizedCausalLMBatch, +) @pytest.fixture(scope="session") @@ -38,7 +41,9 @@ def default_pb_batch(default_pb_request): @pytest.fixture def default_causal_lm_batch(default_pb_batch, gpt2_tokenizer): - return VectorizedCausalLMBatch.from_pb(default_pb_batch, gpt2_tokenizer, torch.device("cpu")) + return VectorizedCausalLMBatch.from_pb( + default_pb_batch, gpt2_tokenizer, torch.device("cpu") + ) @pytest.fixture @@ -50,7 +55,9 @@ def default_multi_requests_causal_lm_batch(default_pb_request, gpt2_tokenizer): req_1.stopping_parameters.max_new_tokens = 5 batch_pb = generate_pb2.Batch(id=1, requests=[req_0, req_1], size=2) - return VectorizedCausalLMBatch.from_pb(batch_pb, gpt2_tokenizer, torch.device("cpu")) + return VectorizedCausalLMBatch.from_pb( + batch_pb, gpt2_tokenizer, torch.device("cpu") + ) def test_batch_from_pb(default_pb_batch, default_causal_lm_batch): @@ -59,33 +66,29 @@ def test_batch_from_pb(default_pb_batch, default_causal_lm_batch): assert batch.batch_id == default_pb_batch.id assert batch.requests == default_pb_batch.requests - assert len(batch.input_ids) == default_pb_batch.size - assert batch.input_ids[0][-1] == 14402 - assert torch.all(batch.input_ids[0][:-1] == 50256) + assert batch.input_ids.shape == (1, 11) + assert batch.input_ids[0, 0] == 14402 + assert batch.attention_mask.shape == (1, 11) assert batch.attention_mask[0, 0] == 1 - assert torch.all(batch.attention_mask[0, 1:] == 0) + assert batch.attention_mask.all() + assert batch.position_ids.shape == (1, 11) assert batch.past_key_values is None - assert all( - [ - torch.equal(input_ids, all_input_ids[:, 0]) - for input_ids, all_input_ids in zip(batch.input_ids, batch.all_input_ids) - ] - ) - assert batch.input_lengths == [1] - assert len(batch) == default_pb_batch.size - assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == len(batch) + assert len(batch) == 1 + assert len(batch.stopping_criterias) == 1 - assert batch.max_input_length == batch.input_lengths[0] + assert batch.max_input_length == 1 def test_batch_concatenate_no_prefill(default_causal_lm_batch): with pytest.raises(ValueError): - VectorizedCausalLMBatch.concatenate([default_causal_lm_batch, default_causal_lm_batch]) + VectorizedCausalLMBatch.concatenate( + [default_causal_lm_batch, default_causal_lm_batch] + ) def test_causal_lm_batch_type(default_causal_lm): @@ -93,39 +96,29 @@ def test_causal_lm_batch_type(default_causal_lm): def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch): - sequence_length = len(default_causal_lm_batch.all_input_ids[0]) generations, next_batch = default_causal_lm.generate_token(default_causal_lm_batch) - assert len(generations) == len(next_batch) + assert len(generations) == len(next_batch) == 1 assert isinstance(next_batch, VectorizedCausalLMBatch) - assert len(next_batch.all_input_ids) == len(next_batch) - assert len(next_batch.all_input_ids[0]) == sequence_length + 1 - assert len(next_batch.attention_mask[0]) == 11 - assert next_batch.all_input_ids[0][-1] == 13 - assert next_batch.all_input_ids[0][-2] == 14402 - assert torch.all(next_batch.all_input_ids[0][:-2] == 50256) + assert next_batch.input_ids.shape == (1, 11) + assert next_batch.input_ids[0, 0] == 14402 + assert next_batch.input_ids[0, 1] == 13 + assert next_batch.max_input_length == 2 + assert next_batch.attention_mask.shape == (1, 11) - assert torch.all(next_batch.attention_mask[0][0:2] == 1) - assert torch.all(next_batch.attention_mask[0][2:] == 0) - - assert next_batch.input_ids.shape == (len(next_batch), 1) - assert next_batch.input_ids[0, 0] == 13 + assert next_batch.attention_mask.all() assert next_batch.input_lengths == [2] - assert next_batch.max_input_length == next_batch.input_lengths[0] assert next_batch.past_key_values is not None - assert all( - [p[0].shape == (1, 12, sequence_length, 64) for p in next_batch.past_key_values] - ) - assert all( - [p[1].shape == (1, 12, sequence_length, 64) for p in next_batch.past_key_values] - ) - assert all([generation.generated_text is None for generation in generations]) - assert all([len(generation.prefill_tokens) == 1 for generation in generations]) - assert all([generation.token_id.item() == 13 for generation in generations]) - assert all([generation.token_text == "." for generation in generations]) + assert all([p[0].shape == (1, 12, 1, 64) for p in next_batch.past_key_values]) + assert all([p[1].shape == (1, 12, 1, 64) for p in next_batch.past_key_values]) + + assert generations[0].generated_text is None + assert len(generations[0].prefill_tokens) == 1 + assert generations[0].token_id.item() == 13 + assert generations[0].token_text == "." assert generations[0].request_id == 0 @@ -222,21 +215,53 @@ def test_batch_concatenate( next_batch = VectorizedCausalLMBatch.concatenate([next_batch_0, next_batch_1]) - assert torch.equal(next_batch.all_input_ids[0], next_batch_0.all_input_ids[0]) - assert torch.equal(next_batch.all_input_ids[1], next_batch_1.all_input_ids[0]) - assert torch.equal(next_batch.all_input_ids[2], next_batch_1.all_input_ids[1]) + assert torch.equal( + next_batch.input_ids[ + 0, + next_batch.max_input_length + - next_batch.input_lengths[0] : next_batch.max_input_length, + ], + next_batch_0.input_ids[ + 0, + next_batch_0.max_input_length + - next_batch_0.input_lengths[0] : next_batch_0.max_input_length, + ], + ) + assert torch.equal( + next_batch.input_ids[ + 1, + next_batch.max_input_length + - next_batch.input_lengths[1] : next_batch.max_input_length, + ], + next_batch_1.input_ids[ + 0, + next_batch_1.max_input_length + - next_batch_1.input_lengths[0] : next_batch_1.max_input_length, + ], + ) + assert torch.equal( + next_batch.input_ids[ + 2, + next_batch.max_input_length + - next_batch.input_lengths[2] : next_batch.max_input_length, + ], + next_batch_1.input_ids[ + 1, + next_batch_1.max_input_length + - next_batch_1.input_lengths[1] : next_batch_1.max_input_length, + ], + ) - assert torch.all( - next_batch.attention_mask[0, : -next_batch.padding_right_offset] == 1 - ) - assert torch.all( - next_batch.attention_mask[1:, 1 : -next_batch.padding_right_offset] == 1 - ) - assert torch.all(next_batch.attention_mask[1:, 3:] == 0) + assert next_batch.attention_mask[0].all() + assert next_batch.attention_mask[1:, 1:].all() + assert next_batch.attention_mask[1:, :1].logical_not().all() assert next_batch.batch_id == 0 - assert next_batch.input_ids[0, 0] == 12355 - assert torch.all(next_batch.input_ids[1:] == 13) + assert next_batch.input_ids[:, next_batch.max_input_length - 1].tolist() == [ + 12355, + 13, + 13, + ] assert next_batch.input_lengths == [3, 2, 2] assert next_batch.max_input_length == 3 @@ -244,9 +269,6 @@ def test_batch_concatenate( assert next_batch.requests[0] == next_batch_0.requests[0] assert next_batch.requests[1:] == next_batch_1.requests - assert next_batch.next_token_choosers[0] == next_batch_0.next_token_choosers[0] - assert next_batch.next_token_choosers[1:] == next_batch_1.next_token_choosers - assert next_batch.stopping_criterias[0] == next_batch_0.stopping_criterias[0] assert next_batch.stopping_criterias[1:] == next_batch_1.stopping_criterias diff --git a/server/text_generation_server/models/vectorized_causal_lm.py b/server/text_generation_server/models/vectorized_causal_lm.py index ec0ed18e..a3f1b634 100644 --- a/server/text_generation_server/models/vectorized_causal_lm.py +++ b/server/text_generation_server/models/vectorized_causal_lm.py @@ -212,7 +212,7 @@ class VectorizedCausalLMBatch(Batch): ] for tensor in tensors_to_update: # Update tensors in-place to allow incremental garbage collection - tensors_to_update.data = tensor[kv_cache_slice] + tensor.data = tensor[kv_cache_slice] return self @@ -320,7 +320,12 @@ class VectorizedCausalLMBatch(Batch): past_key_values = [] for i, kv_format in enumerate(kv_formats): for j in range(1 if kv_format is None else kv_format): - tensors_to_merge = [batch.past_key_values[i] for batch in batches] + tensors_to_merge = [ + batch.past_key_values[i] + if kv_format is None + else batch.past_key_values[i][j] + for batch in batches + ] # Generally `max_input_length`, unless the model allocates more than needed. right_indices = [ left_index + tensor.size(kv_cache_seq_dim) @@ -461,7 +466,8 @@ class VectorizedCausalLM(Model): .squeeze(1) .tolist() ) - if query_length > 1: + is_prefill = batch.past_key_values is None + if is_prefill: prefill_token_ids = batch.input_ids[:, :key_length].tolist() prefill_logprobs = ( logprobs.gather(2, batch.input_ids[:, 1:key_length, None]) @@ -509,7 +515,11 @@ class VectorizedCausalLM(Model): # Decode generated tokens # TODO: Same as stopping_criteria.current_output? output_text = self.decode( - batch.input_ids[i, -stopping_criterias.current_tokens :] + batch.input_ids[ + i, + batch.max_input_length + - stopping_criterias.current_tokens : batch.max_input_length, + ] ) # TODO: Seed generated_text = GeneratedText( @@ -522,7 +532,7 @@ class VectorizedCausalLM(Model): generation = Generation( batch.requests[i].id, - prefill_tokens[i] if batch.details and query_length > 1 else None, + prefill_tokens[i] if batch.details and is_prefill else None, next_token_id, token_logprobs[i] if batch.details else 0.0, next_token_text, diff --git a/server/text_generation_server/utils/tokens_heterogeneous.py b/server/text_generation_server/utils/tokens_heterogeneous.py index ac9aae62..d593fb88 100644 --- a/server/text_generation_server/utils/tokens_heterogeneous.py +++ b/server/text_generation_server/utils/tokens_heterogeneous.py @@ -214,6 +214,7 @@ class HeterogeneousTypicalLogitsWarper(LogitsWarper): scores = scores.masked_fill_(indices_to_remove, self.filter_value) return scores + class HeterogeneousProcessorWrapper(LogitsProcessor): r""" A wrapper for logit warpers or processors without heterogeneous parameter support. @@ -227,12 +228,12 @@ class HeterogeneousProcessorWrapper(LogitsProcessor): self, processors: Dict[int, Union[LogitsProcessor, LogitsWarper]], ): - self.processors=processors - self.max_index=max(processors) + self.processors = processors + self.max_index = max(processors) def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: for i, processor in self.processors.items(): - scores[i:i+1]=processor(input_ids[i:i+1], scores[i:i+1]) + scores[i : i + 1] = processor(input_ids[i : i + 1], scores[i : i + 1]) return scores @@ -275,9 +276,15 @@ class HeterogeneousNextTokenChooser: watermark = self._standardize(watermark, batch_size, False) if any(watermark): - warpers.append(HeterogeneousProcessorWrapper( - {i:WatermarkLogitsProcessor(device=device) for i, x in watermark if x} - )) + warpers.append( + HeterogeneousProcessorWrapper( + { + i: WatermarkLogitsProcessor(device=device) + for i, x in watermark + if x + } + ) + ) repetition_penalty = self._standardize(repetition_penalty, batch_size, 1.0) if any([x != 1.0 for x in repetition_penalty]):