fix: revise non batching tests

This commit is contained in:
drbh 2024-02-03 05:04:00 +00:00
parent 3a42765cab
commit 0f124cbc52
4 changed files with 54 additions and 80 deletions

View File

@ -11,7 +11,7 @@
}, },
{ {
"id": 2748, "id": 2748,
"logprob": -9.7421875, "logprob": -9.7265625,
"text": " request" "text": " request"
} }
], ],
@ -19,66 +19,66 @@
"tokens": [ "tokens": [
{ {
"id": 187, "id": 187,
"logprob": -2.4824219, "logprob": -2.4746094,
"special": false,
"text": "\n"
},
{
"id": 187,
"logprob": -2.4824219,
"special": false, "special": false,
"text": "\n" "text": "\n"
}, },
{ {
"id": 50274, "id": 50274,
"logprob": -1.7880859, "logprob": -1.7861328,
"special": false, "special": false,
"text": " " "text": " "
}, },
{ {
"id": 92, "id": 92,
"logprob": -2.0703125, "logprob": -2.03125,
"special": false, "special": false,
"text": "{" "text": "{"
}, },
{ {
"id": 187, "id": 187,
"logprob": -0.04827881, "logprob": -0.048706055,
"special": false, "special": false,
"text": "\n" "text": "\n"
}, },
{ {
"id": 50270, "id": 50270,
"logprob": -0.18896484, "logprob": -0.19421387,
"special": false, "special": false,
"text": " " "text": " "
}, },
{ {
"id": 3, "id": 3,
"logprob": -1.5234375, "logprob": -1.5097656,
"special": false, "special": false,
"text": "\"" "text": "\""
}, },
{ {
"id": 9629, "id": 9629,
"logprob": -2.8203125, "logprob": -2.7792969,
"special": false, "special": false,
"text": "request" "text": "request"
}, },
{ {
"id": 1381, "id": 1381,
"logprob": -0.78759766, "logprob": -0.78271484,
"special": false, "special": false,
"text": "\":" "text": "\":"
}, },
{ {
"id": 551, "id": 551,
"logprob": -0.49169922, "logprob": -0.49487305,
"special": false, "special": false,
"text": " {" "text": " {"
},
{
"id": 187,
"logprob": -0.021087646,
"special": false,
"text": "\n"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "\n\n {\n \"request\": {" "generated_text": "\n {\n \"request\": {\n"
} }

View File

@ -16,17 +16,17 @@
}, },
{ {
"id": 8862, "id": 8862,
"logprob": -3.4746094, "logprob": -3.4433594,
"text": " yellow" "text": " yellow"
}, },
{ {
"id": 13, "id": 13,
"logprob": -0.43579102, "logprob": -0.43017578,
"text": "," "text": ","
}, },
{ {
"id": 209, "id": 209,
"logprob": -8.2421875, "logprob": -8.21875,
"text": " " "text": " "
} }
], ],
@ -39,28 +39,28 @@
"text": "\n" "text": "\n"
}, },
{ {
"id": 2764, "id": 395,
"logprob": -0.37573242, "logprob": -0.46411133,
"special": false, "special": false,
"text": "umber" "text": "and"
}, },
{ {
"id": 285, "id": 13735,
"logprob": 0.0, "logprob": -2.1132812,
"special": false, "special": false,
"text": " and" "text": " orange"
}, },
{ {
"id": 3168, "id": 313,
"logprob": -0.9013672, "logprob": -1.2128906,
"special": false, "special": false,
"text": " white" "text": " ("
}, },
{ {
"id": 28, "id": 249,
"logprob": -1.2314453, "logprob": -2.3671875,
"special": false, "special": false,
"text": ";" "text": "in"
}, },
{ {
"id": 253, "id": 253,
@ -69,31 +69,31 @@
"text": " the" "text": " the"
}, },
{ {
"id": 3295, "id": 1340,
"logprob": -1.2167969, "logprob": -1.640625,
"special": false, "special": false,
"text": " color" "text": " order"
}, },
{ {
"id": 273, "id": 597,
"logprob": -0.5488281,
"special": false,
"text": " they"
},
{
"id": 3176,
"logprob": -0.48608398,
"special": false,
"text": " appear"
},
{
"id": 275,
"logprob": 0.0, "logprob": 0.0,
"special": false, "special": false,
"text": " of" "text": " in"
},
{
"id": 697,
"logprob": -2.1015625,
"special": false,
"text": " its"
},
{
"id": 17433,
"logprob": -2.4296875,
"special": false,
"text": " unders"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "blue, red, yellow, \number and white; the color of its unders" "generated_text": "blue, red, yellow, \nand orange (in the order they appear in"
} }

View File

@ -44,17 +44,14 @@ async def test_fused_kernel_mamba_all_params(fused_kernel_mamba, response_snapsh
) )
assert response.details.generated_tokens == 10 assert response.details.generated_tokens == 10
# TODO: fix so the input is not included in the output assert response.generated_text == "blue, red, yellow, \nand orange (in the order they appear in"
assert response.generated_text == "blue, red, yellow, \number and white; the color of its unders"
assert response == response_snapshot assert response == response_snapshot
# TODO: fix `Expected x0.dim() == 2 to be true, but got false.` # TODO: Fix batching
# 94: `hidden_states, _ = self.layer_norm(hidden_states.squeeze(0))`
# NOTE: the fast layer norm has strict requirements on the input shape
# @pytest.mark.asyncio # @pytest.mark.asyncio
# @pytest.mark.private # @pytest.mark.private
# async def test_fused_kernel_mamba_load(fused_kernel_mamba, generate_load, response_snapshot): # async def test_fused_kernel_mamba_load(fused_kernel_mamba, generate_load, response_snapshot):
# responses = await generate_load(fused_kernel_mamba, "Test request", max_new_tokens=10, n=4) # responses = await generate_load(fused_kernel_mamba, "Test request", max_new_tokens=10, n=2)
# assert len(responses) == 4 # assert len(responses) == 4
# assert all([r.generated_text == responses[0].generated_text for r in responses]) # assert all([r.generated_text == responses[0].generated_text for r in responses])

View File

@ -34,7 +34,6 @@ class MambaBatch(Batch):
# Decoder values # Decoder values
input_ids: torch.Tensor input_ids: torch.Tensor
past_input_ids: Optional[torch.Tensor]
# All tokens # All tokens
all_input_ids: List[torch.Tensor] all_input_ids: List[torch.Tensor]
@ -132,7 +131,7 @@ class MambaBatch(Batch):
requests=pb.requests, requests=pb.requests,
requests_idx_mapping=requests_idx_mapping, requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids, input_ids=input_ids,
past_input_ids=None, # past_input_ids=None,
all_input_ids=list(all_input_ids), all_input_ids=list(all_input_ids),
input_lengths=input_lengths.tolist(), input_lengths=input_lengths.tolist(),
prefix_offsets=prefix_offsets, prefix_offsets=prefix_offsets,
@ -198,7 +197,6 @@ class MambaBatch(Batch):
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached # Apply indices to input_ids, attention mask, past key values and other items that need to be cached
input_ids = self.input_ids[keep_indices] input_ids = self.input_ids[keep_indices]
position_ids = self.position_ids[keep_indices]
top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices] top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
@ -245,9 +243,6 @@ class MambaBatch(Batch):
# Batch tensors # Batch tensors
input_ids = None input_ids = None
attention_mask = None
position_ids = None
past_key_values = []
top_n_tokens_tensor = None top_n_tokens_tensor = None
# Used for slicing correctly inside the tensors # Used for slicing correctly inside the tensors
@ -273,10 +268,6 @@ class MambaBatch(Batch):
# Slicing end index for this batch # Slicing end index for this batch
end_index = start_index + len(batch) end_index = start_index + len(batch)
# We only concatenate batches that did at least one step
if batch.past_key_values is None:
raise ValueError("only concatenate prefilled batches")
# Create empty tensor # Create empty tensor
# input_ids is always of shape [batch_size, 1] # input_ids is always of shape [batch_size, 1]
# We do not need to pad it # We do not need to pad it
@ -285,12 +276,6 @@ class MambaBatch(Batch):
# Copy to correct indices # Copy to correct indices
input_ids[start_index:end_index] = batch.input_ids input_ids[start_index:end_index] = batch.input_ids
# Create padded tensor
if attention_mask is None:
attention_mask = batch.attention_mask.new_zeros(
(total_batch_size, max_input_length + padding_right_offset),
)
if top_n_tokens_tensor is None: if top_n_tokens_tensor is None:
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros( top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
total_batch_size, total_batch_size,
@ -309,9 +294,6 @@ class MambaBatch(Batch):
requests=requests, requests=requests,
requests_idx_mapping=requests_idx_mapping, requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
input_lengths=input_lengths, input_lengths=input_lengths,
prefix_offsets=prefix_offsets, prefix_offsets=prefix_offsets,
@ -555,8 +537,6 @@ class Mamba(Model):
) )
else: else:
prefill_tokens = None prefill_tokens = None
past_input_ids = torch.cat([past_input_ids, next_token_id], dim=1)
if top_n_tokens > 0: if top_n_tokens > 0:
toptoken_texts = self.tokenizer.batch_decode( toptoken_texts = self.tokenizer.batch_decode(
@ -608,9 +588,6 @@ class Mamba(Model):
# Slice unused values from prefill # Slice unused values from prefill
batch.input_ids = batch.input_ids[:, :1] batch.input_ids = batch.input_ids[:, :1]
batch.past_input_ids = past_input_ids
forward_ns = start_decode - start forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode decode_ns = time.time_ns() - start_decode
return generations, batch, (forward_ns, decode_ns) return generations, batch, (forward_ns, decode_ns)