mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
fix tests for causal lm
This commit is contained in:
parent
2ad7a63761
commit
d9578153cb
@ -44,11 +44,12 @@ def default_causal_lm_batch(default_pb_batch, gpt2_tokenizer):
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def default_multi_requests_causal_lm_batch(default_pb_request, gpt2_tokenizer):
|
def default_multi_requests_causal_lm_batch(default_pb_request, gpt2_tokenizer):
|
||||||
req_0 = copy(default_pb_request)
|
req_0 = copy(default_pb_request)
|
||||||
|
req_0.id = 1
|
||||||
req_1 = default_pb_request
|
req_1 = default_pb_request
|
||||||
req_1.id = 1
|
req_1.id = 2
|
||||||
req_1.stopping_parameters.max_new_tokens = 5
|
req_1.stopping_parameters.max_new_tokens = 5
|
||||||
|
|
||||||
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
|
batch_pb = generate_pb2.Batch(id=1, requests=[req_0, req_1], size=2)
|
||||||
return CausalLMBatch.from_pb(batch_pb, gpt2_tokenizer, torch.device("cpu"))
|
return CausalLMBatch.from_pb(batch_pb, gpt2_tokenizer, torch.device("cpu"))
|
||||||
|
|
||||||
|
|
||||||
@ -67,12 +68,12 @@ def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
|
|||||||
|
|
||||||
assert batch.past_key_values is None
|
assert batch.past_key_values is None
|
||||||
|
|
||||||
assert torch.equal(batch.input_ids, batch.all_input_ids[:, :, 0])
|
assert all([torch.equal(input_ids, all_input_ids[:, 0]) for input_ids, all_input_ids in zip(batch.input_ids, batch.all_input_ids)])
|
||||||
|
|
||||||
assert batch.input_lengths == [1]
|
assert batch.input_lengths == [1]
|
||||||
|
|
||||||
assert batch.size == default_pb_batch.size
|
assert len(batch) == default_pb_batch.size
|
||||||
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size
|
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == len(batch)
|
||||||
|
|
||||||
assert batch.max_input_length == batch.input_lengths[0]
|
assert batch.max_input_length == batch.input_lengths[0]
|
||||||
|
|
||||||
@ -93,7 +94,7 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
|
|||||||
assert len(generations) == len(next_batch)
|
assert len(generations) == len(next_batch)
|
||||||
assert isinstance(next_batch, CausalLMBatch)
|
assert isinstance(next_batch, CausalLMBatch)
|
||||||
|
|
||||||
assert len(next_batch.all_input_ids) == next_batch.size
|
assert len(next_batch.all_input_ids) == len(next_batch)
|
||||||
assert len(next_batch.all_input_ids[0]) == sequence_length + 1
|
assert len(next_batch.all_input_ids[0]) == sequence_length + 1
|
||||||
assert len(next_batch.attention_mask[0]) == 11
|
assert len(next_batch.attention_mask[0]) == 11
|
||||||
assert next_batch.all_input_ids[0][-1] == 13
|
assert next_batch.all_input_ids[0][-1] == 13
|
||||||
@ -103,7 +104,7 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
|
|||||||
assert torch.all(next_batch.attention_mask[0][0:2] == 1)
|
assert torch.all(next_batch.attention_mask[0][0:2] == 1)
|
||||||
assert torch.all(next_batch.attention_mask[0][2:] == 0)
|
assert torch.all(next_batch.attention_mask[0][2:] == 0)
|
||||||
|
|
||||||
assert next_batch.input_ids.shape == (next_batch.size, 1)
|
assert next_batch.input_ids.shape == (len(next_batch), 1)
|
||||||
assert next_batch.input_ids[0, 0] == 13
|
assert next_batch.input_ids[0, 0] == 13
|
||||||
|
|
||||||
assert next_batch.input_lengths == [2]
|
assert next_batch.input_lengths == [2]
|
||||||
@ -168,6 +169,8 @@ def test_causal_lm_generate_token_completion_multi(
|
|||||||
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
|
next_batch = next_batch.filter([next_batch.requests[0]])
|
||||||
|
|
||||||
for _ in range(
|
for _ in range(
|
||||||
default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||||
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||||
@ -266,6 +269,8 @@ def test_batch_concatenate(
|
|||||||
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
|
next_batch = next_batch.filter([next_batch.requests[0], next_batch.requests[1]])
|
||||||
|
|
||||||
for _ in range(
|
for _ in range(
|
||||||
default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||||
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||||
@ -285,6 +290,8 @@ def test_batch_concatenate(
|
|||||||
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
|
next_batch = next_batch.filter([next_batch.requests[1]])
|
||||||
|
|
||||||
for _ in range(
|
for _ in range(
|
||||||
default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||||
- default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
- default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||||
|
@ -147,10 +147,13 @@ class CausalLMBatch(Batch):
|
|||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
max_input_length = 0
|
max_input_length = 0
|
||||||
|
|
||||||
|
next_token_choosers = []
|
||||||
|
stopping_criterias = []
|
||||||
|
|
||||||
for i, r in enumerate(requests):
|
for i, r in enumerate(requests):
|
||||||
idx = self.requests_idx_mapping[r.id]
|
idx = self.requests_idx_mapping[r.id]
|
||||||
keep_indices.append(idx)
|
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
|
keep_indices.append(idx)
|
||||||
|
|
||||||
offsets.append(self.offsets[idx])
|
offsets.append(self.offsets[idx])
|
||||||
token_offsets.append(self.token_offsets[idx])
|
token_offsets.append(self.token_offsets[idx])
|
||||||
@ -160,28 +163,37 @@ class CausalLMBatch(Batch):
|
|||||||
input_lengths.append(request_input_length)
|
input_lengths.append(request_input_length)
|
||||||
max_input_length = max(max_input_length, request_input_length)
|
max_input_length = max(max_input_length, request_input_length)
|
||||||
|
|
||||||
# Replace metadata
|
next_token_choosers.append(self.next_token_choosers[idx])
|
||||||
self.requests_idx_mapping = requests_idx_mapping
|
stopping_criterias.append(self.stopping_criterias[idx])
|
||||||
self.input_lengths = input_lengths
|
|
||||||
self.offsets = offsets
|
|
||||||
self.token_offsets = token_offsets
|
|
||||||
self.all_input_ids = all_input_ids
|
|
||||||
self.max_input_length = max_input_length
|
|
||||||
|
|
||||||
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
|
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
|
||||||
self.input_ids = self.input_ids[keep_indices]
|
input_ids = self.input_ids[keep_indices]
|
||||||
self.attention_mask = self.attention_mask[keep_indices]
|
attention_mask = self.attention_mask[keep_indices]
|
||||||
self.position_ids = self.position_ids[keep_indices]
|
position_ids = self.position_ids[keep_indices]
|
||||||
# Force past to be of dim [self_size, num_heads, ...] for easy indexing
|
# Force past to be of dim [self_size, num_heads, ...] for easy indexing
|
||||||
self.past_key_values = [
|
past_key_values = [
|
||||||
[t.view(len(self), -1, *t.shape[-2:])[keep_indices] for t in layer]
|
[t.view(len(self), -1, *t.shape[-2:])[keep_indices] for t in layer]
|
||||||
for layer in self.past_key_values
|
for layer in self.past_key_values
|
||||||
]
|
]
|
||||||
self.requests = requests
|
|
||||||
self.next_token_choosers = [self.next_token_choosers[i] for i in keep_indices]
|
|
||||||
self.stopping_criterias = [self.stopping_criterias[i] for i in keep_indices]
|
|
||||||
|
|
||||||
return self
|
return CausalLMBatch(
|
||||||
|
batch_id=self.batch_id,
|
||||||
|
requests=requests,
|
||||||
|
requests_idx_mapping=requests_idx_mapping,
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
all_input_ids=all_input_ids,
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
offsets=offsets,
|
||||||
|
token_offsets=token_offsets,
|
||||||
|
next_token_choosers=next_token_choosers,
|
||||||
|
stopping_criterias=stopping_criterias,
|
||||||
|
max_input_length=max_input_length,
|
||||||
|
padding_right_offset=self.padding_right_offset,
|
||||||
|
keys_head_dim_last=self.keys_head_dim_last,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@tracer.start_as_current_span("concatenate")
|
@tracer.start_as_current_span("concatenate")
|
||||||
@ -224,8 +236,9 @@ class CausalLMBatch(Batch):
|
|||||||
stopping_criterias.extend(batch.stopping_criterias)
|
stopping_criterias.extend(batch.stopping_criterias)
|
||||||
|
|
||||||
if i == 0:
|
if i == 0:
|
||||||
requests_idx_mapping = requests_idx_mapping
|
requests_idx_mapping = batch.requests_idx_mapping
|
||||||
else:
|
else:
|
||||||
|
# We need to offset the mapping for each batch by the cumulative batch size
|
||||||
for k, v in batch.requests_idx_mapping.items():
|
for k, v in batch.requests_idx_mapping.items():
|
||||||
requests_idx_mapping[k] = v + start_index
|
requests_idx_mapping[k] = v + start_index
|
||||||
|
|
||||||
@ -525,17 +538,20 @@ class CausalLM(Model):
|
|||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
batch.input_ids[i] = next_token_id
|
batch.input_ids[i, 0] = next_token_id
|
||||||
batch.all_input_ids[i] = all_input_ids
|
batch.all_input_ids[i] = all_input_ids
|
||||||
batch.input_lengths[i] = new_input_length
|
batch.input_lengths[i] = new_input_length
|
||||||
batch.offsets[i] = offset
|
batch.offsets[i] = offset
|
||||||
batch.token_offsets[i] = token_offset
|
batch.token_offsets[i] = token_offset
|
||||||
batch.max_input_length = max(batch.max_input_length, new_input_length)
|
batch.max_input_length = max(batch.max_input_length, new_input_length)
|
||||||
|
|
||||||
# Decrease right offset
|
# Slice unused values from prefill
|
||||||
batch.padding_right_offset -= 1
|
batch.input_ids = batch.input_ids[:, :1]
|
||||||
|
|
||||||
# Update attention_mask as we added a new token to input_ids
|
# Update attention_mask as we added a new token to input_ids
|
||||||
batch.attention_mask[:, -batch.padding_right_offset] = 1
|
batch.attention_mask[:, -batch.padding_right_offset] = 1
|
||||||
|
# Decrease right offset
|
||||||
|
batch.padding_right_offset -= 1
|
||||||
|
|
||||||
# Update position_ids
|
# Update position_ids
|
||||||
batch.position_ids = batch.position_ids[:, -1:] + 1
|
batch.position_ids = batch.position_ids[:, -1:] + 1
|
||||||
|
@ -244,9 +244,12 @@ class FlashCausalLMBatch(Batch):
|
|||||||
for i, batch in enumerate(batches):
|
for i, batch in enumerate(batches):
|
||||||
requests.extend(batch.requests)
|
requests.extend(batch.requests)
|
||||||
|
|
||||||
# We need to offset the mapping for each batch by the cumulative batch size
|
if i == 0:
|
||||||
for k, v in batch.requests_idx_mapping.items():
|
requests_idx_mapping = batch.requests_idx_mapping
|
||||||
requests_idx_mapping[k] = v + cumulative_batch_size
|
else:
|
||||||
|
# We need to offset the mapping for each batch by the cumulative batch size
|
||||||
|
for k, v in batch.requests_idx_mapping.items():
|
||||||
|
requests_idx_mapping[k] = v + cumulative_batch_size
|
||||||
|
|
||||||
input_ids.extend(batch.input_ids)
|
input_ids.extend(batch.input_ids)
|
||||||
position_ids.extend(batch.position_ids)
|
position_ids.extend(batch.position_ids)
|
||||||
|
Loading…
Reference in New Issue
Block a user