mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
fix tests for causal lm
This commit is contained in:
parent
2ad7a63761
commit
d9578153cb
@ -44,11 +44,12 @@ def default_causal_lm_batch(default_pb_batch, gpt2_tokenizer):
|
||||
@pytest.fixture
|
||||
def default_multi_requests_causal_lm_batch(default_pb_request, gpt2_tokenizer):
|
||||
req_0 = copy(default_pb_request)
|
||||
req_0.id = 1
|
||||
req_1 = default_pb_request
|
||||
req_1.id = 1
|
||||
req_1.id = 2
|
||||
req_1.stopping_parameters.max_new_tokens = 5
|
||||
|
||||
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
|
||||
batch_pb = generate_pb2.Batch(id=1, requests=[req_0, req_1], size=2)
|
||||
return CausalLMBatch.from_pb(batch_pb, gpt2_tokenizer, torch.device("cpu"))
|
||||
|
||||
|
||||
@ -67,12 +68,12 @@ def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
|
||||
|
||||
assert batch.past_key_values is None
|
||||
|
||||
assert torch.equal(batch.input_ids, batch.all_input_ids[:, :, 0])
|
||||
assert all([torch.equal(input_ids, all_input_ids[:, 0]) for input_ids, all_input_ids in zip(batch.input_ids, batch.all_input_ids)])
|
||||
|
||||
assert batch.input_lengths == [1]
|
||||
|
||||
assert batch.size == default_pb_batch.size
|
||||
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size
|
||||
assert len(batch) == default_pb_batch.size
|
||||
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == len(batch)
|
||||
|
||||
assert batch.max_input_length == batch.input_lengths[0]
|
||||
|
||||
@ -93,7 +94,7 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
|
||||
assert len(generations) == len(next_batch)
|
||||
assert isinstance(next_batch, CausalLMBatch)
|
||||
|
||||
assert len(next_batch.all_input_ids) == next_batch.size
|
||||
assert len(next_batch.all_input_ids) == len(next_batch)
|
||||
assert len(next_batch.all_input_ids[0]) == sequence_length + 1
|
||||
assert len(next_batch.attention_mask[0]) == 11
|
||||
assert next_batch.all_input_ids[0][-1] == 13
|
||||
@ -103,7 +104,7 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
|
||||
assert torch.all(next_batch.attention_mask[0][0:2] == 1)
|
||||
assert torch.all(next_batch.attention_mask[0][2:] == 0)
|
||||
|
||||
assert next_batch.input_ids.shape == (next_batch.size, 1)
|
||||
assert next_batch.input_ids.shape == (len(next_batch), 1)
|
||||
assert next_batch.input_ids[0, 0] == 13
|
||||
|
||||
assert next_batch.input_lengths == [2]
|
||||
@ -168,6 +169,8 @@ def test_causal_lm_generate_token_completion_multi(
|
||||
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||
)
|
||||
|
||||
next_batch = next_batch.filter([next_batch.requests[0]])
|
||||
|
||||
for _ in range(
|
||||
default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||
@ -266,6 +269,8 @@ def test_batch_concatenate(
|
||||
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||
)
|
||||
|
||||
next_batch = next_batch.filter([next_batch.requests[0], next_batch.requests[1]])
|
||||
|
||||
for _ in range(
|
||||
default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||
@ -285,6 +290,8 @@ def test_batch_concatenate(
|
||||
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||
)
|
||||
|
||||
next_batch = next_batch.filter([next_batch.requests[1]])
|
||||
|
||||
for _ in range(
|
||||
default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||
- default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||
|
@ -147,10 +147,13 @@ class CausalLMBatch(Batch):
|
||||
all_input_ids = []
|
||||
max_input_length = 0
|
||||
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
|
||||
for i, r in enumerate(requests):
|
||||
idx = self.requests_idx_mapping[r.id]
|
||||
keep_indices.append(idx)
|
||||
requests_idx_mapping[r.id] = i
|
||||
keep_indices.append(idx)
|
||||
|
||||
offsets.append(self.offsets[idx])
|
||||
token_offsets.append(self.token_offsets[idx])
|
||||
@ -160,28 +163,37 @@ class CausalLMBatch(Batch):
|
||||
input_lengths.append(request_input_length)
|
||||
max_input_length = max(max_input_length, request_input_length)
|
||||
|
||||
# Replace metadata
|
||||
self.requests_idx_mapping = requests_idx_mapping
|
||||
self.input_lengths = input_lengths
|
||||
self.offsets = offsets
|
||||
self.token_offsets = token_offsets
|
||||
self.all_input_ids = all_input_ids
|
||||
self.max_input_length = max_input_length
|
||||
next_token_choosers.append(self.next_token_choosers[idx])
|
||||
stopping_criterias.append(self.stopping_criterias[idx])
|
||||
|
||||
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
|
||||
self.input_ids = self.input_ids[keep_indices]
|
||||
self.attention_mask = self.attention_mask[keep_indices]
|
||||
self.position_ids = self.position_ids[keep_indices]
|
||||
input_ids = self.input_ids[keep_indices]
|
||||
attention_mask = self.attention_mask[keep_indices]
|
||||
position_ids = self.position_ids[keep_indices]
|
||||
# Force past to be of dim [self_size, num_heads, ...] for easy indexing
|
||||
self.past_key_values = [
|
||||
past_key_values = [
|
||||
[t.view(len(self), -1, *t.shape[-2:])[keep_indices] for t in layer]
|
||||
for layer in self.past_key_values
|
||||
]
|
||||
self.requests = requests
|
||||
self.next_token_choosers = [self.next_token_choosers[i] for i in keep_indices]
|
||||
self.stopping_criterias = [self.stopping_criterias[i] for i in keep_indices]
|
||||
|
||||
return self
|
||||
return CausalLMBatch(
|
||||
batch_id=self.batch_id,
|
||||
requests=requests,
|
||||
requests_idx_mapping=requests_idx_mapping,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
all_input_ids=all_input_ids,
|
||||
input_lengths=input_lengths,
|
||||
offsets=offsets,
|
||||
token_offsets=token_offsets,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
max_input_length=max_input_length,
|
||||
padding_right_offset=self.padding_right_offset,
|
||||
keys_head_dim_last=self.keys_head_dim_last,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
@ -224,8 +236,9 @@ class CausalLMBatch(Batch):
|
||||
stopping_criterias.extend(batch.stopping_criterias)
|
||||
|
||||
if i == 0:
|
||||
requests_idx_mapping = requests_idx_mapping
|
||||
requests_idx_mapping = batch.requests_idx_mapping
|
||||
else:
|
||||
# We need to offset the mapping for each batch by the cumulative batch size
|
||||
for k, v in batch.requests_idx_mapping.items():
|
||||
requests_idx_mapping[k] = v + start_index
|
||||
|
||||
@ -525,17 +538,20 @@ class CausalLM(Model):
|
||||
generations.append(generation)
|
||||
|
||||
# Update values
|
||||
batch.input_ids[i] = next_token_id
|
||||
batch.input_ids[i, 0] = next_token_id
|
||||
batch.all_input_ids[i] = all_input_ids
|
||||
batch.input_lengths[i] = new_input_length
|
||||
batch.offsets[i] = offset
|
||||
batch.token_offsets[i] = token_offset
|
||||
batch.max_input_length = max(batch.max_input_length, new_input_length)
|
||||
|
||||
# Decrease right offset
|
||||
batch.padding_right_offset -= 1
|
||||
# Slice unused values from prefill
|
||||
batch.input_ids = batch.input_ids[:, :1]
|
||||
|
||||
# Update attention_mask as we added a new token to input_ids
|
||||
batch.attention_mask[:, -batch.padding_right_offset] = 1
|
||||
# Decrease right offset
|
||||
batch.padding_right_offset -= 1
|
||||
|
||||
# Update position_ids
|
||||
batch.position_ids = batch.position_ids[:, -1:] + 1
|
||||
|
@ -244,9 +244,12 @@ class FlashCausalLMBatch(Batch):
|
||||
for i, batch in enumerate(batches):
|
||||
requests.extend(batch.requests)
|
||||
|
||||
# We need to offset the mapping for each batch by the cumulative batch size
|
||||
for k, v in batch.requests_idx_mapping.items():
|
||||
requests_idx_mapping[k] = v + cumulative_batch_size
|
||||
if i == 0:
|
||||
requests_idx_mapping = batch.requests_idx_mapping
|
||||
else:
|
||||
# We need to offset the mapping for each batch by the cumulative batch size
|
||||
for k, v in batch.requests_idx_mapping.items():
|
||||
requests_idx_mapping[k] = v + cumulative_batch_size
|
||||
|
||||
input_ids.extend(batch.input_ids)
|
||||
position_ids.extend(batch.position_ids)
|
||||
|
Loading…
Reference in New Issue
Block a user