fix tests for causal lm

This commit is contained in:
OlivierDehaene 2023-04-18 18:56:51 +02:00
parent 2ad7a63761
commit d9578153cb
3 changed files with 56 additions and 30 deletions

View File

@ -44,11 +44,12 @@ def default_causal_lm_batch(default_pb_batch, gpt2_tokenizer):
@pytest.fixture
def default_multi_requests_causal_lm_batch(default_pb_request, gpt2_tokenizer):
req_0 = copy(default_pb_request)
req_0.id = 1
req_1 = default_pb_request
req_1.id = 1
req_1.id = 2
req_1.stopping_parameters.max_new_tokens = 5
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
batch_pb = generate_pb2.Batch(id=1, requests=[req_0, req_1], size=2)
return CausalLMBatch.from_pb(batch_pb, gpt2_tokenizer, torch.device("cpu"))
@ -67,12 +68,12 @@ def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
assert batch.past_key_values is None
assert torch.equal(batch.input_ids, batch.all_input_ids[:, :, 0])
assert all([torch.equal(input_ids, all_input_ids[:, 0]) for input_ids, all_input_ids in zip(batch.input_ids, batch.all_input_ids)])
assert batch.input_lengths == [1]
assert batch.size == default_pb_batch.size
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size
assert len(batch) == default_pb_batch.size
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == len(batch)
assert batch.max_input_length == batch.input_lengths[0]
@ -93,7 +94,7 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
assert len(generations) == len(next_batch)
assert isinstance(next_batch, CausalLMBatch)
assert len(next_batch.all_input_ids) == next_batch.size
assert len(next_batch.all_input_ids) == len(next_batch)
assert len(next_batch.all_input_ids[0]) == sequence_length + 1
assert len(next_batch.attention_mask[0]) == 11
assert next_batch.all_input_ids[0][-1] == 13
@ -103,7 +104,7 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
assert torch.all(next_batch.attention_mask[0][0:2] == 1)
assert torch.all(next_batch.attention_mask[0][2:] == 0)
assert next_batch.input_ids.shape == (next_batch.size, 1)
assert next_batch.input_ids.shape == (len(next_batch), 1)
assert next_batch.input_ids[0, 0] == 13
assert next_batch.input_lengths == [2]
@ -168,6 +169,8 @@ def test_causal_lm_generate_token_completion_multi(
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
)
next_batch = next_batch.filter([next_batch.requests[0]])
for _ in range(
default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
@ -266,6 +269,8 @@ def test_batch_concatenate(
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
)
next_batch = next_batch.filter([next_batch.requests[0], next_batch.requests[1]])
for _ in range(
default_causal_lm_batch.stopping_criterias[0].max_new_tokens
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
@ -285,6 +290,8 @@ def test_batch_concatenate(
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens
)
next_batch = next_batch.filter([next_batch.requests[1]])
for _ in range(
default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
- default_causal_lm_batch.stopping_criterias[0].max_new_tokens

View File

@ -147,10 +147,13 @@ class CausalLMBatch(Batch):
all_input_ids = []
max_input_length = 0
next_token_choosers = []
stopping_criterias = []
for i, r in enumerate(requests):
idx = self.requests_idx_mapping[r.id]
keep_indices.append(idx)
requests_idx_mapping[r.id] = i
keep_indices.append(idx)
offsets.append(self.offsets[idx])
token_offsets.append(self.token_offsets[idx])
@ -160,28 +163,37 @@ class CausalLMBatch(Batch):
input_lengths.append(request_input_length)
max_input_length = max(max_input_length, request_input_length)
# Replace metadata
self.requests_idx_mapping = requests_idx_mapping
self.input_lengths = input_lengths
self.offsets = offsets
self.token_offsets = token_offsets
self.all_input_ids = all_input_ids
self.max_input_length = max_input_length
next_token_choosers.append(self.next_token_choosers[idx])
stopping_criterias.append(self.stopping_criterias[idx])
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
self.input_ids = self.input_ids[keep_indices]
self.attention_mask = self.attention_mask[keep_indices]
self.position_ids = self.position_ids[keep_indices]
input_ids = self.input_ids[keep_indices]
attention_mask = self.attention_mask[keep_indices]
position_ids = self.position_ids[keep_indices]
# Force past to be of dim [self_size, num_heads, ...] for easy indexing
self.past_key_values = [
past_key_values = [
[t.view(len(self), -1, *t.shape[-2:])[keep_indices] for t in layer]
for layer in self.past_key_values
]
self.requests = requests
self.next_token_choosers = [self.next_token_choosers[i] for i in keep_indices]
self.stopping_criterias = [self.stopping_criterias[i] for i in keep_indices]
return self
return CausalLMBatch(
batch_id=self.batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
all_input_ids=all_input_ids,
input_lengths=input_lengths,
offsets=offsets,
token_offsets=token_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
max_input_length=max_input_length,
padding_right_offset=self.padding_right_offset,
keys_head_dim_last=self.keys_head_dim_last,
)
@classmethod
@tracer.start_as_current_span("concatenate")
@ -224,8 +236,9 @@ class CausalLMBatch(Batch):
stopping_criterias.extend(batch.stopping_criterias)
if i == 0:
requests_idx_mapping = requests_idx_mapping
requests_idx_mapping = batch.requests_idx_mapping
else:
# We need to offset the mapping for each batch by the cumulative batch size
for k, v in batch.requests_idx_mapping.items():
requests_idx_mapping[k] = v + start_index
@ -525,17 +538,20 @@ class CausalLM(Model):
generations.append(generation)
# Update values
batch.input_ids[i] = next_token_id
batch.input_ids[i, 0] = next_token_id
batch.all_input_ids[i] = all_input_ids
batch.input_lengths[i] = new_input_length
batch.offsets[i] = offset
batch.token_offsets[i] = token_offset
batch.max_input_length = max(batch.max_input_length, new_input_length)
# Decrease right offset
batch.padding_right_offset -= 1
# Slice unused values from prefill
batch.input_ids = batch.input_ids[:, :1]
# Update attention_mask as we added a new token to input_ids
batch.attention_mask[:, -batch.padding_right_offset] = 1
# Decrease right offset
batch.padding_right_offset -= 1
# Update position_ids
batch.position_ids = batch.position_ids[:, -1:] + 1

View File

@ -244,9 +244,12 @@ class FlashCausalLMBatch(Batch):
for i, batch in enumerate(batches):
requests.extend(batch.requests)
# We need to offset the mapping for each batch by the cumulative batch size
for k, v in batch.requests_idx_mapping.items():
requests_idx_mapping[k] = v + cumulative_batch_size
if i == 0:
requests_idx_mapping = batch.requests_idx_mapping
else:
# We need to offset the mapping for each batch by the cumulative batch size
for k, v in batch.requests_idx_mapping.items():
requests_idx_mapping[k] = v + cumulative_batch_size
input_ids.extend(batch.input_ids)
position_ids.extend(batch.position_ids)