From 0f124cbc52c6fe744fc34a293d8e0455db3c1e4a Mon Sep 17 00:00:00 2001 From: drbh Date: Sat, 3 Feb 2024 05:04:00 +0000 Subject: [PATCH] fix: revise non batching tests --- .../test_fused_kernel_mamba.json | 34 +++++----- .../test_fused_kernel_mamba_all_params.json | 66 +++++++++---------- .../models/test_fused_kernel_mamba.py | 9 +-- server/text_generation_server/models/mamba.py | 25 +------ 4 files changed, 54 insertions(+), 80 deletions(-) diff --git a/integration-tests/models/__snapshots__/test_fused_kernel_mamba/test_fused_kernel_mamba.json b/integration-tests/models/__snapshots__/test_fused_kernel_mamba/test_fused_kernel_mamba.json index ae6ee35e..d75c959f 100644 --- a/integration-tests/models/__snapshots__/test_fused_kernel_mamba/test_fused_kernel_mamba.json +++ b/integration-tests/models/__snapshots__/test_fused_kernel_mamba/test_fused_kernel_mamba.json @@ -11,7 +11,7 @@ }, { "id": 2748, - "logprob": -9.7421875, + "logprob": -9.7265625, "text": " request" } ], @@ -19,66 +19,66 @@ "tokens": [ { "id": 187, - "logprob": -2.4824219, - "special": false, - "text": "\n" - }, - { - "id": 187, - "logprob": -2.4824219, + "logprob": -2.4746094, "special": false, "text": "\n" }, { "id": 50274, - "logprob": -1.7880859, + "logprob": -1.7861328, "special": false, "text": " " }, { "id": 92, - "logprob": -2.0703125, + "logprob": -2.03125, "special": false, "text": "{" }, { "id": 187, - "logprob": -0.04827881, + "logprob": -0.048706055, "special": false, "text": "\n" }, { "id": 50270, - "logprob": -0.18896484, + "logprob": -0.19421387, "special": false, "text": " " }, { "id": 3, - "logprob": -1.5234375, + "logprob": -1.5097656, "special": false, "text": "\"" }, { "id": 9629, - "logprob": -2.8203125, + "logprob": -2.7792969, "special": false, "text": "request" }, { "id": 1381, - "logprob": -0.78759766, + "logprob": -0.78271484, "special": false, "text": "\":" }, { "id": 551, - "logprob": -0.49169922, + "logprob": -0.49487305, "special": false, "text": " {" + }, + { + "id": 187, + "logprob": -0.021087646, + "special": false, + "text": "\n" } ], "top_tokens": null }, - "generated_text": "\n\n {\n \"request\": {" + "generated_text": "\n {\n \"request\": {\n" } diff --git a/integration-tests/models/__snapshots__/test_fused_kernel_mamba/test_fused_kernel_mamba_all_params.json b/integration-tests/models/__snapshots__/test_fused_kernel_mamba/test_fused_kernel_mamba_all_params.json index 0ab7cf11..052c1c69 100644 --- a/integration-tests/models/__snapshots__/test_fused_kernel_mamba/test_fused_kernel_mamba_all_params.json +++ b/integration-tests/models/__snapshots__/test_fused_kernel_mamba/test_fused_kernel_mamba_all_params.json @@ -16,17 +16,17 @@ }, { "id": 8862, - "logprob": -3.4746094, + "logprob": -3.4433594, "text": " yellow" }, { "id": 13, - "logprob": -0.43579102, + "logprob": -0.43017578, "text": "," }, { "id": 209, - "logprob": -8.2421875, + "logprob": -8.21875, "text": " " } ], @@ -39,28 +39,28 @@ "text": "\n" }, { - "id": 2764, - "logprob": -0.37573242, + "id": 395, + "logprob": -0.46411133, "special": false, - "text": "umber" + "text": "and" }, { - "id": 285, - "logprob": 0.0, + "id": 13735, + "logprob": -2.1132812, "special": false, - "text": " and" + "text": " orange" }, { - "id": 3168, - "logprob": -0.9013672, + "id": 313, + "logprob": -1.2128906, "special": false, - "text": " white" + "text": " (" }, { - "id": 28, - "logprob": -1.2314453, + "id": 249, + "logprob": -2.3671875, "special": false, - "text": ";" + "text": "in" }, { "id": 253, @@ -69,31 +69,31 @@ "text": " the" }, { - "id": 3295, - "logprob": -1.2167969, + "id": 1340, + "logprob": -1.640625, "special": false, - "text": " color" + "text": " order" }, { - "id": 273, + "id": 597, + "logprob": -0.5488281, + "special": false, + "text": " they" + }, + { + "id": 3176, + "logprob": -0.48608398, + "special": false, + "text": " appear" + }, + { + "id": 275, "logprob": 0.0, "special": false, - "text": " of" - }, - { - "id": 697, - "logprob": -2.1015625, - "special": false, - "text": " its" - }, - { - "id": 17433, - "logprob": -2.4296875, - "special": false, - "text": " unders" + "text": " in" } ], "top_tokens": null }, - "generated_text": "blue, red, yellow, \number and white; the color of its unders" + "generated_text": "blue, red, yellow, \nand orange (in the order they appear in" } diff --git a/integration-tests/models/test_fused_kernel_mamba.py b/integration-tests/models/test_fused_kernel_mamba.py index 0a449332..98431298 100644 --- a/integration-tests/models/test_fused_kernel_mamba.py +++ b/integration-tests/models/test_fused_kernel_mamba.py @@ -44,17 +44,14 @@ async def test_fused_kernel_mamba_all_params(fused_kernel_mamba, response_snapsh ) assert response.details.generated_tokens == 10 - # TODO: fix so the input is not included in the output - assert response.generated_text == "blue, red, yellow, \number and white; the color of its unders" + assert response.generated_text == "blue, red, yellow, \nand orange (in the order they appear in" assert response == response_snapshot -# TODO: fix `Expected x0.dim() == 2 to be true, but got false.` -# 94: `hidden_states, _ = self.layer_norm(hidden_states.squeeze(0))` -# NOTE: the fast layer norm has strict requirements on the input shape +# TODO: Fix batching # @pytest.mark.asyncio # @pytest.mark.private # async def test_fused_kernel_mamba_load(fused_kernel_mamba, generate_load, response_snapshot): -# responses = await generate_load(fused_kernel_mamba, "Test request", max_new_tokens=10, n=4) +# responses = await generate_load(fused_kernel_mamba, "Test request", max_new_tokens=10, n=2) # assert len(responses) == 4 # assert all([r.generated_text == responses[0].generated_text for r in responses]) diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index f7d950e7..6c0568c3 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -34,7 +34,6 @@ class MambaBatch(Batch): # Decoder values input_ids: torch.Tensor - past_input_ids: Optional[torch.Tensor] # All tokens all_input_ids: List[torch.Tensor] @@ -132,7 +131,7 @@ class MambaBatch(Batch): requests=pb.requests, requests_idx_mapping=requests_idx_mapping, input_ids=input_ids, - past_input_ids=None, + # past_input_ids=None, all_input_ids=list(all_input_ids), input_lengths=input_lengths.tolist(), prefix_offsets=prefix_offsets, @@ -198,7 +197,6 @@ class MambaBatch(Batch): # Apply indices to input_ids, attention mask, past key values and other items that need to be cached input_ids = self.input_ids[keep_indices] - position_ids = self.position_ids[keep_indices] top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices] max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens @@ -245,9 +243,6 @@ class MambaBatch(Batch): # Batch tensors input_ids = None - attention_mask = None - position_ids = None - past_key_values = [] top_n_tokens_tensor = None # Used for slicing correctly inside the tensors @@ -273,10 +268,6 @@ class MambaBatch(Batch): # Slicing end index for this batch end_index = start_index + len(batch) - # We only concatenate batches that did at least one step - if batch.past_key_values is None: - raise ValueError("only concatenate prefilled batches") - # Create empty tensor # input_ids is always of shape [batch_size, 1] # We do not need to pad it @@ -285,12 +276,6 @@ class MambaBatch(Batch): # Copy to correct indices input_ids[start_index:end_index] = batch.input_ids - # Create padded tensor - if attention_mask is None: - attention_mask = batch.attention_mask.new_zeros( - (total_batch_size, max_input_length + padding_right_offset), - ) - if top_n_tokens_tensor is None: top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros( total_batch_size, @@ -309,9 +294,6 @@ class MambaBatch(Batch): requests=requests, requests_idx_mapping=requests_idx_mapping, input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, all_input_ids=all_input_ids, input_lengths=input_lengths, prefix_offsets=prefix_offsets, @@ -555,8 +537,6 @@ class Mamba(Model): ) else: prefill_tokens = None - past_input_ids = torch.cat([past_input_ids, next_token_id], dim=1) - if top_n_tokens > 0: toptoken_texts = self.tokenizer.batch_decode( @@ -608,9 +588,6 @@ class Mamba(Model): # Slice unused values from prefill batch.input_ids = batch.input_ids[:, :1] - - batch.past_input_ids = past_input_ids - forward_ns = start_decode - start decode_ns = time.time_ns() - start_decode return generations, batch, (forward_ns, decode_ns)