Fix mamba load.

This commit is contained in:
Nicolas Patry 2024-02-06 18:57:24 +00:00
parent 3caa9b9cb7
commit 8319e854c8
3 changed files with 257 additions and 131 deletions

View File

@ -6,82 +6,97 @@
"generated_tokens": 10,
"prefill": [
{
"id": 5089,
"id": 1276,
"logprob": null,
"text": "Test"
"text": "What"
},
{
"id": 2748,
"logprob": -9.7265625,
"text": " request"
"id": 310,
"logprob": -0.8125,
"text": " is"
},
{
"id": 18147,
"logprob": -12.828125,
"text": " Deep"
},
{
"id": 20727,
"logprob": -3.0,
"text": " Learning"
},
{
"id": 32,
"logprob": -1.1484375,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 187,
"logprob": -2.4746094,
"logprob": -0.3552246,
"special": false,
"text": "\n"
},
{
"id": 187,
"logprob": -1.3857422,
"logprob": -0.38989258,
"special": false,
"text": "\n"
},
{
"id": 510,
"logprob": -2.703125,
"id": 30763,
"logprob": -1.1386719,
"special": false,
"text": "The"
"text": "Deep"
},
{
"id": 806,
"logprob": -4.1992188,
"id": 4715,
"logprob": -0.5576172,
"special": false,
"text": " first"
"text": " learning"
},
{
"id": 2181,
"logprob": -2.703125,
"id": 310,
"logprob": -0.5913086,
"special": false,
"text": " thing"
"text": " is"
},
{
"id": 309,
"logprob": -1.4160156,
"id": 247,
"logprob": -0.69970703,
"special": false,
"text": " I"
"text": " a"
},
{
"id": 8344,
"logprob": -1.6171875,
"id": 747,
"logprob": -2.0449219,
"special": false,
"text": " noticed"
"text": " new"
},
{
"id": 369,
"logprob": -1.0039062,
"id": 1511,
"logprob": -2.3847656,
"special": false,
"text": " was"
"text": " type"
},
{
"id": 326,
"logprob": -0.8823242,
"id": 273,
"logprob": -0.0026626587,
"special": false,
"text": " that"
"text": " of"
},
{
"id": 253,
"logprob": -1.3173828,
"id": 5145,
"logprob": -1.2841797,
"special": false,
"text": " the"
"text": " machine"
}
],
"top_tokens": null
},
"generated_text": "\n\nThe first thing I noticed was that the"
"generated_text": "\n\nDeep learning is a new type of machine"
},
{
"details": {
@ -90,82 +105,97 @@
"generated_tokens": 10,
"prefill": [
{
"id": 5089,
"id": 1276,
"logprob": null,
"text": "Test"
"text": "What"
},
{
"id": 2748,
"logprob": -9.7265625,
"text": " request"
"id": 310,
"logprob": -0.78027344,
"text": " is"
},
{
"id": 18147,
"logprob": -12.8203125,
"text": " Deep"
},
{
"id": 20727,
"logprob": -2.9902344,
"text": " Learning"
},
{
"id": 32,
"logprob": -1.1523438,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 187,
"logprob": -2.4941406,
"logprob": -0.35351562,
"special": false,
"text": "\n"
},
{
"id": 187,
"logprob": -1.3857422,
"logprob": -0.38476562,
"special": false,
"text": "\n"
},
{
"id": 510,
"logprob": -2.703125,
"id": 30763,
"logprob": -1.1308594,
"special": false,
"text": "The"
"text": "Deep"
},
{
"id": 806,
"logprob": -4.1992188,
"id": 4715,
"logprob": -0.5522461,
"special": false,
"text": " first"
"text": " learning"
},
{
"id": 2181,
"logprob": -2.703125,
"id": 310,
"logprob": -0.59375,
"special": false,
"text": " thing"
"text": " is"
},
{
"id": 309,
"logprob": -1.4160156,
"id": 247,
"logprob": -0.7036133,
"special": false,
"text": " I"
"text": " a"
},
{
"id": 8344,
"logprob": -1.6171875,
"id": 747,
"logprob": -2.0507812,
"special": false,
"text": " noticed"
"text": " new"
},
{
"id": 369,
"logprob": -1.0039062,
"id": 1511,
"logprob": -2.3808594,
"special": false,
"text": " was"
"text": " type"
},
{
"id": 326,
"logprob": -0.8823242,
"id": 273,
"logprob": -0.002664566,
"special": false,
"text": " that"
"text": " of"
},
{
"id": 253,
"logprob": -1.3173828,
"id": 5145,
"logprob": -1.2851562,
"special": false,
"text": " the"
"text": " machine"
}
],
"top_tokens": null
},
"generated_text": "\n\nThe first thing I noticed was that the"
"generated_text": "\n\nDeep learning is a new type of machine"
},
{
"details": {
@ -174,82 +204,97 @@
"generated_tokens": 10,
"prefill": [
{
"id": 5089,
"id": 1276,
"logprob": null,
"text": "Test"
"text": "What"
},
{
"id": 2748,
"logprob": -9.7265625,
"text": " request"
"id": 310,
"logprob": -0.78027344,
"text": " is"
},
{
"id": 18147,
"logprob": -12.8203125,
"text": " Deep"
},
{
"id": 20727,
"logprob": -2.9902344,
"text": " Learning"
},
{
"id": 32,
"logprob": -1.1523438,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 187,
"logprob": -2.4941406,
"logprob": -0.35351562,
"special": false,
"text": "\n"
},
{
"id": 187,
"logprob": -1.3857422,
"logprob": -0.38476562,
"special": false,
"text": "\n"
},
{
"id": 510,
"logprob": -2.703125,
"id": 30763,
"logprob": -1.1308594,
"special": false,
"text": "The"
"text": "Deep"
},
{
"id": 806,
"logprob": -4.1992188,
"id": 4715,
"logprob": -0.5522461,
"special": false,
"text": " first"
"text": " learning"
},
{
"id": 2181,
"logprob": -2.703125,
"id": 310,
"logprob": -0.59375,
"special": false,
"text": " thing"
"text": " is"
},
{
"id": 309,
"logprob": -1.4160156,
"id": 247,
"logprob": -0.7036133,
"special": false,
"text": " I"
"text": " a"
},
{
"id": 8344,
"logprob": -1.6171875,
"id": 747,
"logprob": -2.0507812,
"special": false,
"text": " noticed"
"text": " new"
},
{
"id": 369,
"logprob": -1.0039062,
"id": 1511,
"logprob": -2.3808594,
"special": false,
"text": " was"
"text": " type"
},
{
"id": 326,
"logprob": -0.8823242,
"id": 273,
"logprob": -0.002664566,
"special": false,
"text": " that"
"text": " of"
},
{
"id": 253,
"logprob": -1.3173828,
"id": 5145,
"logprob": -1.2851562,
"special": false,
"text": " the"
"text": " machine"
}
],
"top_tokens": null
},
"generated_text": "\n\nThe first thing I noticed was that the"
"generated_text": "\n\nDeep learning is a new type of machine"
},
{
"details": {
@ -258,81 +303,96 @@
"generated_tokens": 10,
"prefill": [
{
"id": 5089,
"id": 1276,
"logprob": null,
"text": "Test"
"text": "What"
},
{
"id": 2748,
"logprob": -9.7265625,
"text": " request"
"id": 310,
"logprob": -0.78027344,
"text": " is"
},
{
"id": 18147,
"logprob": -12.8203125,
"text": " Deep"
},
{
"id": 20727,
"logprob": -2.9902344,
"text": " Learning"
},
{
"id": 32,
"logprob": -1.1523438,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 187,
"logprob": -2.4941406,
"logprob": -0.35351562,
"special": false,
"text": "\n"
},
{
"id": 187,
"logprob": -1.3857422,
"logprob": -0.38476562,
"special": false,
"text": "\n"
},
{
"id": 510,
"logprob": -2.703125,
"id": 30763,
"logprob": -1.1308594,
"special": false,
"text": "The"
"text": "Deep"
},
{
"id": 806,
"logprob": -4.1992188,
"id": 4715,
"logprob": -0.5522461,
"special": false,
"text": " first"
"text": " learning"
},
{
"id": 2181,
"logprob": -2.703125,
"id": 310,
"logprob": -0.59375,
"special": false,
"text": " thing"
"text": " is"
},
{
"id": 309,
"logprob": -1.4160156,
"id": 247,
"logprob": -0.7036133,
"special": false,
"text": " I"
"text": " a"
},
{
"id": 8344,
"logprob": -1.6171875,
"id": 747,
"logprob": -2.0507812,
"special": false,
"text": " noticed"
"text": " new"
},
{
"id": 369,
"logprob": -1.0039062,
"id": 1511,
"logprob": -2.3808594,
"special": false,
"text": " was"
"text": " type"
},
{
"id": 326,
"logprob": -0.8823242,
"id": 273,
"logprob": -0.002664566,
"special": false,
"text": " that"
"text": " of"
},
{
"id": 253,
"logprob": -1.3173828,
"id": 5145,
"logprob": -1.2851562,
"special": false,
"text": " the"
"text": " machine"
}
],
"top_tokens": null
},
"generated_text": "\n\nThe first thing I noticed was that the"
"generated_text": "\n\nDeep learning is a new type of machine"
}
]

View File

@ -17,10 +17,11 @@ async def fused_kernel_mamba(fused_kernel_mamba_handle):
@pytest.mark.private
async def test_fused_kernel_mamba(fused_kernel_mamba, response_snapshot):
response = await fused_kernel_mamba.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
"What is Deep Learning?", max_new_tokens=10
)
assert response.details.generated_tokens == 10
assert response.generated_text == "\n\nDeep learning is a new type of machine"
assert response == response_snapshot
@ -50,9 +51,10 @@ async def test_fused_kernel_mamba_all_params(fused_kernel_mamba, response_snapsh
@pytest.mark.asyncio
@pytest.mark.private
async def test_fused_kernel_mamba_load(fused_kernel_mamba, generate_load, response_snapshot):
responses = await generate_load(fused_kernel_mamba, "Test request", max_new_tokens=10, n=4)
responses = await generate_load(fused_kernel_mamba, "What is Deep Learning?", max_new_tokens=10, n=4)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses[0].generated_text == "\n\nDeep learning is a new type of machine"
assert responses == response_snapshot

View File

@ -169,6 +169,7 @@ class MambaBatch(Batch):
total_remaining_decode_tokens = 0
new_padding_right_offset = 0
indices = []
for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id]
requests_idx_mapping[request_id] = i
@ -182,6 +183,7 @@ class MambaBatch(Batch):
request_input_length = self.input_lengths[idx]
input_lengths.append(request_input_length)
max_input_length = max(max_input_length, request_input_length)
indices.append(idx)
next_token_choosers.append(self.next_token_choosers[idx])
stopping_criteria = self.stopping_criterias[idx]
@ -216,6 +218,13 @@ class MambaBatch(Batch):
self.padding_right_offset = new_padding_right_offset
self.max_tokens = max_tokens
# TODO
# Kept it simple by just updating the state, maybe updating the other CPU values is necessary.
key_value_memory_dict = {}
for i, (conv_state, ssm_state) in self.inference_params.key_value_memory_dict.items():
key_value_memory_dict[i] = (conv_state[indices], ssm_state[indices])
self.inference_params.key_value_memory_dict = key_value_memory_dict
return self
@classmethod
@ -240,6 +249,9 @@ class MambaBatch(Batch):
stopping_criterias = []
top_n_tokens = []
max_tokens = 0
max_seqlen = 0
batch_size = 0
seqlen_offset = 0
# Batch tensors
input_ids = None
@ -287,8 +299,60 @@ class MambaBatch(Batch):
max_input_length - batch.max_input_length
) * len(batch)
max_seqlen = max(max_seqlen, batch.inference_params.max_seqlen)
seqlen_offset = max(seqlen_offset, batch.inference_params.seqlen_offset)
batch_size += batch.inference_params.max_batch_size
start_index = end_index
(_, d_model, d_conv) = batches[0].inference_params.key_value_memory_dict[0][0].shape
(_, _, d_state) = batches[0].inference_params.key_value_memory_dict[0][1].shape
n_blocks = len(batches[0].inference_params.key_value_memory_dict)
dtype = batches[0].inference_params.key_value_memory_dict[0][0].dtype
device = batches[0].inference_params.key_value_memory_dict[0][0].device
key_value_memory_dict = {}
for i in range(n_blocks):
conv_state = torch.zeros(
batch_size,
d_model,
d_conv,
device=device,
dtype=dtype,
)
ssm_state = torch.zeros(
batch_size,
d_model,
d_state,
device=device,
dtype=dtype,
)
key_value_memory_dict[i] = (conv_state, ssm_state)
lengths_per_sample = torch.zeros(batch_size, dtype=torch.int32, device=device)
inference_params = InferenceParams(
max_seqlen=max_seqlen,
max_batch_size=batch_size,
seqlen_offset=seqlen_offset,
key_value_memory_dict=key_value_memory_dict,
lengths_per_sample=lengths_per_sample,
)
current_batch = 0
for batch in batches:
for i in range(n_blocks):
conv_state, ssm_state = batch.inference_params.key_value_memory_dict[i]
batch_size = batch.inference_params.max_batch_size
try:
inference_params.key_value_memory_dict[i][0][current_batch:current_batch + batch_size] = conv_state
inference_params.key_value_memory_dict[i][1][current_batch:current_batch + batch_size] = ssm_state
except Exception:
import ipdb;ipdb.set_trace()
pass
inference_params.lengths_per_sample[current_batch: current_batch + batch_size] = batch.inference_params.lengths_per_sample
current_batch += batch_size
return cls(
batch_id=batches[0].batch_id,
requests=requests,
@ -306,6 +370,7 @@ class MambaBatch(Batch):
padding_right_offset=padding_right_offset,
keys_head_dim_last=batches[0].keys_head_dim_last,
max_tokens=max_tokens,
inference_params=inference_params
)
def __len__(self):
@ -380,7 +445,6 @@ class Mamba(Model):
def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]:
start = time.time_ns()
input_ids = batch.input_ids # batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids
batch_size = input_ids.shape[0]