fix: revise non batching tests

This commit is contained in:
drbh 2024-02-03 05:04:00 +00:00
parent 3a42765cab
commit 0f124cbc52
4 changed files with 54 additions and 80 deletions

View File

@ -11,7 +11,7 @@
},
{
"id": 2748,
"logprob": -9.7421875,
"logprob": -9.7265625,
"text": " request"
}
],
@ -19,66 +19,66 @@
"tokens": [
{
"id": 187,
"logprob": -2.4824219,
"special": false,
"text": "\n"
},
{
"id": 187,
"logprob": -2.4824219,
"logprob": -2.4746094,
"special": false,
"text": "\n"
},
{
"id": 50274,
"logprob": -1.7880859,
"logprob": -1.7861328,
"special": false,
"text": " "
},
{
"id": 92,
"logprob": -2.0703125,
"logprob": -2.03125,
"special": false,
"text": "{"
},
{
"id": 187,
"logprob": -0.04827881,
"logprob": -0.048706055,
"special": false,
"text": "\n"
},
{
"id": 50270,
"logprob": -0.18896484,
"logprob": -0.19421387,
"special": false,
"text": " "
},
{
"id": 3,
"logprob": -1.5234375,
"logprob": -1.5097656,
"special": false,
"text": "\""
},
{
"id": 9629,
"logprob": -2.8203125,
"logprob": -2.7792969,
"special": false,
"text": "request"
},
{
"id": 1381,
"logprob": -0.78759766,
"logprob": -0.78271484,
"special": false,
"text": "\":"
},
{
"id": 551,
"logprob": -0.49169922,
"logprob": -0.49487305,
"special": false,
"text": " {"
},
{
"id": 187,
"logprob": -0.021087646,
"special": false,
"text": "\n"
}
],
"top_tokens": null
},
"generated_text": "\n\n {\n \"request\": {"
"generated_text": "\n {\n \"request\": {\n"
}

View File

@ -16,17 +16,17 @@
},
{
"id": 8862,
"logprob": -3.4746094,
"logprob": -3.4433594,
"text": " yellow"
},
{
"id": 13,
"logprob": -0.43579102,
"logprob": -0.43017578,
"text": ","
},
{
"id": 209,
"logprob": -8.2421875,
"logprob": -8.21875,
"text": " "
}
],
@ -39,28 +39,28 @@
"text": "\n"
},
{
"id": 2764,
"logprob": -0.37573242,
"id": 395,
"logprob": -0.46411133,
"special": false,
"text": "umber"
"text": "and"
},
{
"id": 285,
"logprob": 0.0,
"id": 13735,
"logprob": -2.1132812,
"special": false,
"text": " and"
"text": " orange"
},
{
"id": 3168,
"logprob": -0.9013672,
"id": 313,
"logprob": -1.2128906,
"special": false,
"text": " white"
"text": " ("
},
{
"id": 28,
"logprob": -1.2314453,
"id": 249,
"logprob": -2.3671875,
"special": false,
"text": ";"
"text": "in"
},
{
"id": 253,
@ -69,31 +69,31 @@
"text": " the"
},
{
"id": 3295,
"logprob": -1.2167969,
"id": 1340,
"logprob": -1.640625,
"special": false,
"text": " color"
"text": " order"
},
{
"id": 273,
"id": 597,
"logprob": -0.5488281,
"special": false,
"text": " they"
},
{
"id": 3176,
"logprob": -0.48608398,
"special": false,
"text": " appear"
},
{
"id": 275,
"logprob": 0.0,
"special": false,
"text": " of"
},
{
"id": 697,
"logprob": -2.1015625,
"special": false,
"text": " its"
},
{
"id": 17433,
"logprob": -2.4296875,
"special": false,
"text": " unders"
"text": " in"
}
],
"top_tokens": null
},
"generated_text": "blue, red, yellow, \number and white; the color of its unders"
"generated_text": "blue, red, yellow, \nand orange (in the order they appear in"
}

View File

@ -44,17 +44,14 @@ async def test_fused_kernel_mamba_all_params(fused_kernel_mamba, response_snapsh
)
assert response.details.generated_tokens == 10
# TODO: fix so the input is not included in the output
assert response.generated_text == "blue, red, yellow, \number and white; the color of its unders"
assert response.generated_text == "blue, red, yellow, \nand orange (in the order they appear in"
assert response == response_snapshot
# TODO: fix `Expected x0.dim() == 2 to be true, but got false.`
# 94: `hidden_states, _ = self.layer_norm(hidden_states.squeeze(0))`
# NOTE: the fast layer norm has strict requirements on the input shape
# TODO: Fix batching
# @pytest.mark.asyncio
# @pytest.mark.private
# async def test_fused_kernel_mamba_load(fused_kernel_mamba, generate_load, response_snapshot):
# responses = await generate_load(fused_kernel_mamba, "Test request", max_new_tokens=10, n=4)
# responses = await generate_load(fused_kernel_mamba, "Test request", max_new_tokens=10, n=2)
# assert len(responses) == 4
# assert all([r.generated_text == responses[0].generated_text for r in responses])

View File

@ -34,7 +34,6 @@ class MambaBatch(Batch):
# Decoder values
input_ids: torch.Tensor
past_input_ids: Optional[torch.Tensor]
# All tokens
all_input_ids: List[torch.Tensor]
@ -132,7 +131,7 @@ class MambaBatch(Batch):
requests=pb.requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
past_input_ids=None,
# past_input_ids=None,
all_input_ids=list(all_input_ids),
input_lengths=input_lengths.tolist(),
prefix_offsets=prefix_offsets,
@ -198,7 +197,6 @@ class MambaBatch(Batch):
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
input_ids = self.input_ids[keep_indices]
position_ids = self.position_ids[keep_indices]
top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
@ -245,9 +243,6 @@ class MambaBatch(Batch):
# Batch tensors
input_ids = None
attention_mask = None
position_ids = None
past_key_values = []
top_n_tokens_tensor = None
# Used for slicing correctly inside the tensors
@ -273,10 +268,6 @@ class MambaBatch(Batch):
# Slicing end index for this batch
end_index = start_index + len(batch)
# We only concatenate batches that did at least one step
if batch.past_key_values is None:
raise ValueError("only concatenate prefilled batches")
# Create empty tensor
# input_ids is always of shape [batch_size, 1]
# We do not need to pad it
@ -285,12 +276,6 @@ class MambaBatch(Batch):
# Copy to correct indices
input_ids[start_index:end_index] = batch.input_ids
# Create padded tensor
if attention_mask is None:
attention_mask = batch.attention_mask.new_zeros(
(total_batch_size, max_input_length + padding_right_offset),
)
if top_n_tokens_tensor is None:
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
total_batch_size,
@ -309,9 +294,6 @@ class MambaBatch(Batch):
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
all_input_ids=all_input_ids,
input_lengths=input_lengths,
prefix_offsets=prefix_offsets,
@ -555,8 +537,6 @@ class Mamba(Model):
)
else:
prefill_tokens = None
past_input_ids = torch.cat([past_input_ids, next_token_id], dim=1)
if top_n_tokens > 0:
toptoken_texts = self.tokenizer.batch_decode(
@ -608,9 +588,6 @@ class Mamba(Model):
# Slice unused values from prefill
batch.input_ids = batch.input_ids[:, :1]
batch.past_input_ids = past_input_ids
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, batch, (forward_ns, decode_ns)