Skip server tests of not enabled models (#125)

Co-authored-by: Karol Damaszke <kdamaszke@habana.ai>
This commit is contained in:
Karol Damaszke 2024-04-09 14:15:41 +02:00 committed by GitHub
parent c6739526c6
commit 30cc78773e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 34 additions and 19 deletions

View File

@ -1,3 +1,5 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
import pytest
import torch
@ -5,9 +7,9 @@ from copy import copy
from transformers import AutoTokenizer
from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.models.causal_lm import CausalLMBatch, PAD_SEQUENCE_TO_MULTIPLE_OF
from text_generation_server.utils import weight_hub_files, download_weights
from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOMSharded
from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOM
@pytest.fixture(scope="session")
@ -16,7 +18,7 @@ def default_bloom():
revision = "main"
filenames = weight_hub_files(model_id, revision, ".safetensors")
download_weights(filenames, model_id, revision)
return BLOOMSharded(model_id)
return BLOOM(model_id)
@pytest.fixture(scope="session")
@ -30,7 +32,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
id=0,
inputs="Test",
prefill_logprobs=True,
truncate=100,
truncate=PAD_SEQUENCE_TO_MULTIPLE_OF,
parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters,
)
@ -44,7 +46,7 @@ def default_pb_batch(default_pb_request):
@pytest.fixture
def default_bloom_batch(default_pb_batch, bloom_560m_tokenizer):
return BloomCausalLMBatch.from_pb(
default_pb_batch, bloom_560m_tokenizer, torch.float32, torch.device("cpu")
default_pb_batch, bloom_560m_tokenizer, torch.float32, torch.device("hpu")
)
@ -58,7 +60,7 @@ def default_multi_requests_bloom_batch(default_pb_request, bloom_560m_tokenizer)
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
return BloomCausalLMBatch.from_pb(
batch_pb, bloom_560m_tokenizer, torch.float32, torch.device("cpu")
batch_pb, bloom_560m_tokenizer, torch.float32, torch.device("hpu")
)
@ -66,30 +68,29 @@ def test_batch_from_pb(default_pb_batch, default_bloom_batch):
batch = default_bloom_batch
assert batch.batch_id == default_pb_batch.id
assert batch.requests == default_pb_batch.requests
assert len(batch.requests) == len(default_pb_batch.requests) == default_pb_batch.size
for request, pb_request in zip(batch.requests, default_pb_batch.requests):
assert request.data == pb_request
assert len(batch.input_ids) == default_pb_batch.size
assert batch.input_ids[0][-1] == 10264
assert torch.all(batch.input_ids[0][:-1] == 3)
assert batch.input_ids[0][-1] == 3
assert batch.input_ids[0][-2] == 10264
assert torch.all(batch.input_ids[0][:-2] == 3)
assert batch.attention_mask[0][0] == 1
assert torch.all(batch.attention_mask[0][1:] == 0)
assert batch.attention_mask[0][-1] == 0
assert batch.attention_mask[0][-2] == 1
assert torch.all(batch.attention_mask[0][:-2] == 0)
assert batch.past_key_values is None
assert all(
[
torch.equal(input_ids, all_input_ids[:, 0])
for input_ids, all_input_ids in zip(batch.input_ids, batch.all_input_ids)
torch.equal(input_ids, request.all_input_ids[:batch.input_length+1, 0])
for input_ids, request in zip(batch.input_ids, batch.requests)
]
)
assert batch.input_lengths == [1]
assert len(batch) == default_pb_batch.size
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == len(batch)
assert batch.max_input_length == batch.input_lengths[0]
assert batch.max_input_length == batch.input_length == PAD_SEQUENCE_TO_MULTIPLE_OF - 1
def test_batch_concatenate_no_prefill(default_bloom_batch):
@ -101,6 +102,7 @@ def test_causal_lm_batch_type(default_bloom):
assert default_bloom.batch_type == BloomCausalLMBatch
@pytest.mark.skip
def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
sequence_length = len(default_bloom_batch.all_input_ids[0])
generations, next_batch = default_bloom.generate_token(default_bloom_batch)
@ -138,6 +140,7 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
assert generations[0].request_id == 0
@pytest.mark.skip
def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch):
next_batch = default_bloom_batch
for _ in range(default_bloom_batch.stopping_criterias[0].max_new_tokens - 1):
@ -158,6 +161,7 @@ def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch)
)
@pytest.mark.skip
def test_causal_lm_generate_token_completion_multi(
default_bloom, default_multi_requests_bloom_batch
):
@ -208,6 +212,7 @@ def test_causal_lm_generate_token_completion_multi(
)
@pytest.mark.skip
def test_batch_concatenate(
default_bloom, default_bloom_batch, default_multi_requests_bloom_batch
):

View File

@ -61,6 +61,7 @@ def default_multi_requests_seq2seq_lm_batch(default_pb_request, mt0_small_tokeni
)
@pytest.mark.skip("seq2seq model not enabled on HPU yet")
def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch):
batch = default_seq2seq_lm_batch
sequence_length = len(default_seq2seq_lm_batch.input_ids[0])
@ -92,15 +93,18 @@ def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch):
assert batch.max_decoder_input_length == batch.decoder_input_lengths[0]
@pytest.mark.skip("seq2seq model not enabled on HPU yet")
def test_batch_concatenate_no_prefill(default_seq2seq_lm_batch):
with pytest.raises(ValueError):
Seq2SeqLMBatch.concatenate([default_seq2seq_lm_batch, default_seq2seq_lm_batch])
@pytest.mark.skip("seq2seq model not enabled on HPU yet")
def test_seq2seq_lm_batch_type(default_seq2seq_lm):
assert default_seq2seq_lm.batch_type == Seq2SeqLMBatch
@pytest.mark.skip("seq2seq model not enabled on HPU yet")
def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch):
sequence_length = len(default_seq2seq_lm_batch.input_ids[0])
generations, next_batch = default_seq2seq_lm.generate_token(
@ -156,6 +160,7 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
assert generations[0].request_id == 0
@pytest.mark.skip("seq2seq model not enabled on HPU yet")
def test_seq2seq_lm_generate_token_completion(
default_seq2seq_lm, default_seq2seq_lm_batch
):
@ -173,6 +178,7 @@ def test_seq2seq_lm_generate_token_completion(
assert generations[0].generated_text.generated_tokens == 7
@pytest.mark.skip("seq2seq model not enabled on HPU yet")
def test_seq2seq_lm_generate_token_completion_multi(
default_seq2seq_lm, default_multi_requests_seq2seq_lm_batch
):
@ -210,6 +216,7 @@ def test_seq2seq_lm_generate_token_completion_multi(
assert generations[0].generated_text.generated_tokens == 7
@pytest.mark.skip("seq2seq model not enabled on HPU yet")
def test_batch_concatenate(
default_seq2seq_lm,
default_seq2seq_lm_batch,

View File

@ -340,6 +340,9 @@ class CausalLMBatch(Batch):
@classmethod
def recombine(cls, batches: List["CausalLMBatch"], pad_token_id: int) -> "CausalLMBatch":
if not all(b.past_key_values is not None for b in batches):
raise ValueError("KV cache not allocated! Cannot recombine before prefill!")
total_requests = sum(len(b) for b in batches)
new_bs = round_up(total_requests, BATCH_BUCKET_SIZE)
batch_id = batches[0].batch_id