From 8319e854c8122d6e7d2bf87943c8dca2a65aab63 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 6 Feb 2024 18:57:24 +0000 Subject: [PATCH] Fix mamba load. --- .../test_fused_kernel_mamba_load.json | 316 +++++++++++------- .../models/test_fused_kernel_mamba.py | 6 +- server/text_generation_server/models/mamba.py | 66 +++- 3 files changed, 257 insertions(+), 131 deletions(-) diff --git a/integration-tests/models/__snapshots__/test_fused_kernel_mamba/test_fused_kernel_mamba_load.json b/integration-tests/models/__snapshots__/test_fused_kernel_mamba/test_fused_kernel_mamba_load.json index 1d2bec30..830b9f59 100644 --- a/integration-tests/models/__snapshots__/test_fused_kernel_mamba/test_fused_kernel_mamba_load.json +++ b/integration-tests/models/__snapshots__/test_fused_kernel_mamba/test_fused_kernel_mamba_load.json @@ -6,82 +6,97 @@ "generated_tokens": 10, "prefill": [ { - "id": 5089, + "id": 1276, "logprob": null, - "text": "Test" + "text": "What" }, { - "id": 2748, - "logprob": -9.7265625, - "text": " request" + "id": 310, + "logprob": -0.8125, + "text": " is" + }, + { + "id": 18147, + "logprob": -12.828125, + "text": " Deep" + }, + { + "id": 20727, + "logprob": -3.0, + "text": " Learning" + }, + { + "id": 32, + "logprob": -1.1484375, + "text": "?" } ], "seed": null, "tokens": [ { "id": 187, - "logprob": -2.4746094, + "logprob": -0.3552246, "special": false, "text": "\n" }, { "id": 187, - "logprob": -1.3857422, + "logprob": -0.38989258, "special": false, "text": "\n" }, { - "id": 510, - "logprob": -2.703125, + "id": 30763, + "logprob": -1.1386719, "special": false, - "text": "The" + "text": "Deep" }, { - "id": 806, - "logprob": -4.1992188, + "id": 4715, + "logprob": -0.5576172, "special": false, - "text": " first" + "text": " learning" }, { - "id": 2181, - "logprob": -2.703125, + "id": 310, + "logprob": -0.5913086, "special": false, - "text": " thing" + "text": " is" }, { - "id": 309, - "logprob": -1.4160156, + "id": 247, + "logprob": -0.69970703, "special": false, - "text": " I" + "text": " a" }, { - "id": 8344, - "logprob": -1.6171875, + "id": 747, + "logprob": -2.0449219, "special": false, - "text": " noticed" + "text": " new" }, { - "id": 369, - "logprob": -1.0039062, + "id": 1511, + "logprob": -2.3847656, "special": false, - "text": " was" + "text": " type" }, { - "id": 326, - "logprob": -0.8823242, + "id": 273, + "logprob": -0.0026626587, "special": false, - "text": " that" + "text": " of" }, { - "id": 253, - "logprob": -1.3173828, + "id": 5145, + "logprob": -1.2841797, "special": false, - "text": " the" + "text": " machine" } ], "top_tokens": null }, - "generated_text": "\n\nThe first thing I noticed was that the" + "generated_text": "\n\nDeep learning is a new type of machine" }, { "details": { @@ -90,82 +105,97 @@ "generated_tokens": 10, "prefill": [ { - "id": 5089, + "id": 1276, "logprob": null, - "text": "Test" + "text": "What" }, { - "id": 2748, - "logprob": -9.7265625, - "text": " request" + "id": 310, + "logprob": -0.78027344, + "text": " is" + }, + { + "id": 18147, + "logprob": -12.8203125, + "text": " Deep" + }, + { + "id": 20727, + "logprob": -2.9902344, + "text": " Learning" + }, + { + "id": 32, + "logprob": -1.1523438, + "text": "?" } ], "seed": null, "tokens": [ { "id": 187, - "logprob": -2.4941406, + "logprob": -0.35351562, "special": false, "text": "\n" }, { "id": 187, - "logprob": -1.3857422, + "logprob": -0.38476562, "special": false, "text": "\n" }, { - "id": 510, - "logprob": -2.703125, + "id": 30763, + "logprob": -1.1308594, "special": false, - "text": "The" + "text": "Deep" }, { - "id": 806, - "logprob": -4.1992188, + "id": 4715, + "logprob": -0.5522461, "special": false, - "text": " first" + "text": " learning" }, { - "id": 2181, - "logprob": -2.703125, + "id": 310, + "logprob": -0.59375, "special": false, - "text": " thing" + "text": " is" }, { - "id": 309, - "logprob": -1.4160156, + "id": 247, + "logprob": -0.7036133, "special": false, - "text": " I" + "text": " a" }, { - "id": 8344, - "logprob": -1.6171875, + "id": 747, + "logprob": -2.0507812, "special": false, - "text": " noticed" + "text": " new" }, { - "id": 369, - "logprob": -1.0039062, + "id": 1511, + "logprob": -2.3808594, "special": false, - "text": " was" + "text": " type" }, { - "id": 326, - "logprob": -0.8823242, + "id": 273, + "logprob": -0.002664566, "special": false, - "text": " that" + "text": " of" }, { - "id": 253, - "logprob": -1.3173828, + "id": 5145, + "logprob": -1.2851562, "special": false, - "text": " the" + "text": " machine" } ], "top_tokens": null }, - "generated_text": "\n\nThe first thing I noticed was that the" + "generated_text": "\n\nDeep learning is a new type of machine" }, { "details": { @@ -174,82 +204,97 @@ "generated_tokens": 10, "prefill": [ { - "id": 5089, + "id": 1276, "logprob": null, - "text": "Test" + "text": "What" }, { - "id": 2748, - "logprob": -9.7265625, - "text": " request" + "id": 310, + "logprob": -0.78027344, + "text": " is" + }, + { + "id": 18147, + "logprob": -12.8203125, + "text": " Deep" + }, + { + "id": 20727, + "logprob": -2.9902344, + "text": " Learning" + }, + { + "id": 32, + "logprob": -1.1523438, + "text": "?" } ], "seed": null, "tokens": [ { "id": 187, - "logprob": -2.4941406, + "logprob": -0.35351562, "special": false, "text": "\n" }, { "id": 187, - "logprob": -1.3857422, + "logprob": -0.38476562, "special": false, "text": "\n" }, { - "id": 510, - "logprob": -2.703125, + "id": 30763, + "logprob": -1.1308594, "special": false, - "text": "The" + "text": "Deep" }, { - "id": 806, - "logprob": -4.1992188, + "id": 4715, + "logprob": -0.5522461, "special": false, - "text": " first" + "text": " learning" }, { - "id": 2181, - "logprob": -2.703125, + "id": 310, + "logprob": -0.59375, "special": false, - "text": " thing" + "text": " is" }, { - "id": 309, - "logprob": -1.4160156, + "id": 247, + "logprob": -0.7036133, "special": false, - "text": " I" + "text": " a" }, { - "id": 8344, - "logprob": -1.6171875, + "id": 747, + "logprob": -2.0507812, "special": false, - "text": " noticed" + "text": " new" }, { - "id": 369, - "logprob": -1.0039062, + "id": 1511, + "logprob": -2.3808594, "special": false, - "text": " was" + "text": " type" }, { - "id": 326, - "logprob": -0.8823242, + "id": 273, + "logprob": -0.002664566, "special": false, - "text": " that" + "text": " of" }, { - "id": 253, - "logprob": -1.3173828, + "id": 5145, + "logprob": -1.2851562, "special": false, - "text": " the" + "text": " machine" } ], "top_tokens": null }, - "generated_text": "\n\nThe first thing I noticed was that the" + "generated_text": "\n\nDeep learning is a new type of machine" }, { "details": { @@ -258,81 +303,96 @@ "generated_tokens": 10, "prefill": [ { - "id": 5089, + "id": 1276, "logprob": null, - "text": "Test" + "text": "What" }, { - "id": 2748, - "logprob": -9.7265625, - "text": " request" + "id": 310, + "logprob": -0.78027344, + "text": " is" + }, + { + "id": 18147, + "logprob": -12.8203125, + "text": " Deep" + }, + { + "id": 20727, + "logprob": -2.9902344, + "text": " Learning" + }, + { + "id": 32, + "logprob": -1.1523438, + "text": "?" } ], "seed": null, "tokens": [ { "id": 187, - "logprob": -2.4941406, + "logprob": -0.35351562, "special": false, "text": "\n" }, { "id": 187, - "logprob": -1.3857422, + "logprob": -0.38476562, "special": false, "text": "\n" }, { - "id": 510, - "logprob": -2.703125, + "id": 30763, + "logprob": -1.1308594, "special": false, - "text": "The" + "text": "Deep" }, { - "id": 806, - "logprob": -4.1992188, + "id": 4715, + "logprob": -0.5522461, "special": false, - "text": " first" + "text": " learning" }, { - "id": 2181, - "logprob": -2.703125, + "id": 310, + "logprob": -0.59375, "special": false, - "text": " thing" + "text": " is" }, { - "id": 309, - "logprob": -1.4160156, + "id": 247, + "logprob": -0.7036133, "special": false, - "text": " I" + "text": " a" }, { - "id": 8344, - "logprob": -1.6171875, + "id": 747, + "logprob": -2.0507812, "special": false, - "text": " noticed" + "text": " new" }, { - "id": 369, - "logprob": -1.0039062, + "id": 1511, + "logprob": -2.3808594, "special": false, - "text": " was" + "text": " type" }, { - "id": 326, - "logprob": -0.8823242, + "id": 273, + "logprob": -0.002664566, "special": false, - "text": " that" + "text": " of" }, { - "id": 253, - "logprob": -1.3173828, + "id": 5145, + "logprob": -1.2851562, "special": false, - "text": " the" + "text": " machine" } ], "top_tokens": null }, - "generated_text": "\n\nThe first thing I noticed was that the" + "generated_text": "\n\nDeep learning is a new type of machine" } ] diff --git a/integration-tests/models/test_fused_kernel_mamba.py b/integration-tests/models/test_fused_kernel_mamba.py index 327113b2..9bd0052f 100644 --- a/integration-tests/models/test_fused_kernel_mamba.py +++ b/integration-tests/models/test_fused_kernel_mamba.py @@ -17,10 +17,11 @@ async def fused_kernel_mamba(fused_kernel_mamba_handle): @pytest.mark.private async def test_fused_kernel_mamba(fused_kernel_mamba, response_snapshot): response = await fused_kernel_mamba.generate( - "Test request", max_new_tokens=10, decoder_input_details=True + "What is Deep Learning?", max_new_tokens=10 ) assert response.details.generated_tokens == 10 + assert response.generated_text == "\n\nDeep learning is a new type of machine" assert response == response_snapshot @@ -50,9 +51,10 @@ async def test_fused_kernel_mamba_all_params(fused_kernel_mamba, response_snapsh @pytest.mark.asyncio @pytest.mark.private async def test_fused_kernel_mamba_load(fused_kernel_mamba, generate_load, response_snapshot): - responses = await generate_load(fused_kernel_mamba, "Test request", max_new_tokens=10, n=4) + responses = await generate_load(fused_kernel_mamba, "What is Deep Learning?", max_new_tokens=10, n=4) assert len(responses) == 4 assert all([r.generated_text == responses[0].generated_text for r in responses]) + assert responses[0].generated_text == "\n\nDeep learning is a new type of machine" assert responses == response_snapshot diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index ca3e6e92..4750d90a 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -169,6 +169,7 @@ class MambaBatch(Batch): total_remaining_decode_tokens = 0 new_padding_right_offset = 0 + indices = [] for i, request_id in enumerate(request_ids): idx = self.requests_idx_mapping[request_id] requests_idx_mapping[request_id] = i @@ -182,6 +183,7 @@ class MambaBatch(Batch): request_input_length = self.input_lengths[idx] input_lengths.append(request_input_length) max_input_length = max(max_input_length, request_input_length) + indices.append(idx) next_token_choosers.append(self.next_token_choosers[idx]) stopping_criteria = self.stopping_criterias[idx] @@ -216,6 +218,13 @@ class MambaBatch(Batch): self.padding_right_offset = new_padding_right_offset self.max_tokens = max_tokens + # TODO + # Kept it simple by just updating the state, maybe updating the other CPU values is necessary. + key_value_memory_dict = {} + for i, (conv_state, ssm_state) in self.inference_params.key_value_memory_dict.items(): + key_value_memory_dict[i] = (conv_state[indices], ssm_state[indices]) + self.inference_params.key_value_memory_dict = key_value_memory_dict + return self @classmethod @@ -240,6 +249,9 @@ class MambaBatch(Batch): stopping_criterias = [] top_n_tokens = [] max_tokens = 0 + max_seqlen = 0 + batch_size = 0 + seqlen_offset = 0 # Batch tensors input_ids = None @@ -287,8 +299,60 @@ class MambaBatch(Batch): max_input_length - batch.max_input_length ) * len(batch) + max_seqlen = max(max_seqlen, batch.inference_params.max_seqlen) + seqlen_offset = max(seqlen_offset, batch.inference_params.seqlen_offset) + batch_size += batch.inference_params.max_batch_size + start_index = end_index + + (_, d_model, d_conv) = batches[0].inference_params.key_value_memory_dict[0][0].shape + (_, _, d_state) = batches[0].inference_params.key_value_memory_dict[0][1].shape + n_blocks = len(batches[0].inference_params.key_value_memory_dict) + dtype = batches[0].inference_params.key_value_memory_dict[0][0].dtype + device = batches[0].inference_params.key_value_memory_dict[0][0].device + + key_value_memory_dict = {} + for i in range(n_blocks): + conv_state = torch.zeros( + batch_size, + d_model, + d_conv, + device=device, + dtype=dtype, + ) + ssm_state = torch.zeros( + batch_size, + d_model, + d_state, + device=device, + dtype=dtype, + ) + key_value_memory_dict[i] = (conv_state, ssm_state) + lengths_per_sample = torch.zeros(batch_size, dtype=torch.int32, device=device) + + inference_params = InferenceParams( + max_seqlen=max_seqlen, + max_batch_size=batch_size, + seqlen_offset=seqlen_offset, + key_value_memory_dict=key_value_memory_dict, + lengths_per_sample=lengths_per_sample, + ) + + current_batch = 0 + for batch in batches: + for i in range(n_blocks): + conv_state, ssm_state = batch.inference_params.key_value_memory_dict[i] + batch_size = batch.inference_params.max_batch_size + try: + inference_params.key_value_memory_dict[i][0][current_batch:current_batch + batch_size] = conv_state + inference_params.key_value_memory_dict[i][1][current_batch:current_batch + batch_size] = ssm_state + except Exception: + import ipdb;ipdb.set_trace() + pass + inference_params.lengths_per_sample[current_batch: current_batch + batch_size] = batch.inference_params.lengths_per_sample + current_batch += batch_size + return cls( batch_id=batches[0].batch_id, requests=requests, @@ -306,6 +370,7 @@ class MambaBatch(Batch): padding_right_offset=padding_right_offset, keys_head_dim_last=batches[0].keys_head_dim_last, max_tokens=max_tokens, + inference_params=inference_params ) def __len__(self): @@ -380,7 +445,6 @@ class Mamba(Model): def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]: start = time.time_ns() - input_ids = batch.input_ids # batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids batch_size = input_ids.shape[0]