mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Fix mamba load.
This commit is contained in:
parent
3caa9b9cb7
commit
8319e854c8
@ -6,82 +6,97 @@
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 5089,
|
||||
"id": 1276,
|
||||
"logprob": null,
|
||||
"text": "Test"
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 2748,
|
||||
"logprob": -9.7265625,
|
||||
"text": " request"
|
||||
"id": 310,
|
||||
"logprob": -0.8125,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 18147,
|
||||
"logprob": -12.828125,
|
||||
"text": " Deep"
|
||||
},
|
||||
{
|
||||
"id": 20727,
|
||||
"logprob": -3.0,
|
||||
"text": " Learning"
|
||||
},
|
||||
{
|
||||
"id": 32,
|
||||
"logprob": -1.1484375,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 187,
|
||||
"logprob": -2.4746094,
|
||||
"logprob": -0.3552246,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 187,
|
||||
"logprob": -1.3857422,
|
||||
"logprob": -0.38989258,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 510,
|
||||
"logprob": -2.703125,
|
||||
"id": 30763,
|
||||
"logprob": -1.1386719,
|
||||
"special": false,
|
||||
"text": "The"
|
||||
"text": "Deep"
|
||||
},
|
||||
{
|
||||
"id": 806,
|
||||
"logprob": -4.1992188,
|
||||
"id": 4715,
|
||||
"logprob": -0.5576172,
|
||||
"special": false,
|
||||
"text": " first"
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 2181,
|
||||
"logprob": -2.703125,
|
||||
"id": 310,
|
||||
"logprob": -0.5913086,
|
||||
"special": false,
|
||||
"text": " thing"
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 309,
|
||||
"logprob": -1.4160156,
|
||||
"id": 247,
|
||||
"logprob": -0.69970703,
|
||||
"special": false,
|
||||
"text": " I"
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 8344,
|
||||
"logprob": -1.6171875,
|
||||
"id": 747,
|
||||
"logprob": -2.0449219,
|
||||
"special": false,
|
||||
"text": " noticed"
|
||||
"text": " new"
|
||||
},
|
||||
{
|
||||
"id": 369,
|
||||
"logprob": -1.0039062,
|
||||
"id": 1511,
|
||||
"logprob": -2.3847656,
|
||||
"special": false,
|
||||
"text": " was"
|
||||
"text": " type"
|
||||
},
|
||||
{
|
||||
"id": 326,
|
||||
"logprob": -0.8823242,
|
||||
"id": 273,
|
||||
"logprob": -0.0026626587,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 253,
|
||||
"logprob": -1.3173828,
|
||||
"id": 5145,
|
||||
"logprob": -1.2841797,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
"text": " machine"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\nThe first thing I noticed was that the"
|
||||
"generated_text": "\n\nDeep learning is a new type of machine"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
@ -90,82 +105,97 @@
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 5089,
|
||||
"id": 1276,
|
||||
"logprob": null,
|
||||
"text": "Test"
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 2748,
|
||||
"logprob": -9.7265625,
|
||||
"text": " request"
|
||||
"id": 310,
|
||||
"logprob": -0.78027344,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 18147,
|
||||
"logprob": -12.8203125,
|
||||
"text": " Deep"
|
||||
},
|
||||
{
|
||||
"id": 20727,
|
||||
"logprob": -2.9902344,
|
||||
"text": " Learning"
|
||||
},
|
||||
{
|
||||
"id": 32,
|
||||
"logprob": -1.1523438,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 187,
|
||||
"logprob": -2.4941406,
|
||||
"logprob": -0.35351562,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 187,
|
||||
"logprob": -1.3857422,
|
||||
"logprob": -0.38476562,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 510,
|
||||
"logprob": -2.703125,
|
||||
"id": 30763,
|
||||
"logprob": -1.1308594,
|
||||
"special": false,
|
||||
"text": "The"
|
||||
"text": "Deep"
|
||||
},
|
||||
{
|
||||
"id": 806,
|
||||
"logprob": -4.1992188,
|
||||
"id": 4715,
|
||||
"logprob": -0.5522461,
|
||||
"special": false,
|
||||
"text": " first"
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 2181,
|
||||
"logprob": -2.703125,
|
||||
"id": 310,
|
||||
"logprob": -0.59375,
|
||||
"special": false,
|
||||
"text": " thing"
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 309,
|
||||
"logprob": -1.4160156,
|
||||
"id": 247,
|
||||
"logprob": -0.7036133,
|
||||
"special": false,
|
||||
"text": " I"
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 8344,
|
||||
"logprob": -1.6171875,
|
||||
"id": 747,
|
||||
"logprob": -2.0507812,
|
||||
"special": false,
|
||||
"text": " noticed"
|
||||
"text": " new"
|
||||
},
|
||||
{
|
||||
"id": 369,
|
||||
"logprob": -1.0039062,
|
||||
"id": 1511,
|
||||
"logprob": -2.3808594,
|
||||
"special": false,
|
||||
"text": " was"
|
||||
"text": " type"
|
||||
},
|
||||
{
|
||||
"id": 326,
|
||||
"logprob": -0.8823242,
|
||||
"id": 273,
|
||||
"logprob": -0.002664566,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 253,
|
||||
"logprob": -1.3173828,
|
||||
"id": 5145,
|
||||
"logprob": -1.2851562,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
"text": " machine"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\nThe first thing I noticed was that the"
|
||||
"generated_text": "\n\nDeep learning is a new type of machine"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
@ -174,82 +204,97 @@
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 5089,
|
||||
"id": 1276,
|
||||
"logprob": null,
|
||||
"text": "Test"
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 2748,
|
||||
"logprob": -9.7265625,
|
||||
"text": " request"
|
||||
"id": 310,
|
||||
"logprob": -0.78027344,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 18147,
|
||||
"logprob": -12.8203125,
|
||||
"text": " Deep"
|
||||
},
|
||||
{
|
||||
"id": 20727,
|
||||
"logprob": -2.9902344,
|
||||
"text": " Learning"
|
||||
},
|
||||
{
|
||||
"id": 32,
|
||||
"logprob": -1.1523438,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 187,
|
||||
"logprob": -2.4941406,
|
||||
"logprob": -0.35351562,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 187,
|
||||
"logprob": -1.3857422,
|
||||
"logprob": -0.38476562,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 510,
|
||||
"logprob": -2.703125,
|
||||
"id": 30763,
|
||||
"logprob": -1.1308594,
|
||||
"special": false,
|
||||
"text": "The"
|
||||
"text": "Deep"
|
||||
},
|
||||
{
|
||||
"id": 806,
|
||||
"logprob": -4.1992188,
|
||||
"id": 4715,
|
||||
"logprob": -0.5522461,
|
||||
"special": false,
|
||||
"text": " first"
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 2181,
|
||||
"logprob": -2.703125,
|
||||
"id": 310,
|
||||
"logprob": -0.59375,
|
||||
"special": false,
|
||||
"text": " thing"
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 309,
|
||||
"logprob": -1.4160156,
|
||||
"id": 247,
|
||||
"logprob": -0.7036133,
|
||||
"special": false,
|
||||
"text": " I"
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 8344,
|
||||
"logprob": -1.6171875,
|
||||
"id": 747,
|
||||
"logprob": -2.0507812,
|
||||
"special": false,
|
||||
"text": " noticed"
|
||||
"text": " new"
|
||||
},
|
||||
{
|
||||
"id": 369,
|
||||
"logprob": -1.0039062,
|
||||
"id": 1511,
|
||||
"logprob": -2.3808594,
|
||||
"special": false,
|
||||
"text": " was"
|
||||
"text": " type"
|
||||
},
|
||||
{
|
||||
"id": 326,
|
||||
"logprob": -0.8823242,
|
||||
"id": 273,
|
||||
"logprob": -0.002664566,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 253,
|
||||
"logprob": -1.3173828,
|
||||
"id": 5145,
|
||||
"logprob": -1.2851562,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
"text": " machine"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\nThe first thing I noticed was that the"
|
||||
"generated_text": "\n\nDeep learning is a new type of machine"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
@ -258,81 +303,96 @@
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 5089,
|
||||
"id": 1276,
|
||||
"logprob": null,
|
||||
"text": "Test"
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 2748,
|
||||
"logprob": -9.7265625,
|
||||
"text": " request"
|
||||
"id": 310,
|
||||
"logprob": -0.78027344,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 18147,
|
||||
"logprob": -12.8203125,
|
||||
"text": " Deep"
|
||||
},
|
||||
{
|
||||
"id": 20727,
|
||||
"logprob": -2.9902344,
|
||||
"text": " Learning"
|
||||
},
|
||||
{
|
||||
"id": 32,
|
||||
"logprob": -1.1523438,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 187,
|
||||
"logprob": -2.4941406,
|
||||
"logprob": -0.35351562,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 187,
|
||||
"logprob": -1.3857422,
|
||||
"logprob": -0.38476562,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 510,
|
||||
"logprob": -2.703125,
|
||||
"id": 30763,
|
||||
"logprob": -1.1308594,
|
||||
"special": false,
|
||||
"text": "The"
|
||||
"text": "Deep"
|
||||
},
|
||||
{
|
||||
"id": 806,
|
||||
"logprob": -4.1992188,
|
||||
"id": 4715,
|
||||
"logprob": -0.5522461,
|
||||
"special": false,
|
||||
"text": " first"
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 2181,
|
||||
"logprob": -2.703125,
|
||||
"id": 310,
|
||||
"logprob": -0.59375,
|
||||
"special": false,
|
||||
"text": " thing"
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 309,
|
||||
"logprob": -1.4160156,
|
||||
"id": 247,
|
||||
"logprob": -0.7036133,
|
||||
"special": false,
|
||||
"text": " I"
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 8344,
|
||||
"logprob": -1.6171875,
|
||||
"id": 747,
|
||||
"logprob": -2.0507812,
|
||||
"special": false,
|
||||
"text": " noticed"
|
||||
"text": " new"
|
||||
},
|
||||
{
|
||||
"id": 369,
|
||||
"logprob": -1.0039062,
|
||||
"id": 1511,
|
||||
"logprob": -2.3808594,
|
||||
"special": false,
|
||||
"text": " was"
|
||||
"text": " type"
|
||||
},
|
||||
{
|
||||
"id": 326,
|
||||
"logprob": -0.8823242,
|
||||
"id": 273,
|
||||
"logprob": -0.002664566,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 253,
|
||||
"logprob": -1.3173828,
|
||||
"id": 5145,
|
||||
"logprob": -1.2851562,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
"text": " machine"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\nThe first thing I noticed was that the"
|
||||
"generated_text": "\n\nDeep learning is a new type of machine"
|
||||
}
|
||||
]
|
||||
|
@ -17,10 +17,11 @@ async def fused_kernel_mamba(fused_kernel_mamba_handle):
|
||||
@pytest.mark.private
|
||||
async def test_fused_kernel_mamba(fused_kernel_mamba, response_snapshot):
|
||||
response = await fused_kernel_mamba.generate(
|
||||
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||
"What is Deep Learning?", max_new_tokens=10
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response.generated_text == "\n\nDeep learning is a new type of machine"
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@ -50,9 +51,10 @@ async def test_fused_kernel_mamba_all_params(fused_kernel_mamba, response_snapsh
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_fused_kernel_mamba_load(fused_kernel_mamba, generate_load, response_snapshot):
|
||||
responses = await generate_load(fused_kernel_mamba, "Test request", max_new_tokens=10, n=4)
|
||||
responses = await generate_load(fused_kernel_mamba, "What is Deep Learning?", max_new_tokens=10, n=4)
|
||||
|
||||
assert len(responses) == 4
|
||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||
assert responses[0].generated_text == "\n\nDeep learning is a new type of machine"
|
||||
|
||||
assert responses == response_snapshot
|
||||
|
@ -169,6 +169,7 @@ class MambaBatch(Batch):
|
||||
total_remaining_decode_tokens = 0
|
||||
new_padding_right_offset = 0
|
||||
|
||||
indices = []
|
||||
for i, request_id in enumerate(request_ids):
|
||||
idx = self.requests_idx_mapping[request_id]
|
||||
requests_idx_mapping[request_id] = i
|
||||
@ -182,6 +183,7 @@ class MambaBatch(Batch):
|
||||
request_input_length = self.input_lengths[idx]
|
||||
input_lengths.append(request_input_length)
|
||||
max_input_length = max(max_input_length, request_input_length)
|
||||
indices.append(idx)
|
||||
|
||||
next_token_choosers.append(self.next_token_choosers[idx])
|
||||
stopping_criteria = self.stopping_criterias[idx]
|
||||
@ -216,6 +218,13 @@ class MambaBatch(Batch):
|
||||
self.padding_right_offset = new_padding_right_offset
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
# TODO
|
||||
# Kept it simple by just updating the state, maybe updating the other CPU values is necessary.
|
||||
key_value_memory_dict = {}
|
||||
for i, (conv_state, ssm_state) in self.inference_params.key_value_memory_dict.items():
|
||||
key_value_memory_dict[i] = (conv_state[indices], ssm_state[indices])
|
||||
self.inference_params.key_value_memory_dict = key_value_memory_dict
|
||||
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
@ -240,6 +249,9 @@ class MambaBatch(Batch):
|
||||
stopping_criterias = []
|
||||
top_n_tokens = []
|
||||
max_tokens = 0
|
||||
max_seqlen = 0
|
||||
batch_size = 0
|
||||
seqlen_offset = 0
|
||||
|
||||
# Batch tensors
|
||||
input_ids = None
|
||||
@ -287,8 +299,60 @@ class MambaBatch(Batch):
|
||||
max_input_length - batch.max_input_length
|
||||
) * len(batch)
|
||||
|
||||
max_seqlen = max(max_seqlen, batch.inference_params.max_seqlen)
|
||||
seqlen_offset = max(seqlen_offset, batch.inference_params.seqlen_offset)
|
||||
batch_size += batch.inference_params.max_batch_size
|
||||
|
||||
start_index = end_index
|
||||
|
||||
|
||||
(_, d_model, d_conv) = batches[0].inference_params.key_value_memory_dict[0][0].shape
|
||||
(_, _, d_state) = batches[0].inference_params.key_value_memory_dict[0][1].shape
|
||||
n_blocks = len(batches[0].inference_params.key_value_memory_dict)
|
||||
dtype = batches[0].inference_params.key_value_memory_dict[0][0].dtype
|
||||
device = batches[0].inference_params.key_value_memory_dict[0][0].device
|
||||
|
||||
key_value_memory_dict = {}
|
||||
for i in range(n_blocks):
|
||||
conv_state = torch.zeros(
|
||||
batch_size,
|
||||
d_model,
|
||||
d_conv,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
ssm_state = torch.zeros(
|
||||
batch_size,
|
||||
d_model,
|
||||
d_state,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
key_value_memory_dict[i] = (conv_state, ssm_state)
|
||||
lengths_per_sample = torch.zeros(batch_size, dtype=torch.int32, device=device)
|
||||
|
||||
inference_params = InferenceParams(
|
||||
max_seqlen=max_seqlen,
|
||||
max_batch_size=batch_size,
|
||||
seqlen_offset=seqlen_offset,
|
||||
key_value_memory_dict=key_value_memory_dict,
|
||||
lengths_per_sample=lengths_per_sample,
|
||||
)
|
||||
|
||||
current_batch = 0
|
||||
for batch in batches:
|
||||
for i in range(n_blocks):
|
||||
conv_state, ssm_state = batch.inference_params.key_value_memory_dict[i]
|
||||
batch_size = batch.inference_params.max_batch_size
|
||||
try:
|
||||
inference_params.key_value_memory_dict[i][0][current_batch:current_batch + batch_size] = conv_state
|
||||
inference_params.key_value_memory_dict[i][1][current_batch:current_batch + batch_size] = ssm_state
|
||||
except Exception:
|
||||
import ipdb;ipdb.set_trace()
|
||||
pass
|
||||
inference_params.lengths_per_sample[current_batch: current_batch + batch_size] = batch.inference_params.lengths_per_sample
|
||||
current_batch += batch_size
|
||||
|
||||
return cls(
|
||||
batch_id=batches[0].batch_id,
|
||||
requests=requests,
|
||||
@ -306,6 +370,7 @@ class MambaBatch(Batch):
|
||||
padding_right_offset=padding_right_offset,
|
||||
keys_head_dim_last=batches[0].keys_head_dim_last,
|
||||
max_tokens=max_tokens,
|
||||
inference_params=inference_params
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
@ -380,7 +445,6 @@ class Mamba(Model):
|
||||
|
||||
def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]:
|
||||
start = time.time_ns()
|
||||
|
||||
input_ids = batch.input_ids # batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids
|
||||
|
||||
batch_size = input_ids.shape[0]
|
||||
|
Loading…
Reference in New Issue
Block a user