mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
Update tests, fix bugs, format
This commit is contained in:
parent
7f81c25d07
commit
ba5b79bb5e
@ -5,7 +5,10 @@ from copy import copy
|
|||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.models.vectorized_causal_lm import VectorizedCausalLM, VectorizedCausalLMBatch
|
from text_generation_server.models.vectorized_causal_lm import (
|
||||||
|
VectorizedCausalLM,
|
||||||
|
VectorizedCausalLMBatch,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
@ -38,7 +41,9 @@ def default_pb_batch(default_pb_request):
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def default_causal_lm_batch(default_pb_batch, gpt2_tokenizer):
|
def default_causal_lm_batch(default_pb_batch, gpt2_tokenizer):
|
||||||
return VectorizedCausalLMBatch.from_pb(default_pb_batch, gpt2_tokenizer, torch.device("cpu"))
|
return VectorizedCausalLMBatch.from_pb(
|
||||||
|
default_pb_batch, gpt2_tokenizer, torch.device("cpu")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -50,7 +55,9 @@ def default_multi_requests_causal_lm_batch(default_pb_request, gpt2_tokenizer):
|
|||||||
req_1.stopping_parameters.max_new_tokens = 5
|
req_1.stopping_parameters.max_new_tokens = 5
|
||||||
|
|
||||||
batch_pb = generate_pb2.Batch(id=1, requests=[req_0, req_1], size=2)
|
batch_pb = generate_pb2.Batch(id=1, requests=[req_0, req_1], size=2)
|
||||||
return VectorizedCausalLMBatch.from_pb(batch_pb, gpt2_tokenizer, torch.device("cpu"))
|
return VectorizedCausalLMBatch.from_pb(
|
||||||
|
batch_pb, gpt2_tokenizer, torch.device("cpu")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
|
def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
|
||||||
@ -59,33 +66,29 @@ def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
|
|||||||
assert batch.batch_id == default_pb_batch.id
|
assert batch.batch_id == default_pb_batch.id
|
||||||
assert batch.requests == default_pb_batch.requests
|
assert batch.requests == default_pb_batch.requests
|
||||||
|
|
||||||
assert len(batch.input_ids) == default_pb_batch.size
|
assert batch.input_ids.shape == (1, 11)
|
||||||
assert batch.input_ids[0][-1] == 14402
|
assert batch.input_ids[0, 0] == 14402
|
||||||
assert torch.all(batch.input_ids[0][:-1] == 50256)
|
|
||||||
|
|
||||||
|
assert batch.attention_mask.shape == (1, 11)
|
||||||
assert batch.attention_mask[0, 0] == 1
|
assert batch.attention_mask[0, 0] == 1
|
||||||
assert torch.all(batch.attention_mask[0, 1:] == 0)
|
assert batch.attention_mask.all()
|
||||||
|
|
||||||
|
assert batch.position_ids.shape == (1, 11)
|
||||||
assert batch.past_key_values is None
|
assert batch.past_key_values is None
|
||||||
|
|
||||||
assert all(
|
|
||||||
[
|
|
||||||
torch.equal(input_ids, all_input_ids[:, 0])
|
|
||||||
for input_ids, all_input_ids in zip(batch.input_ids, batch.all_input_ids)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
assert batch.input_lengths == [1]
|
assert batch.input_lengths == [1]
|
||||||
|
|
||||||
assert len(batch) == default_pb_batch.size
|
assert len(batch) == 1
|
||||||
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == len(batch)
|
assert len(batch.stopping_criterias) == 1
|
||||||
|
|
||||||
assert batch.max_input_length == batch.input_lengths[0]
|
assert batch.max_input_length == 1
|
||||||
|
|
||||||
|
|
||||||
def test_batch_concatenate_no_prefill(default_causal_lm_batch):
|
def test_batch_concatenate_no_prefill(default_causal_lm_batch):
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
VectorizedCausalLMBatch.concatenate([default_causal_lm_batch, default_causal_lm_batch])
|
VectorizedCausalLMBatch.concatenate(
|
||||||
|
[default_causal_lm_batch, default_causal_lm_batch]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_causal_lm_batch_type(default_causal_lm):
|
def test_causal_lm_batch_type(default_causal_lm):
|
||||||
@ -93,39 +96,29 @@ def test_causal_lm_batch_type(default_causal_lm):
|
|||||||
|
|
||||||
|
|
||||||
def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
|
def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
|
||||||
sequence_length = len(default_causal_lm_batch.all_input_ids[0])
|
|
||||||
generations, next_batch = default_causal_lm.generate_token(default_causal_lm_batch)
|
generations, next_batch = default_causal_lm.generate_token(default_causal_lm_batch)
|
||||||
|
|
||||||
assert len(generations) == len(next_batch)
|
assert len(generations) == len(next_batch) == 1
|
||||||
assert isinstance(next_batch, VectorizedCausalLMBatch)
|
assert isinstance(next_batch, VectorizedCausalLMBatch)
|
||||||
|
|
||||||
assert len(next_batch.all_input_ids) == len(next_batch)
|
assert next_batch.input_ids.shape == (1, 11)
|
||||||
assert len(next_batch.all_input_ids[0]) == sequence_length + 1
|
assert next_batch.input_ids[0, 0] == 14402
|
||||||
assert len(next_batch.attention_mask[0]) == 11
|
assert next_batch.input_ids[0, 1] == 13
|
||||||
assert next_batch.all_input_ids[0][-1] == 13
|
assert next_batch.max_input_length == 2
|
||||||
assert next_batch.all_input_ids[0][-2] == 14402
|
assert next_batch.attention_mask.shape == (1, 11)
|
||||||
assert torch.all(next_batch.all_input_ids[0][:-2] == 50256)
|
|
||||||
|
|
||||||
assert torch.all(next_batch.attention_mask[0][0:2] == 1)
|
assert next_batch.attention_mask.all()
|
||||||
assert torch.all(next_batch.attention_mask[0][2:] == 0)
|
|
||||||
|
|
||||||
assert next_batch.input_ids.shape == (len(next_batch), 1)
|
|
||||||
assert next_batch.input_ids[0, 0] == 13
|
|
||||||
|
|
||||||
assert next_batch.input_lengths == [2]
|
assert next_batch.input_lengths == [2]
|
||||||
assert next_batch.max_input_length == next_batch.input_lengths[0]
|
|
||||||
|
|
||||||
assert next_batch.past_key_values is not None
|
assert next_batch.past_key_values is not None
|
||||||
assert all(
|
assert all([p[0].shape == (1, 12, 1, 64) for p in next_batch.past_key_values])
|
||||||
[p[0].shape == (1, 12, sequence_length, 64) for p in next_batch.past_key_values]
|
assert all([p[1].shape == (1, 12, 1, 64) for p in next_batch.past_key_values])
|
||||||
)
|
|
||||||
assert all(
|
assert generations[0].generated_text is None
|
||||||
[p[1].shape == (1, 12, sequence_length, 64) for p in next_batch.past_key_values]
|
assert len(generations[0].prefill_tokens) == 1
|
||||||
)
|
assert generations[0].token_id.item() == 13
|
||||||
assert all([generation.generated_text is None for generation in generations])
|
assert generations[0].token_text == "."
|
||||||
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
|
|
||||||
assert all([generation.token_id.item() == 13 for generation in generations])
|
|
||||||
assert all([generation.token_text == "." for generation in generations])
|
|
||||||
assert generations[0].request_id == 0
|
assert generations[0].request_id == 0
|
||||||
|
|
||||||
|
|
||||||
@ -222,21 +215,53 @@ def test_batch_concatenate(
|
|||||||
|
|
||||||
next_batch = VectorizedCausalLMBatch.concatenate([next_batch_0, next_batch_1])
|
next_batch = VectorizedCausalLMBatch.concatenate([next_batch_0, next_batch_1])
|
||||||
|
|
||||||
assert torch.equal(next_batch.all_input_ids[0], next_batch_0.all_input_ids[0])
|
assert torch.equal(
|
||||||
assert torch.equal(next_batch.all_input_ids[1], next_batch_1.all_input_ids[0])
|
next_batch.input_ids[
|
||||||
assert torch.equal(next_batch.all_input_ids[2], next_batch_1.all_input_ids[1])
|
0,
|
||||||
|
next_batch.max_input_length
|
||||||
|
- next_batch.input_lengths[0] : next_batch.max_input_length,
|
||||||
|
],
|
||||||
|
next_batch_0.input_ids[
|
||||||
|
0,
|
||||||
|
next_batch_0.max_input_length
|
||||||
|
- next_batch_0.input_lengths[0] : next_batch_0.max_input_length,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert torch.equal(
|
||||||
|
next_batch.input_ids[
|
||||||
|
1,
|
||||||
|
next_batch.max_input_length
|
||||||
|
- next_batch.input_lengths[1] : next_batch.max_input_length,
|
||||||
|
],
|
||||||
|
next_batch_1.input_ids[
|
||||||
|
0,
|
||||||
|
next_batch_1.max_input_length
|
||||||
|
- next_batch_1.input_lengths[0] : next_batch_1.max_input_length,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert torch.equal(
|
||||||
|
next_batch.input_ids[
|
||||||
|
2,
|
||||||
|
next_batch.max_input_length
|
||||||
|
- next_batch.input_lengths[2] : next_batch.max_input_length,
|
||||||
|
],
|
||||||
|
next_batch_1.input_ids[
|
||||||
|
1,
|
||||||
|
next_batch_1.max_input_length
|
||||||
|
- next_batch_1.input_lengths[1] : next_batch_1.max_input_length,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
assert torch.all(
|
assert next_batch.attention_mask[0].all()
|
||||||
next_batch.attention_mask[0, : -next_batch.padding_right_offset] == 1
|
assert next_batch.attention_mask[1:, 1:].all()
|
||||||
)
|
assert next_batch.attention_mask[1:, :1].logical_not().all()
|
||||||
assert torch.all(
|
|
||||||
next_batch.attention_mask[1:, 1 : -next_batch.padding_right_offset] == 1
|
|
||||||
)
|
|
||||||
assert torch.all(next_batch.attention_mask[1:, 3:] == 0)
|
|
||||||
|
|
||||||
assert next_batch.batch_id == 0
|
assert next_batch.batch_id == 0
|
||||||
assert next_batch.input_ids[0, 0] == 12355
|
assert next_batch.input_ids[:, next_batch.max_input_length - 1].tolist() == [
|
||||||
assert torch.all(next_batch.input_ids[1:] == 13)
|
12355,
|
||||||
|
13,
|
||||||
|
13,
|
||||||
|
]
|
||||||
|
|
||||||
assert next_batch.input_lengths == [3, 2, 2]
|
assert next_batch.input_lengths == [3, 2, 2]
|
||||||
assert next_batch.max_input_length == 3
|
assert next_batch.max_input_length == 3
|
||||||
@ -244,9 +269,6 @@ def test_batch_concatenate(
|
|||||||
assert next_batch.requests[0] == next_batch_0.requests[0]
|
assert next_batch.requests[0] == next_batch_0.requests[0]
|
||||||
assert next_batch.requests[1:] == next_batch_1.requests
|
assert next_batch.requests[1:] == next_batch_1.requests
|
||||||
|
|
||||||
assert next_batch.next_token_choosers[0] == next_batch_0.next_token_choosers[0]
|
|
||||||
assert next_batch.next_token_choosers[1:] == next_batch_1.next_token_choosers
|
|
||||||
|
|
||||||
assert next_batch.stopping_criterias[0] == next_batch_0.stopping_criterias[0]
|
assert next_batch.stopping_criterias[0] == next_batch_0.stopping_criterias[0]
|
||||||
assert next_batch.stopping_criterias[1:] == next_batch_1.stopping_criterias
|
assert next_batch.stopping_criterias[1:] == next_batch_1.stopping_criterias
|
||||||
|
|
||||||
|
@ -212,7 +212,7 @@ class VectorizedCausalLMBatch(Batch):
|
|||||||
]
|
]
|
||||||
for tensor in tensors_to_update:
|
for tensor in tensors_to_update:
|
||||||
# Update tensors in-place to allow incremental garbage collection
|
# Update tensors in-place to allow incremental garbage collection
|
||||||
tensors_to_update.data = tensor[kv_cache_slice]
|
tensor.data = tensor[kv_cache_slice]
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@ -320,7 +320,12 @@ class VectorizedCausalLMBatch(Batch):
|
|||||||
past_key_values = []
|
past_key_values = []
|
||||||
for i, kv_format in enumerate(kv_formats):
|
for i, kv_format in enumerate(kv_formats):
|
||||||
for j in range(1 if kv_format is None else kv_format):
|
for j in range(1 if kv_format is None else kv_format):
|
||||||
tensors_to_merge = [batch.past_key_values[i] for batch in batches]
|
tensors_to_merge = [
|
||||||
|
batch.past_key_values[i]
|
||||||
|
if kv_format is None
|
||||||
|
else batch.past_key_values[i][j]
|
||||||
|
for batch in batches
|
||||||
|
]
|
||||||
# Generally `max_input_length`, unless the model allocates more than needed.
|
# Generally `max_input_length`, unless the model allocates more than needed.
|
||||||
right_indices = [
|
right_indices = [
|
||||||
left_index + tensor.size(kv_cache_seq_dim)
|
left_index + tensor.size(kv_cache_seq_dim)
|
||||||
@ -461,7 +466,8 @@ class VectorizedCausalLM(Model):
|
|||||||
.squeeze(1)
|
.squeeze(1)
|
||||||
.tolist()
|
.tolist()
|
||||||
)
|
)
|
||||||
if query_length > 1:
|
is_prefill = batch.past_key_values is None
|
||||||
|
if is_prefill:
|
||||||
prefill_token_ids = batch.input_ids[:, :key_length].tolist()
|
prefill_token_ids = batch.input_ids[:, :key_length].tolist()
|
||||||
prefill_logprobs = (
|
prefill_logprobs = (
|
||||||
logprobs.gather(2, batch.input_ids[:, 1:key_length, None])
|
logprobs.gather(2, batch.input_ids[:, 1:key_length, None])
|
||||||
@ -509,7 +515,11 @@ class VectorizedCausalLM(Model):
|
|||||||
# Decode generated tokens
|
# Decode generated tokens
|
||||||
# TODO: Same as stopping_criteria.current_output?
|
# TODO: Same as stopping_criteria.current_output?
|
||||||
output_text = self.decode(
|
output_text = self.decode(
|
||||||
batch.input_ids[i, -stopping_criterias.current_tokens :]
|
batch.input_ids[
|
||||||
|
i,
|
||||||
|
batch.max_input_length
|
||||||
|
- stopping_criterias.current_tokens : batch.max_input_length,
|
||||||
|
]
|
||||||
)
|
)
|
||||||
# TODO: Seed
|
# TODO: Seed
|
||||||
generated_text = GeneratedText(
|
generated_text = GeneratedText(
|
||||||
@ -522,7 +532,7 @@ class VectorizedCausalLM(Model):
|
|||||||
|
|
||||||
generation = Generation(
|
generation = Generation(
|
||||||
batch.requests[i].id,
|
batch.requests[i].id,
|
||||||
prefill_tokens[i] if batch.details and query_length > 1 else None,
|
prefill_tokens[i] if batch.details and is_prefill else None,
|
||||||
next_token_id,
|
next_token_id,
|
||||||
token_logprobs[i] if batch.details else 0.0,
|
token_logprobs[i] if batch.details else 0.0,
|
||||||
next_token_text,
|
next_token_text,
|
||||||
|
@ -214,6 +214,7 @@ class HeterogeneousTypicalLogitsWarper(LogitsWarper):
|
|||||||
scores = scores.masked_fill_(indices_to_remove, self.filter_value)
|
scores = scores.masked_fill_(indices_to_remove, self.filter_value)
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
class HeterogeneousProcessorWrapper(LogitsProcessor):
|
class HeterogeneousProcessorWrapper(LogitsProcessor):
|
||||||
r"""
|
r"""
|
||||||
A wrapper for logit warpers or processors without heterogeneous parameter support.
|
A wrapper for logit warpers or processors without heterogeneous parameter support.
|
||||||
@ -275,9 +276,15 @@ class HeterogeneousNextTokenChooser:
|
|||||||
|
|
||||||
watermark = self._standardize(watermark, batch_size, False)
|
watermark = self._standardize(watermark, batch_size, False)
|
||||||
if any(watermark):
|
if any(watermark):
|
||||||
warpers.append(HeterogeneousProcessorWrapper(
|
warpers.append(
|
||||||
{i:WatermarkLogitsProcessor(device=device) for i, x in watermark if x}
|
HeterogeneousProcessorWrapper(
|
||||||
))
|
{
|
||||||
|
i: WatermarkLogitsProcessor(device=device)
|
||||||
|
for i, x in watermark
|
||||||
|
if x
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
repetition_penalty = self._standardize(repetition_penalty, batch_size, 1.0)
|
repetition_penalty = self._standardize(repetition_penalty, batch_size, 1.0)
|
||||||
if any([x != 1.0 for x in repetition_penalty]):
|
if any([x != 1.0 for x in repetition_penalty]):
|
||||||
|
Loading…
Reference in New Issue
Block a user