feat(gaudi): add integration test

This commit is contained in:
baptiste 2025-03-28 08:58:28 +00:00
parent 3d059f91ab
commit 064c7bd03c
20 changed files with 511 additions and 3071 deletions

View File

@ -129,9 +129,9 @@ jobs:
export label_extension="-gaudi"
export docker_volume="/mnt/cache"
export docker_devices=""
export runs_on="ubuntu-latest"
export runs_on="aws-dl1-24xlarge"
export platform=""
export extra_pytest=""
export extra_pytest="--gaudi"
export target=""
esac
echo $dockerfile

View File

@ -1,6 +1,6 @@
mkfile_path := $(abspath $(lastword $(MAKEFILE_LIST)))
mkfile_dir := $(dir $(mkfile_path))
root_dir := "${mkfile_dir}/../.."
root_dir := ${mkfile_dir}/../..
HABANA_VERSION := 1.20.0
PYTORCH_VERSION := 2.6.0
@ -47,3 +47,10 @@ local-dev-install: install-dependencies
make install-server && \
make install-router && \
make install-launcher'
# In order to run the integration tests, you need to first build the image (make -C backends/gaudi image)
run-integration-tests:
uv pip install -r ${root_dir}/backends/gaudi/server/integration-tests/requirements.txt
DOCKER_VOLUME=${root_dir}/data \
HF_TOKEN=`cat ${HOME}/.cache/huggingface/token` \
uv run pytest --durations=0 -sv ${root_dir}/backends/gaudi/server/integration-tests

View File

@ -96,3 +96,42 @@ curl 127.0.0.1:8080/generate \
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
-H 'Content-Type: application/json'
```
### Integration tests
To run the integration tests, you need to first build the image:
```bash
make -C backends/gaudi image
```
Then run the following command to run the integration tests:
```bash
make -C backends/gaudi run-integration-tests
```
#### How the integration tests works
The integration tests works as follows:
1. Start a tgi server in a container, similar to the command:
```bash
docker run --runtime=habana --ipc=host --cap-add=sys_nice \
-p 8080:80 -v $volume:/data \
-e LOG_LEVEL=debug -e HF_TOKEN=$hf_token \
tgi-gaudi --model-id $model \
--max-input-tokens 512 --max-total-tokens 1024 --max-batch-size 4 --max-batch-prefill-tokens 2048
```
2. Do a /generate request to the server, similar to the command:
```bash
curl 127.0.0.1:8080/generate \
-X POST \
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
-H 'Content-Type: application/json'
```
3. Check the output of the server against the expected output:
```python
assert curl_output == expected_output
```
This is the repeated for a set of models and configurations.

View File

@ -0,0 +1,300 @@
import asyncio
import contextlib
import os
import shlex
import subprocess
import sys
import threading
import time
from tempfile import TemporaryDirectory
from typing import List
import docker
import pytest
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
from docker.errors import NotFound
from loguru import logger
from test_model import TEST_CONFIGS
from text_generation import AsyncClient
from text_generation.types import Response
# Use the latest image from the local docker build
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", "tgi-gaudi")
DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", None)
HF_TOKEN = os.getenv("HF_TOKEN", None)
assert (
HF_TOKEN is not None
), "HF_TOKEN is not set, please set it as some models are gated and thus the test will fail without it"
if DOCKER_VOLUME is None:
logger.warning(
"DOCKER_VOLUME is not set, this will lead to the tests redownloading the models on each run, consider setting it to speed up testing"
)
LOG_LEVEL = os.getenv("LOG_LEVEL", "info")
BASE_ENV = {
"HF_HUB_ENABLE_HF_TRANSFER": "1",
"LOG_LEVEL": LOG_LEVEL,
"HF_TOKEN": os.getenv("HF_TOKEN", None),
}
HABANA_RUN_ARGS = {
"runtime": "habana",
"ipc_mode": "host",
"cap_add": ["sys_nice"],
}
logger.add(
sys.stderr,
format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>",
level="INFO",
)
# def cleanup_handler(signum, frame):
# logger.info("\nCleaning up containers due to shutdown, please wait...")
# try:
# client = docker.from_env()
# containers = client.containers.list(filters={"name": "tgi-tests-"})
# for container in containers:
# try:
# container.stop()
# container.remove()
# logger.info(f"Successfully cleaned up container {container.name}")
# except Exception as e:
# logger.error(f"Error cleaning up container {container.name}: {str(e)}")
# except Exception as e:
# logger.error(f"Error during cleanup: {str(e)}")
# sys.exit(1)
# signal.signal(signal.SIGINT, cleanup_handler)
# signal.signal(signal.SIGTERM, cleanup_handler)
def stream_container_logs(container):
"""Stream container logs in a separate thread."""
try:
for log in container.logs(stream=True, follow=True):
print(
"[TGI Server Logs] " + log.decode("utf-8"),
end="",
file=sys.stderr,
flush=True,
)
except Exception as e:
logger.error(f"Error streaming container logs: {str(e)}")
class LauncherHandle:
def __init__(self, port: int):
self.client = AsyncClient(f"http://localhost:{port}", timeout=3600)
def _inner_health(self):
raise NotImplementedError
async def health(self, timeout: int = 60):
assert timeout > 0
start_time = time.time()
logger.info(f"Starting health check with timeout of {timeout}s")
for attempt in range(timeout):
if not self._inner_health():
logger.error("Launcher crashed during health check")
raise RuntimeError("Launcher crashed")
try:
await self.client.generate("test")
elapsed = time.time() - start_time
logger.info(f"Health check passed after {elapsed:.1f}s")
return
except (ClientConnectorError, ClientOSError, ServerDisconnectedError) as e:
if attempt == timeout - 1:
logger.error(f"Health check failed after {timeout}s: {str(e)}")
raise RuntimeError(f"Health check failed: {str(e)}")
if attempt % 10 == 0 and attempt != 0: # Only log every 10th attempt
logger.debug(
f"Connection attempt {attempt}/{timeout} failed: {str(e)}"
)
time.sleep(1)
except Exception as e:
logger.error(f"Unexpected error during health check: {str(e)}")
# Get full traceback for debugging
import traceback
logger.error(f"Full traceback:\n{traceback.format_exc()}")
raise
class ContainerLauncherHandle(LauncherHandle):
def __init__(self, docker_client, container_name, port: int):
super(ContainerLauncherHandle, self).__init__(port)
self.docker_client = docker_client
self.container_name = container_name
def _inner_health(self) -> bool:
try:
container = self.docker_client.containers.get(self.container_name)
status = container.status
if status not in ["running", "created"]:
logger.warning(f"Container status is {status}")
# Get container logs for debugging
logs = container.logs().decode("utf-8")
logger.debug(f"Container logs:\n{logs}")
return status in ["running", "created"]
except Exception as e:
logger.error(f"Error checking container health: {str(e)}")
return False
class ProcessLauncherHandle(LauncherHandle):
def __init__(self, process, port: int):
super(ProcessLauncherHandle, self).__init__(port)
self.process = process
def _inner_health(self) -> bool:
return self.process.poll() is None
@pytest.fixture(scope="module")
def data_volume():
tmpdir = TemporaryDirectory()
yield tmpdir.name
try:
# Cleanup the temporary directory using sudo as it contains root files created by the container
subprocess.run(shlex.split(f"sudo rm -rf {tmpdir.name}"), check=True)
except subprocess.CalledProcessError as e:
logger.error(f"Error cleaning up temporary directory: {str(e)}")
@pytest.fixture(scope="module")
def launcher(data_volume):
@contextlib.contextmanager
def docker_launcher(
model_id: str,
test_name: str,
):
logger.info(
f"Starting docker launcher for model {model_id} and test {test_name}"
)
port = 8080
client = docker.from_env()
container_name = f"tgi-gaudi-test-{test_name}"
try:
container = client.containers.get(container_name)
logger.info(
f"Stopping existing container {container_name} for test {test_name}"
)
container.stop()
container.wait()
except NotFound:
pass
except Exception as e:
logger.error(f"Error handling existing container: {str(e)}")
model_name = next(
name for name, cfg in TEST_CONFIGS.items() if cfg["model_id"] == model_id
)
tgi_args = TEST_CONFIGS[model_name]["args"].copy()
env = BASE_ENV.copy()
# Add model_id to env
env["MODEL_ID"] = model_id
# Add env config that is definied in the fixture parameter
if "env_config" in TEST_CONFIGS[model_name]:
env.update(TEST_CONFIGS[model_name]["env_config"].copy())
volumes = [f"{DOCKER_VOLUME}:/data"]
logger.debug(f"Using volume {volumes}")
try:
logger.info(f"Creating container with name {container_name}")
# Log equivalent docker run command for debugging, this is not actually executed
container = client.containers.run(
DOCKER_IMAGE,
command=tgi_args,
name=container_name,
environment=env,
detach=True,
volumes=volumes,
ports={"80/tcp": 8080},
**HABANA_RUN_ARGS,
)
logger.info(f"Container {container_name} started successfully")
# Start log streaming in a background thread
log_thread = threading.Thread(
target=stream_container_logs,
args=(container,),
daemon=True, # This ensures the thread will be killed when the main program exits
)
log_thread.start()
# Add a small delay to allow container to initialize
time.sleep(2)
# Check container status after creation
status = container.status
logger.debug(f"Initial container status: {status}")
if status not in ["running", "created"]:
logs = container.logs().decode("utf-8")
logger.error(f"Container failed to start properly. Logs:\n{logs}")
yield ContainerLauncherHandle(client, container.name, port)
except Exception as e:
logger.error(f"Error starting container: {str(e)}")
# Get full traceback for debugging
import traceback
logger.error(f"Full traceback:\n{traceback.format_exc()}")
raise
finally:
try:
container = client.containers.get(container_name)
logger.info(f"Stopping container {container_name}")
container.stop()
container.wait()
container_output = container.logs().decode("utf-8")
print(container_output, file=sys.stderr)
container.remove()
logger.info(f"Container {container_name} removed successfully")
except NotFound:
pass
except Exception as e:
logger.warning(f"Error cleaning up container: {str(e)}")
return docker_launcher
@pytest.fixture(scope="module")
def generate_load():
async def generate_load_inner(
client: AsyncClient, prompt: str, max_new_tokens: int, n: int
) -> List[Response]:
try:
futures = [
client.generate(
prompt,
max_new_tokens=max_new_tokens,
decoder_input_details=True,
)
for _ in range(n)
]
return await asyncio.gather(*futures)
except Exception as e:
logger.error(f"Error generating load: {str(e)}")
raise
return generate_load_inner

View File

@ -0,0 +1,2 @@
[pytest]
asyncio_mode = auto

View File

@ -0,0 +1,20 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
pytest >= 8.3.5
pytest-asyncio >= 0.26.0
docker >= 7.1.0
Levenshtein >= 0.27.1
loguru >= 0.7.3
aiohttp >= 3.11.14
text-generation

View File

@ -0,0 +1,140 @@
from typing import Any, Dict
from text_generation import AsyncClient
import pytest
from Levenshtein import distance as levenshtein_distance
TEST_CONFIGS = {
"llama3-8b-shared-gaudi1": {
"model_id": "meta-llama/Llama-3.1-8B-Instruct",
"input": "What is Deep Learning?",
"expected_greedy_output": " A Beginners Guide\nDeep learning is a subset of machine learning that involves the use",
"expected_batch_output": " A Beginners Guide\nDeep learning is a subset of machine learning that involves the use",
"args": [
"--sharded",
"true",
"--num-shard",
"8",
"--max-input-tokens",
"512",
"--max-total-tokens",
"1024",
"--max-batch-size",
"8",
"--max-batch-prefill-tokens",
"2048",
],
},
"llama3-8b-gaudi1": {
"model_id": "meta-llama/Llama-3.1-8B-Instruct",
"input": "What is Deep Learning?",
"expected_greedy_output": " A Beginners Guide\nDeep learning is a subset of machine learning that involves the use of artificial neural networks to analyze and interpret data. It is a type of",
# "expected_sampling_output": " Part 2: Backpropagation\nIn my last post , I introduced the concept of deep learning and covered some high-level topics, including feedforward neural networks.",
"expected_batch_output": " A Beginners Guide\nDeep learning is a subset of machine learning that involves the use of artificial neural networks to analyze and interpret data. It is a type of",
"env_config": {},
"args": [
"--max-input-tokens",
"512",
"--max-total-tokens",
"1024",
"--max-batch-size",
"4",
"--max-batch-prefill-tokens",
"2048",
],
},
}
@pytest.fixture(scope="module", params=TEST_CONFIGS.keys())
def test_config(request) -> Dict[str, Any]:
"""Fixture that provides model configurations for testing."""
test_config = TEST_CONFIGS[request.param]
test_config["test_name"] = request.param
return test_config
@pytest.fixture(scope="module")
def model_id(test_config):
yield test_config["model_id"]
@pytest.fixture(scope="module")
def test_name(test_config):
yield test_config["test_name"]
@pytest.fixture(scope="module")
def expected_outputs(test_config):
return {
"greedy": test_config["expected_greedy_output"],
# "sampling": model_config["expected_sampling_output"],
"batch": test_config["expected_batch_output"],
}
@pytest.fixture(scope="module")
def input(test_config):
return test_config["input"]
@pytest.fixture(scope="module")
def tgi_service(launcher, model_id, test_name):
with launcher(model_id, test_name) as tgi_service:
yield tgi_service
@pytest.fixture(scope="module")
async def tgi_client(tgi_service) -> AsyncClient:
await tgi_service.health(1000)
return tgi_service.client
@pytest.mark.asyncio
async def test_model_single_request(
tgi_client: AsyncClient, expected_outputs: Dict[str, Any], input: str
):
# Bounded greedy decoding without input
response = await tgi_client.generate(
input,
max_new_tokens=32,
)
assert response.details.generated_tokens == 32
assert response.generated_text == expected_outputs["greedy"]
# TO FIX: Add sampling tests later, the gaudi backend seems to be flaky with the sampling even with a fixed seed
# Sampling
# response = await tgi_client.generate(
# "What is Deep Learning?",
# do_sample=True,
# top_k=50,
# top_p=0.9,
# repetition_penalty=1.2,
# max_new_tokens=100,
# seed=42,
# decoder_input_details=True,
# )
# assert expected_outputs["sampling"] in response.generated_text
@pytest.mark.asyncio
async def test_model_multiple_requests(
tgi_client, generate_load, expected_outputs, input
):
num_requests = 4
responses = await generate_load(
tgi_client,
input,
max_new_tokens=32,
n=num_requests,
)
assert len(responses) == 4
expected = expected_outputs["batch"]
for r in responses:
assert r.details.generated_tokens == 32
# Compute the similarity with the expectation using the levenshtein distance
# We should not have more than two substitutions or additions
assert levenshtein_distance(r.generated_text, expected) < 3

View File

@ -1,23 +0,0 @@
import pytest
import os
from text_generation_server.pb import generate_pb2
os.environ["USE_PREFIX_CACHING"] = "1"
os.environ["ATTENTION"] = "flashinfer"
@pytest.fixture
def default_pb_parameters():
return generate_pb2.NextTokenChooserParameters(
temperature=1.0,
repetition_penalty=1.0,
top_k=0,
top_p=1.0,
typical_p=1.0,
do_sample=False,
)
@pytest.fixture
def default_pb_stop_parameters():
return generate_pb2.StoppingCriteriaParameters(stop_sequences=[], max_new_tokens=10)

View File

@ -1,363 +0,0 @@
import pytest
import torch
from copy import copy
from transformers import AutoTokenizer
from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.utils import weight_hub_files, download_weights
from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOMSharded
from text_generation_server.models.custom_modeling.bloom_modeling import (
BloomForCausalLM,
)
@pytest.fixture(scope="session")
def default_bloom():
model_id = "bigscience/bloom-560m"
revision = "main"
filenames = weight_hub_files(model_id, revision, ".safetensors")
download_weights(filenames, model_id, revision)
return BLOOMSharded(
model_id,
model_class=BloomForCausalLM,
)
@pytest.fixture(scope="session")
def bloom_560m_tokenizer():
return AutoTokenizer.from_pretrained("bigscience/bloom-560m", padding_side="left")
@pytest.fixture
def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="Test",
input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="Test")]),
prefill_logprobs=True,
truncate=100,
parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters,
)
@pytest.fixture
def default_pb_batch(default_pb_request):
return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1)
@pytest.fixture
def default_bloom_batch(default_pb_batch, bloom_560m_tokenizer):
return BloomCausalLMBatch.from_pb(
default_pb_batch, bloom_560m_tokenizer, torch.float32, torch.device("cpu")
)
@pytest.fixture
def default_multi_requests_bloom_batch(default_pb_request, bloom_560m_tokenizer):
req_0 = copy(default_pb_request)
req_0.id = 1
req_1 = default_pb_request
req_1.id = 2
req_1.stopping_parameters.max_new_tokens = 5
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
return BloomCausalLMBatch.from_pb(
batch_pb, bloom_560m_tokenizer, torch.float32, torch.device("cpu")
)
def test_batch_from_pb(default_pb_batch, default_bloom_batch):
batch = default_bloom_batch
assert batch.batch_id == default_pb_batch.id
assert batch.requests == default_pb_batch.requests
assert len(batch.input_ids) == default_pb_batch.size
assert batch.input_ids[0][-1] == 10264
assert torch.all(batch.input_ids[0][:-1] == 3)
assert batch.attention_mask[0][0] == 1
assert torch.all(batch.attention_mask[0][1:] == 0)
assert batch.past_key_values is None
assert all(
[
torch.equal(input_ids, all_input_ids[:, 0])
for input_ids, all_input_ids in zip(batch.input_ids, batch.all_input_ids)
]
)
assert batch.input_lengths == [1]
assert len(batch) == default_pb_batch.size
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == len(batch)
assert batch.max_input_length == batch.input_lengths[0]
def test_batch_concatenate_no_prefill(default_bloom_batch):
with pytest.raises(ValueError):
BloomCausalLMBatch.concatenate([default_bloom_batch, default_bloom_batch])
def test_causal_lm_batch_type(default_bloom):
assert default_bloom.batch_type == BloomCausalLMBatch
def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
sequence_length = len(default_bloom_batch.all_input_ids[0])
generations, next_batch, _ = default_bloom.generate_token(default_bloom_batch)
assert len(generations) == len(default_bloom_batch)
assert isinstance(next_batch, CausalLMBatch)
assert not next_batch.keys_head_dim_last
assert len(next_batch.all_input_ids) == len(next_batch)
assert len(next_batch.all_input_ids[0]) == sequence_length + 1
assert len(next_batch.attention_mask[0]) == 11
assert torch.all(next_batch.all_input_ids[0][-2:] == 10264)
assert torch.all(next_batch.all_input_ids[0][:-2] == 3)
assert torch.all(next_batch.attention_mask[0][:2] == 1)
assert torch.all(next_batch.attention_mask[0][2:] == 0)
assert next_batch.input_ids.shape == (len(next_batch), 1)
assert next_batch.input_ids[0, 0] == 10264
assert next_batch.input_lengths == [2]
assert next_batch.max_input_length == next_batch.input_lengths[0]
assert next_batch.past_key_values is not None
assert all(
[p[0].shape == (16, 64, sequence_length) for p in next_batch.past_key_values]
)
assert all(
[p[1].shape == (16, sequence_length, 64) for p in next_batch.past_key_values]
)
assert all([generation.generated_text is None for generation in generations])
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
assert all(
[
token_id.item() == 10264
for generation in generations
for token_id in generation.tokens.token_ids
]
)
assert all(
[
token_text == "Test"
for generation in generations
for token_text in generation.tokens.texts
]
)
assert generations[0].request_id == 0
def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch):
next_batch = default_bloom_batch
for _ in range(default_bloom_batch.stopping_criterias[0].max_new_tokens - 1):
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert len(generations) == len(default_bloom_batch)
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
assert (
generations[0].generated_text.text == "TestTestTestTestTestTestTestTestTestTest"
)
assert generations[0].request_id == default_bloom_batch.requests[0].id
assert (
generations[0].generated_text.generated_tokens
== default_bloom_batch.stopping_criterias[0].max_new_tokens
)
def test_causal_lm_generate_token_completion_multi(
default_bloom, default_multi_requests_bloom_batch
):
next_batch = default_multi_requests_bloom_batch
for i in range(
default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 1
):
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert len(generations) == len(default_multi_requests_bloom_batch)
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 2
assert generations[1].generated_text.text == "TestTestTestTestTest"
assert (
generations[1].request_id == default_multi_requests_bloom_batch.requests[1].id
)
assert (
generations[1].generated_text.generated_tokens
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
)
# Copy stopping_criterias before filtering
stopping_criterias = default_multi_requests_bloom_batch.stopping_criterias.copy()
next_batch = next_batch.filter([next_batch.requests[0].id])
for _ in range(
stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
):
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
assert (
generations[0].generated_text.text == "TestTestTestTestTestTestTestTestTestTest"
)
assert (
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id
)
assert (
generations[0].generated_text.generated_tokens
== default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
)
def test_batch_concatenate(
default_bloom, default_bloom_batch, default_multi_requests_bloom_batch
):
next_batch_0 = default_bloom_batch
_, next_batch_0, _ = default_bloom.generate_token(next_batch_0)
_, next_batch_0, _ = default_bloom.generate_token(next_batch_0)
next_batch_1 = default_multi_requests_bloom_batch
_, next_batch_1, _ = default_bloom.generate_token(next_batch_1)
# Clone past_key_values before concatenating to compare after,
# because they are removed from the concatenated batches
next_batch_0_past_key_values = [
(k.clone(), v.clone()) for (k, v) in next_batch_0.past_key_values
]
next_batch_1_past_key_values = [
(k.clone(), v.clone()) for (k, v) in next_batch_1.past_key_values
]
next_batch = BloomCausalLMBatch.concatenate([next_batch_0, next_batch_1])
assert torch.equal(next_batch.all_input_ids[0], next_batch_0.all_input_ids[0])
assert torch.equal(next_batch.all_input_ids[1], next_batch_1.all_input_ids[0])
assert torch.equal(next_batch.all_input_ids[2], next_batch_1.all_input_ids[1])
assert torch.all(
next_batch.attention_mask[0, : -next_batch.padding_right_offset] == 1
)
assert torch.all(
next_batch.attention_mask[1:, 1 : -next_batch.padding_right_offset] == 1
)
assert torch.all(next_batch.attention_mask[1:, 3:] == 0)
assert next_batch.batch_id == 0
assert torch.all(next_batch.input_ids == 10264)
assert next_batch.input_lengths == [3, 2, 2]
assert next_batch.max_input_length == 3
assert next_batch.requests[0] == next_batch_0.requests[0]
assert next_batch.requests[1:] == list(next_batch_1.requests)
assert next_batch.next_token_choosers[0] == next_batch_0.next_token_choosers[0]
assert next_batch.next_token_choosers[1:] == next_batch_1.next_token_choosers
assert next_batch.stopping_criterias[0] == next_batch_0.stopping_criterias[0]
assert next_batch.stopping_criterias[1:] == next_batch_1.stopping_criterias
assert next_batch.past_key_values is not None
assert all([p[0].shape == (3, 16, 64, 2) for p in next_batch.past_key_values])
assert all([p[1].shape == (3, 16, 2, 64) for p in next_batch.past_key_values])
for i, past in enumerate(next_batch.past_key_values):
assert torch.equal(next_batch_0_past_key_values[i][0][:, :, -2:], past[0][0])
assert torch.equal(
next_batch_1_past_key_values[i][0][:, :, -1:],
past[0][1:, :, :, -1].reshape(-1, 64, 1),
)
assert torch.equal(next_batch_0_past_key_values[i][1][:, -2:, :], past[1][0])
assert torch.equal(
next_batch_1_past_key_values[i][1][:, -1:, :],
past[1][1:, :, -1, :].reshape(-1, 1, 64),
)
for _ in range(
default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 2
):
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 3
assert generations[2].generated_text.text == "TestTestTestTestTest"
assert (
generations[2].request_id == default_multi_requests_bloom_batch.requests[1].id
)
assert (
generations[2].generated_text.generated_tokens
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
)
next_batch = next_batch.filter(
[next_batch.requests[0].id, next_batch.requests[1].id]
)
for _ in range(
default_bloom_batch.stopping_criterias[0].max_new_tokens
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
- 2
):
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 2
assert (
generations[0].generated_text.text == "TestTestTestTestTestTestTestTestTestTest"
)
assert generations[0].request_id == default_bloom_batch.requests[0].id
assert (
generations[0].generated_text.generated_tokens
== default_bloom_batch.stopping_criterias[0].max_new_tokens
)
next_batch = next_batch.filter([next_batch.requests[1].id])
for _ in range(
default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
- default_bloom_batch.stopping_criterias[0].max_new_tokens
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
- 4
):
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
assert (
generations[0].generated_text.text == "TestTestTestTestTestTestTestTestTestTest"
)
assert (
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id
)
assert (
generations[0].generated_text.generated_tokens
== default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
)

View File

@ -1,354 +0,0 @@
import pytest
import torch
from copy import copy
from transformers import AutoTokenizer
from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatch
@pytest.fixture(scope="session")
def default_causal_lm():
return CausalLM.fallback("gpt2")
@pytest.fixture(scope="session")
def gpt2_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left")
tokenizer.pad_token_id = 50256
return tokenizer
@pytest.fixture
def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="Test",
input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="Test")]),
prefill_logprobs=True,
truncate=100,
parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters,
)
@pytest.fixture
def default_pb_batch(default_pb_request):
return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1)
@pytest.fixture
def default_causal_lm_batch(default_pb_batch, gpt2_tokenizer):
return CausalLMBatch.from_pb(
default_pb_batch, gpt2_tokenizer, torch.float32, torch.device("cpu")
)
@pytest.fixture
def default_multi_requests_causal_lm_batch(default_pb_request, gpt2_tokenizer):
req_0 = copy(default_pb_request)
req_0.id = 1
req_1 = default_pb_request
req_1.id = 2
req_1.stopping_parameters.max_new_tokens = 5
batch_pb = generate_pb2.Batch(id=1, requests=[req_0, req_1], size=2)
return CausalLMBatch.from_pb(
batch_pb, gpt2_tokenizer, torch.float32, torch.device("cpu")
)
def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
batch = default_causal_lm_batch
assert batch.batch_id == default_pb_batch.id
assert batch.requests == default_pb_batch.requests
assert len(batch.input_ids) == default_pb_batch.size
assert batch.input_ids[0][-1] == 14402
assert torch.all(batch.input_ids[0][:-1] == 50256)
assert batch.attention_mask[0, 0] == 1
assert torch.all(batch.attention_mask[0, 1:] == 0)
assert batch.past_key_values is None
assert all(
[
torch.equal(input_ids, all_input_ids[:, 0])
for input_ids, all_input_ids in zip(batch.input_ids, batch.all_input_ids)
]
)
assert batch.input_lengths == [1]
assert len(batch) == default_pb_batch.size
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == len(batch)
assert batch.max_input_length == batch.input_lengths[0]
def test_batch_concatenate_no_prefill(default_causal_lm_batch):
with pytest.raises(ValueError):
CausalLMBatch.concatenate([default_causal_lm_batch, default_causal_lm_batch])
def test_causal_lm_batch_type(default_causal_lm):
assert default_causal_lm.batch_type == CausalLMBatch
def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
sequence_length = len(default_causal_lm_batch.all_input_ids[0])
generations, next_batch, _ = default_causal_lm.generate_token(
default_causal_lm_batch
)
assert len(generations) == len(next_batch)
assert isinstance(next_batch, CausalLMBatch)
assert len(next_batch.all_input_ids) == len(next_batch)
assert len(next_batch.all_input_ids[0]) == sequence_length + 1
assert len(next_batch.attention_mask[0]) == 11
assert next_batch.all_input_ids[0][-1] == 13
assert next_batch.all_input_ids[0][-2] == 14402
assert torch.all(next_batch.all_input_ids[0][:-2] == 50256)
assert torch.all(next_batch.attention_mask[0][0:2] == 1)
assert torch.all(next_batch.attention_mask[0][2:] == 0)
assert next_batch.input_ids.shape == (len(next_batch), 1)
assert next_batch.input_ids[0, 0] == 13
assert next_batch.input_lengths == [2]
assert next_batch.max_input_length == next_batch.input_lengths[0]
assert next_batch.past_key_values is not None
assert all(
[p[0].shape == (1, 12, sequence_length, 64) for p in next_batch.past_key_values]
)
assert all(
[p[1].shape == (1, 12, sequence_length, 64) for p in next_batch.past_key_values]
)
assert all([generation.generated_text is None for generation in generations])
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
assert all(
[
token_id.item() == 13
for generation in generations
for token_id in generation.tokens.token_ids
]
)
assert all(
[
token_text == "."
for generation in generations
for token_text in generation.tokens.texts
]
)
assert generations[0].request_id == 0
def test_causal_lm_generate_token_completion(
default_causal_lm, default_causal_lm_batch
):
next_batch = default_causal_lm_batch
for _ in range(default_causal_lm_batch.stopping_criterias[0].max_new_tokens - 1):
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
assert generations[0].generated_text.text == ".java:784) at net.minecraft."
assert generations[0].request_id == default_causal_lm_batch.requests[0].id
assert (
generations[0].generated_text.generated_tokens
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens
)
def test_causal_lm_generate_token_completion_multi(
default_causal_lm, default_multi_requests_causal_lm_batch
):
next_batch = default_multi_requests_causal_lm_batch
for i in range(
default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 1
):
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 2
assert generations[1].generated_text.text == ".java:784)"
assert (
generations[1].request_id
== default_multi_requests_causal_lm_batch.requests[1].id
)
assert (
generations[1].generated_text.generated_tokens
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
)
# Copy stopping_criterias before filtering
stopping_criterias = (
default_multi_requests_causal_lm_batch.stopping_criterias.copy()
)
next_batch = next_batch.filter([next_batch.requests[0].id])
for _ in range(
stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
):
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
assert generations[0].generated_text.text == ".java:784) at net.minecraft."
assert (
generations[0].request_id
== default_multi_requests_causal_lm_batch.requests[0].id
)
assert (
generations[0].generated_text.generated_tokens
== default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
)
def test_batch_concatenate(
default_causal_lm, default_causal_lm_batch, default_multi_requests_causal_lm_batch
):
next_batch_0 = default_causal_lm_batch
_, next_batch_0, _ = default_causal_lm.generate_token(next_batch_0)
_, next_batch_0, _ = default_causal_lm.generate_token(next_batch_0)
next_batch_1 = default_multi_requests_causal_lm_batch
_, next_batch_1, _ = default_causal_lm.generate_token(next_batch_1)
# Clone past_key_values before concatenating to compare after,
# because they are removed from the concatenated batches
next_batch_0_past_key_values = [
(k.clone(), v.clone()) for (k, v) in next_batch_0.past_key_values
]
next_batch_1_past_key_values = [
(k.clone(), v.clone()) for (k, v) in next_batch_1.past_key_values
]
next_batch = CausalLMBatch.concatenate([next_batch_0, next_batch_1])
assert torch.equal(next_batch.all_input_ids[0], next_batch_0.all_input_ids[0])
assert torch.equal(next_batch.all_input_ids[1], next_batch_1.all_input_ids[0])
assert torch.equal(next_batch.all_input_ids[2], next_batch_1.all_input_ids[1])
assert torch.all(
next_batch.attention_mask[0, : -next_batch.padding_right_offset] == 1
)
assert torch.all(
next_batch.attention_mask[1:, 1 : -next_batch.padding_right_offset] == 1
)
assert torch.all(next_batch.attention_mask[1:, 3:] == 0)
assert next_batch.batch_id == 0
assert next_batch.input_ids[0, 0] == 12355
assert torch.all(next_batch.input_ids[1:] == 13)
assert next_batch.input_lengths == [3, 2, 2]
assert next_batch.max_input_length == 3
assert next_batch.requests[0] == next_batch_0.requests[0]
assert next_batch.requests[1:] == list(next_batch_1.requests)
assert next_batch.next_token_choosers[0] == next_batch_0.next_token_choosers[0]
assert next_batch.next_token_choosers[1:] == next_batch_1.next_token_choosers
assert next_batch.stopping_criterias[0] == next_batch_0.stopping_criterias[0]
assert next_batch.stopping_criterias[1:] == next_batch_1.stopping_criterias
assert next_batch.past_key_values is not None
assert all([p[0].shape == (3, 12, 2, 64) for p in next_batch.past_key_values])
assert all([p[1].shape == (3, 12, 2, 64) for p in next_batch.past_key_values])
for i, past in enumerate(next_batch.past_key_values):
assert torch.equal(next_batch_0_past_key_values[i][0][0, :, -2:], past[0][0])
assert torch.equal(
next_batch_1_past_key_values[i][0][:, :, -1:], past[0][1:, :, -1:, :]
)
assert torch.equal(next_batch_0_past_key_values[i][1][0, :, -2:], past[1][0])
assert torch.equal(
next_batch_1_past_key_values[i][1][:, :, -1:], past[1][1:, :, -1:, :]
)
for _ in range(
default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 2
):
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 3
assert generations[2].generated_text.text == ".java:784)"
assert (
generations[2].request_id
== default_multi_requests_causal_lm_batch.requests[1].id
)
assert (
generations[2].generated_text.generated_tokens
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
)
next_batch = next_batch.filter(
[next_batch.requests[0].id, next_batch.requests[1].id]
)
for _ in range(
default_causal_lm_batch.stopping_criterias[0].max_new_tokens
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
- 2
):
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 2
assert generations[0].generated_text.text == ".java:784) at net.minecraft."
assert generations[0].request_id == default_causal_lm_batch.requests[0].id
assert (
generations[0].generated_text.generated_tokens
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens
)
next_batch = next_batch.filter([next_batch.requests[1].id])
for _ in range(
default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
- default_causal_lm_batch.stopping_criterias[0].max_new_tokens
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
- 4
):
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
assert generations[0].generated_text.text == ".java:784) at net.minecraft."
assert (
generations[0].request_id
== default_multi_requests_causal_lm_batch.requests[0].id
)
assert (
generations[0].generated_text.generated_tokens
== default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
)

View File

@ -1,83 +0,0 @@
import pytest
import torch
from transformers import AutoTokenizer
from text_generation_server.models import Model
def get_test_model():
class TestModel(Model):
def batch_type(self):
raise NotImplementedError
def generate_token(self, batch):
raise NotImplementedError
tokenizer = AutoTokenizer.from_pretrained("huggingface/llama-7b")
model = TestModel(
"test_model_id",
torch.nn.Linear(1, 1),
tokenizer,
False,
torch.float32,
torch.device("cpu"),
)
return model
@pytest.mark.private
def test_decode_streaming_english_spaces():
model = get_test_model()
truth = "Hello here, this is a simple test"
all_input_ids = [15043, 1244, 29892, 445, 338, 263, 2560, 1243]
assert (
all_input_ids == model.tokenizer(truth, add_special_tokens=False)["input_ids"]
)
decoded_text = ""
offset = 0
token_offset = 0
for i in range(len(all_input_ids)):
text, offset, token_offset = model.decode_token(
all_input_ids[: i + 1], offset, token_offset
)
decoded_text += text
assert decoded_text == truth
@pytest.mark.private
def test_decode_streaming_chinese_utf8():
model = get_test_model()
truth = "我很感谢你的热情"
all_input_ids = [
30672,
232,
193,
139,
233,
135,
162,
235,
179,
165,
30919,
30210,
234,
134,
176,
30993,
]
decoded_text = ""
offset = 0
token_offset = 0
for i in range(len(all_input_ids)):
text, offset, token_offset = model.decode_token(
all_input_ids[: i + 1], offset, token_offset
)
decoded_text += text
assert decoded_text == truth

View File

@ -1,108 +0,0 @@
import pytest
from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLMBatch, CausalLM
@pytest.fixture(scope="session")
def default_santacoder():
return CausalLM.fallback(model_id="bigcode/santacoder")
@pytest.fixture
def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="def",
input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="def")]),
prefill_logprobs=True,
truncate=100,
parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters,
)
@pytest.fixture
def default_pb_batch(default_pb_request):
return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1)
@pytest.fixture
def default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="<fim-prefix>def<fim-suffix>world<fim-middle>",
input_chunks=generate_pb2.Input(
chunks=[
generate_pb2.InputChunk(
text="<fim-prefix>def<fim-suffix>world<fim-middle>"
)
]
),
prefill_logprobs=True,
truncate=100,
parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters,
)
@pytest.fixture
def default_fim_pb_batch(default_fim_pb_request):
return generate_pb2.Batch(id=0, requests=[default_fim_pb_request], size=1)
@pytest.mark.skip
def test_santacoder_generate_token_completion(default_santacoder, default_pb_batch):
batch = CausalLMBatch.from_pb(
default_pb_batch,
default_santacoder.tokenizer,
default_santacoder.dtype,
default_santacoder.device,
)
next_batch = batch
for _ in range(batch.stopping_criterias[0].max_new_tokens - 1):
generations, next_batch, _ = default_santacoder.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_santacoder.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
assert generations[0].generated_text.text == " test_get_all_users_with_"
assert generations[0].request_id == batch.requests[0].id
assert (
generations[0].generated_text.generated_tokens
== batch.stopping_criterias[0].max_new_tokens
)
@pytest.mark.skip
def test_fim_santacoder_generate_token_completion(
default_santacoder, default_fim_pb_batch
):
batch = CausalLMBatch.from_pb(
default_fim_pb_batch,
default_santacoder.tokenizer,
default_santacoder.dtype,
default_santacoder.device,
)
next_batch = batch
for _ in range(batch.stopping_criterias[0].max_new_tokens - 1):
generations, next_batch, _ = default_santacoder.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_santacoder.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
assert (
generations[0].generated_text.text
== """ineProperty(exports, "__esModule", { value"""
)
assert generations[0].request_id == batch.requests[0].id
assert (
generations[0].generated_text.generated_tokens
== batch.stopping_criterias[0].max_new_tokens
)

View File

@ -1,365 +0,0 @@
import pytest
import torch
from copy import copy
from transformers import AutoTokenizer
from text_generation_server.pb import generate_pb2
from text_generation_server.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch
@pytest.fixture(scope="session")
def mt0_small_tokenizer():
tokenizer = AutoTokenizer.from_pretrained(
"bigscience/mt0-small", padding_side="left"
)
tokenizer.bos_token_id = 0
return tokenizer
@pytest.fixture(scope="session")
def default_seq2seq_lm():
return Seq2SeqLM.fallback("bigscience/mt0-small")
@pytest.fixture
def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="Test",
input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="Test")]),
prefill_logprobs=True,
truncate=100,
parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters,
)
@pytest.fixture
def default_pb_batch(default_pb_request):
return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1)
@pytest.fixture
def default_seq2seq_lm_batch(default_pb_batch, mt0_small_tokenizer):
return Seq2SeqLMBatch.from_pb(
default_pb_batch, mt0_small_tokenizer, torch.float32, torch.device("cpu")
)
@pytest.fixture
def default_multi_requests_seq2seq_lm_batch(default_pb_request, mt0_small_tokenizer):
req_0 = copy(default_pb_request)
req_0.id = 1
req_1 = default_pb_request
req_1.id = 2
req_1.stopping_parameters.max_new_tokens = 5
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
return Seq2SeqLMBatch.from_pb(
batch_pb, mt0_small_tokenizer, torch.float32, torch.device("cpu")
)
def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch):
batch = default_seq2seq_lm_batch
sequence_length = len(default_seq2seq_lm_batch.input_ids[0])
assert batch.batch_id == default_pb_batch.id
assert batch.requests == default_pb_batch.requests
assert batch.input_ids.shape == (default_pb_batch.size, sequence_length)
assert batch.input_ids[0][-2] == 4268
assert batch.input_ids[0][-1] == 1
assert torch.all(batch.input_ids[0][:-2] == 0)
assert torch.all(batch.attention_mask[0][-2:] == 1)
assert torch.all(batch.attention_mask[0][:-2] == 0)
assert len(batch.decoder_input_ids) == default_pb_batch.size
assert batch.decoder_attention_mask is None
assert batch.encoder_last_hidden_state is None
assert batch.past_key_values is None
assert batch.input_lengths == [2]
assert batch.decoder_input_lengths == [1]
assert len(batch) == default_pb_batch.size
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == len(batch)
assert batch.max_input_length == batch.input_lengths[0]
assert batch.max_decoder_input_length == batch.decoder_input_lengths[0]
def test_batch_concatenate_no_prefill(default_seq2seq_lm_batch):
with pytest.raises(ValueError):
Seq2SeqLMBatch.concatenate([default_seq2seq_lm_batch, default_seq2seq_lm_batch])
def test_seq2seq_lm_batch_type(default_seq2seq_lm):
assert default_seq2seq_lm.batch_type == Seq2SeqLMBatch
def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch):
sequence_length = len(default_seq2seq_lm_batch.input_ids[0])
generations, next_batch, _ = default_seq2seq_lm.generate_token(
default_seq2seq_lm_batch
)
assert len(generations) == len(next_batch)
assert isinstance(next_batch, Seq2SeqLMBatch)
assert next_batch.input_ids is None
assert torch.equal(
next_batch.attention_mask, default_seq2seq_lm_batch.attention_mask
)
assert next_batch.input_lengths == default_seq2seq_lm_batch.input_lengths
assert next_batch.max_input_length == default_seq2seq_lm_batch.max_input_length
assert (
next_batch.next_token_choosers == default_seq2seq_lm_batch.next_token_choosers
)
assert next_batch.stopping_criterias == default_seq2seq_lm_batch.stopping_criterias
assert len(next_batch.decoder_input_ids) == len(next_batch)
assert next_batch.all_decoder_input_ids[0][0] == 0
assert next_batch.all_decoder_input_ids[0][1] == 259
assert next_batch.decoder_attention_mask is None
assert next_batch.encoder_last_hidden_state.shape == (1, sequence_length, 512)
assert next_batch.decoder_input_lengths == [2]
assert next_batch.max_decoder_input_length == 2
assert next_batch.past_key_values is not None
assert all(
[p[0].shape == (len(next_batch), 6, 1, 64) for p in next_batch.past_key_values]
)
assert all(
[p[1].shape == (len(next_batch), 6, 1, 64) for p in next_batch.past_key_values]
)
assert all(
[
p[2].shape == (len(next_batch), 6, sequence_length, 64)
for p in next_batch.past_key_values
]
)
assert all(
[
p[3].shape == (len(next_batch), 6, sequence_length, 64)
for p in next_batch.past_key_values
]
)
assert all([generation.generated_text is None for generation in generations])
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
assert all(
[
token_id.item() == 259
for generation in generations
for token_id in generation.tokens.token_ids
]
)
assert all(
[
token_text == " "
for generation in generations
for token_text in generation.tokens.texts
]
)
assert generations[0].request_id == 0
def test_seq2seq_lm_generate_token_completion(
default_seq2seq_lm, default_seq2seq_lm_batch
):
next_batch = default_seq2seq_lm_batch
for _ in range(6):
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
assert generations[0].generated_text.text == "a few weeks"
assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id
assert generations[0].generated_text.generated_tokens == 7
def test_seq2seq_lm_generate_token_completion_multi(
default_seq2seq_lm, default_multi_requests_seq2seq_lm_batch
):
next_batch = default_multi_requests_seq2seq_lm_batch
for i in range(4):
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 2
assert generations[1].generated_text.text == "a few "
assert (
generations[1].request_id
== default_multi_requests_seq2seq_lm_batch.requests[1].id
)
assert generations[1].generated_text.generated_tokens == 5
next_batch = next_batch.filter([next_batch.requests[0].id])
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
assert generations[0].generated_text.text == "a few weeks"
assert (
generations[0].request_id
== default_multi_requests_seq2seq_lm_batch.requests[0].id
)
assert generations[0].generated_text.generated_tokens == 7
def test_batch_concatenate(
default_seq2seq_lm,
default_seq2seq_lm_batch,
default_multi_requests_seq2seq_lm_batch,
):
next_batch_0 = default_seq2seq_lm_batch
_, next_batch_0, _ = default_seq2seq_lm.generate_token(next_batch_0)
_, next_batch_0, _ = default_seq2seq_lm.generate_token(next_batch_0)
next_batch_1 = default_multi_requests_seq2seq_lm_batch
_, next_batch_1, _ = default_seq2seq_lm.generate_token(next_batch_1)
# Copy hidden state because it is removed from the concatenated branches
next_batch_0_encoder_last_hidden_state = next_batch_0.encoder_last_hidden_state
next_batch_1_encoder_last_hidden_state = next_batch_1.encoder_last_hidden_state
# Clone past_key_values before concatenating to compare after,
# because they are removed from the concatenated batches
next_batch_0_past_key_values = [
[t.clone() for t in layer] for layer in next_batch_0.past_key_values
]
next_batch_1_past_key_values = [
[t.clone() for t in layer] for layer in next_batch_1.past_key_values
]
next_batch = Seq2SeqLMBatch.concatenate([next_batch_0, next_batch_1])
assert next_batch.batch_id == 0
assert torch.equal(
next_batch.decoder_input_ids[0], next_batch_0.decoder_input_ids[0]
)
assert next_batch.all_decoder_input_ids[1][0] == 0
assert next_batch.all_decoder_input_ids[2][0] == 0
assert torch.equal(
next_batch.decoder_input_ids[1:, -2:], next_batch_1.decoder_input_ids
)
assert torch.all(next_batch.decoder_attention_mask[0, :3] == 1)
assert torch.all(next_batch.decoder_attention_mask[0, 3:] == 0)
assert torch.all(next_batch.decoder_attention_mask[1:, 0] == 0)
assert torch.all(next_batch.decoder_attention_mask[1:, 1:3] == 1)
assert torch.equal(
next_batch.encoder_last_hidden_state[0],
next_batch_0_encoder_last_hidden_state[0, -2:],
)
assert torch.equal(
next_batch.encoder_last_hidden_state[1:],
next_batch_1_encoder_last_hidden_state[:, -2:],
)
assert next_batch.input_lengths == [2, 2, 2]
assert next_batch.decoder_input_lengths == [3, 2, 2]
assert next_batch.max_input_length == 2
assert next_batch.max_decoder_input_length == 3
assert next_batch.requests[0] == next_batch_0.requests[0]
assert next_batch.requests[1:] == list(next_batch_1.requests)
assert next_batch.next_token_choosers[0] == next_batch_0.next_token_choosers[0]
assert next_batch.next_token_choosers[1:] == next_batch_1.next_token_choosers
assert next_batch.stopping_criterias[0] == next_batch_0.stopping_criterias[0]
assert next_batch.stopping_criterias[1:] == next_batch_1.stopping_criterias
assert next_batch.past_key_values is not None
assert all(
[p[0].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values]
)
assert all(
[p[1].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values]
)
assert all(
[p[2].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values]
)
assert all(
[p[3].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values]
)
for i, past in enumerate(next_batch.past_key_values):
assert torch.equal(next_batch_0_past_key_values[i][0][0, :, -2:, :], past[0][0])
assert torch.equal(
next_batch_1_past_key_values[i][0][:, :, -1:, :], past[0][1:, :, -1:, :]
)
assert torch.equal(next_batch_0_past_key_values[i][1][0, :, -2:, :], past[1][0])
assert torch.equal(
next_batch_1_past_key_values[i][1][:, :, -1:, :], past[1][1:, :, -1:, :]
)
assert torch.equal(next_batch_0_past_key_values[i][2][0, :, -2:, :], past[2][0])
assert torch.equal(
next_batch_1_past_key_values[i][2][:, :, -2:, :], past[2][1:]
)
assert torch.equal(next_batch_0_past_key_values[i][3][0, :, -2:, :], past[3][0])
assert torch.equal(
next_batch_1_past_key_values[i][3][:, :, -2:, :], past[3][1:]
)
for _ in range(3):
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 3
assert generations[2].generated_text.text == "a few "
assert (
generations[2].request_id
== default_multi_requests_seq2seq_lm_batch.requests[1].id
)
assert generations[2].generated_text.generated_tokens == 5
next_batch = next_batch.filter(
[next_batch.requests[0].id, next_batch.requests[1].id]
)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 2
assert generations[0].generated_text.text == "a few weeks"
assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id
assert generations[0].generated_text.generated_tokens == 7
next_batch = next_batch.filter([next_batch.requests[1].id])
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
assert generations[0].generated_text.text == "a few weeks"
assert (
generations[0].request_id
== default_multi_requests_seq2seq_lm_batch.requests[0].id
)
assert generations[0].generated_text.generated_tokens == 7

View File

@ -1,247 +0,0 @@
import pytest
from unittest.mock import Mock
from text_generation_server.utils.adapter import (
get_attn_weights,
get_mlp_weights,
parse_lora_adapters,
AdapterInfo,
)
def test_parse_lora_adapters_empty():
assert parse_lora_adapters(None) == []
assert parse_lora_adapters("") == []
def test_parse_lora_adapters_single():
result = parse_lora_adapters("adapter1")
assert result == [AdapterInfo(id="adapter1", path=None, revision=None)]
def test_parse_lora_adapters_with_path():
result = parse_lora_adapters("adapter1=path/to/adapter1")
assert result == [
AdapterInfo(id="adapter1", path="path/to/adapter1", revision=None)
]
def test_parse_lora_adapters_with_path_and_revision():
result = parse_lora_adapters("adapter1=path/to/adapter1@main")
assert result == [
AdapterInfo(id="adapter1", path="path/to/adapter1", revision="main")
]
def test_parse_lora_adapters_multiple():
result = parse_lora_adapters(
"adapter1,adapter2=path/to/adapter2,adapter3=path/to/adapter3@dev"
)
assert result == [
AdapterInfo(id="adapter1", path=None, revision=None),
AdapterInfo(id="adapter2", path="path/to/adapter2", revision=None),
AdapterInfo(id="adapter3", path="path/to/adapter3", revision="dev"),
]
def test_parse_lora_adapters_invalid_format():
try:
parse_lora_adapters("adapter1,invalid=format=test,adapter3")
assert False, "Should have raised ValueError"
except ValueError as e:
assert str(e) == "Invalid LoRA adapter format: invalid=format=test"
def test_get_attn_weights():
# create a mock layer
mock_layer = Mock()
mock_layer.self_attn.query_key_value = Mock()
mock_layer.self_attn.o_proj = Mock()
# call the function
result = get_attn_weights(2, mock_layer)
# assert the result
expected = {
(2, "q_proj"): (
"model.layers.2.self_attn.q_proj",
mock_layer.self_attn.query_key_value,
),
(2, "k_proj"): (
"model.layers.2.self_attn.k_proj",
mock_layer.self_attn.query_key_value,
),
(2, "qkv_proj"): (
"model.layers.2.self_attn.qkv_proj",
mock_layer.self_attn.query_key_value,
),
(2, "v_proj"): (
"model.layers.2.self_attn.v_proj",
mock_layer.self_attn.query_key_value,
),
(2, "o_proj"): ("model.layers.2.self_attn.o_proj", mock_layer.self_attn.o_proj),
}
assert result == expected
def test_get_mlp_weights_with_gate_up_proj():
# create a mock layer with gate_up_proj
mock_layer = Mock()
mock_layer.mlp.gate_up_proj = Mock()
mock_layer.mlp.down_proj = Mock()
# call the function
result = get_mlp_weights(3, mock_layer)
# assert the result
expected = {
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj),
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj),
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
}
assert result == expected
def test_get_mlp_weights_without_gate_up_proj():
# create a mock layer without gate_up_proj
mock_layer = Mock()
mock_layer.mlp = Mock(spec=[])
# call the function
result = get_mlp_weights(1, mock_layer)
# assert the result
assert result == {}
@pytest.mark.parametrize("layer_index", [0, 1, 5])
def test_get_attn_weights_different_layers(layer_index):
mock_layer = Mock()
mock_layer.self_attn.query_key_value = Mock()
mock_layer.self_attn.o_proj = Mock()
result = get_attn_weights(layer_index, mock_layer)
for k in ["q", "k", "v"]:
assert (layer_index, f"{k}_proj") in result
assert (
result[(layer_index, f"{k}_proj")][0]
== f"model.layers.{layer_index}.self_attn.{k}_proj"
)
assert (layer_index, "o_proj") in result
assert (
result[(layer_index, "o_proj")][0]
== f"model.layers.{layer_index}.self_attn.o_proj"
)
@pytest.mark.parametrize("layer_index", [0, 1, 5])
def test_get_mlp_weights_different_layers(layer_index):
mock_layer = Mock()
mock_layer.mlp.gate_up_proj = Mock()
mock_layer.mlp.down_proj = Mock()
result = get_mlp_weights(layer_index, mock_layer)
for k in ["gate", "up", "down"]:
assert (layer_index, f"{k}_proj") in result
assert (
result[(layer_index, f"{k}_proj")][0]
== f"model.layers.{layer_index}.mlp.{k}_proj"
)
def test_get_attn_weights_llama_compatibility():
mock_layer = Mock()
mock_layer.self_attn.query_key_value = Mock()
mock_layer.self_attn.o_proj = Mock()
result = get_attn_weights(2, mock_layer)
expected = {
(2, "q_proj"): (
"model.layers.2.self_attn.q_proj",
mock_layer.self_attn.query_key_value,
),
(2, "k_proj"): (
"model.layers.2.self_attn.k_proj",
mock_layer.self_attn.query_key_value,
),
(2, "qkv_proj"): (
"model.layers.2.self_attn.qkv_proj",
mock_layer.self_attn.query_key_value,
),
(2, "v_proj"): (
"model.layers.2.self_attn.v_proj",
mock_layer.self_attn.query_key_value,
),
(2, "o_proj"): ("model.layers.2.self_attn.o_proj", mock_layer.self_attn.o_proj),
}
assert result == expected
def test_get_mlp_weights_llama_compatibility():
mock_layer = Mock()
mock_layer.mlp.gate_up_proj = Mock()
mock_layer.mlp.down_proj = Mock()
result = get_mlp_weights(3, mock_layer)
expected = {
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj),
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj),
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
}
assert result == expected
def test_get_attn_weights_gemma_compatibility():
mock_layer = Mock()
mock_layer.self_attn.query_key_value = Mock()
mock_layer.self_attn.o_proj = Mock()
result = get_attn_weights(2, mock_layer)
expected = {
(2, "q_proj"): (
"model.layers.2.self_attn.q_proj",
mock_layer.self_attn.query_key_value,
),
(2, "k_proj"): (
"model.layers.2.self_attn.k_proj",
mock_layer.self_attn.query_key_value,
),
(2, "qkv_proj"): (
"model.layers.2.self_attn.qkv_proj",
mock_layer.self_attn.query_key_value,
),
(2, "v_proj"): (
"model.layers.2.self_attn.v_proj",
mock_layer.self_attn.query_key_value,
),
(2, "o_proj"): ("model.layers.2.self_attn.o_proj", mock_layer.self_attn.o_proj),
}
assert result == expected
def test_get_mlp_weights_gemma_compatibility():
mock_layer = Mock()
mock_layer.mlp.gate_proj = Mock()
mock_layer.mlp.up_proj = Mock()
mock_layer.mlp.down_proj = Mock()
# ensure that the mock_layer.mlp.gate_up_proj attribute does not exist.
# This is necessary because the use of `Mock` automatically creates any
# attributes that are accessed, even if they don't exist in the actual
# implementation. If `gate_up_proj` were created, `get_mlp_weights` might
# follow the wrong execution path and return an incorrect result.
del mock_layer.mlp.gate_up_proj
result = get_mlp_weights(3, mock_layer)
expected = {
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_proj),
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.up_proj),
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
}
assert result == expected

View File

@ -1,21 +0,0 @@
from text_generation_server.utils.hub import (
download_weights,
weight_hub_files,
weight_files,
)
from text_generation_server.utils.convert import convert_files
def test_convert_files():
model_id = "bigscience/bloom-560m"
pt_filenames = weight_hub_files(model_id, extension=".bin")
local_pt_files = download_weights(pt_filenames, model_id)
local_st_files = [
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" for p in local_pt_files
]
convert_files(local_pt_files, local_st_files, discard_names=[])
found_st_files = weight_files(model_id)
assert all([p in found_st_files for p in local_st_files])

View File

@ -1,103 +0,0 @@
import os
import tempfile
import pytest
import huggingface_hub.constants
import text_generation_server.utils.hub
from text_generation_server.utils.hub import (
weight_hub_files,
download_weights,
weight_files,
EntryNotFoundError,
LocalEntryNotFoundError,
RevisionNotFoundError,
)
@pytest.fixture()
def offline():
current_value = text_generation_server.utils.hub.HF_HUB_OFFLINE
text_generation_server.utils.hub.HF_HUB_OFFLINE = True
yield "offline"
text_generation_server.utils.hub.HF_HUB_OFFLINE = current_value
@pytest.fixture()
def fresh_cache():
with tempfile.TemporaryDirectory() as d:
current_value = huggingface_hub.constants.HUGGINGFACE_HUB_CACHE
huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = d
text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = d
os.environ["HUGGINGFACE_HUB_CACHE"] = d
yield
huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = current_value
os.environ["HUGGINGFACE_HUB_CACHE"] = current_value
text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = current_value
@pytest.fixture()
def prefetched():
model_id = "bert-base-uncased"
huggingface_hub.snapshot_download(
repo_id=model_id,
revision="main",
local_files_only=False,
repo_type="model",
allow_patterns=["*.safetensors"],
)
yield model_id
def test_weight_hub_files_offline_error(offline, fresh_cache):
# If the model is not prefetched then it will raise an error
with pytest.raises(EntryNotFoundError):
weight_hub_files("gpt2")
def test_weight_hub_files_offline_ok(prefetched, offline):
# If the model is prefetched then we should be able to get the weight files from local cache
filenames = weight_hub_files(prefetched)
root = None
assert len(filenames) == 1
for f in filenames:
curroot, filename = os.path.split(f)
if root is None:
root = curroot
else:
assert root == curroot
assert filename == "model.safetensors"
def test_weight_hub_files():
filenames = weight_hub_files("bigscience/bloom-560m")
assert filenames == ["model.safetensors"]
def test_weight_hub_files_llm():
filenames = weight_hub_files("bigscience/bloom")
assert filenames == [f"model_{i:05d}-of-00072.safetensors" for i in range(1, 73)]
def test_weight_hub_files_empty():
with pytest.raises(EntryNotFoundError):
weight_hub_files("bigscience/bloom", extension=".errors")
def test_download_weights():
model_id = "bigscience/bloom-560m"
filenames = weight_hub_files(model_id)
files = download_weights(filenames, model_id)
local_files = weight_files("bigscience/bloom-560m")
assert files == local_files
def test_weight_files_revision_error():
with pytest.raises(RevisionNotFoundError):
weight_files("bigscience/bloom-560m", revision="error")
def test_weight_files_not_cached_error(fresh_cache):
with pytest.raises(LocalEntryNotFoundError):
weight_files("bert-base-uncased")

View File

@ -1,82 +0,0 @@
import torch
from text_generation_server.layers import (
TensorParallelEmbedding,
)
class ProcessGroup:
def __init__(self, rank: int, world_size: int):
self._rank = rank
self.world_size = world_size
def size(self) -> int:
return self.world_size
def rank(self) -> int:
return self._rank
class Weights:
def __init__(self, rank: int, world_size: int, vocab_size: int, hidden_dim: int):
self.weight = (
torch.arange(vocab_size * hidden_dim).float().view(vocab_size, hidden_dim)
)
self.process_group = ProcessGroup(rank, world_size)
def get_partial_sharded(self, name: str, dim: int):
assert dim == 0
rank = self.process_group.rank()
world_size = self.process_group.size()
size = self.weight.shape[dim]
block_size = (size + world_size - 1) // world_size
start = rank * block_size
stop = (rank + 1) * block_size
return self.weight[start:stop]
def get_shape(self, name: str):
return self.weight.shape
def test_weight_hub_files_offline_error():
vocab_size = 17
weights = Weights(
rank=0,
world_size=1,
vocab_size=vocab_size,
hidden_dim=256,
)
embeddings = TensorParallelEmbedding("", weights)
input_ids = torch.arange(vocab_size)
output = embeddings.forward(input_ids)
assert embeddings.min_id == 0
assert embeddings.max_id == 17
torch.testing.assert_close(output, torch.arange(256 * 17).float().view(17, 256))
weights_0_2 = Weights(rank=0, world_size=2, vocab_size=vocab_size, hidden_dim=256)
weights_1_2 = Weights(rank=1, world_size=2, vocab_size=vocab_size, hidden_dim=256)
embeddings_0_2 = TensorParallelEmbedding("", weights_0_2, reduce=False)
assert embeddings_0_2.min_id == 0
assert embeddings_0_2.max_id == 9
torch.testing.assert_close(
embeddings_0_2.weight,
torch.cat([torch.arange(9 * 256), torch.zeros(256)], dim=0)
.view(10, 256)
.float(),
)
embeddings_1_2 = TensorParallelEmbedding("", weights_1_2, reduce=False)
assert embeddings_1_2.min_id == 9
assert embeddings_1_2.max_id == 17
torch.testing.assert_close(
embeddings_1_2.weight,
torch.cat([torch.arange(8 * 256) + 9 * 256, torch.zeros(256)], dim=0)
.view(9, 256)
.float(),
)
output_tp_0 = embeddings_0_2.forward(input_ids)
output_tp_1 = embeddings_1_2.forward(input_ids)
torch.testing.assert_close(output, output_tp_0 + output_tp_1)

View File

@ -1,88 +0,0 @@
import torch
from text_generation_server.utils.tokens import (
StopSequenceCriteria,
StoppingCriteria,
FinishReason,
batch_top_tokens,
)
def test_stop_sequence_criteria():
criteria = StopSequenceCriteria("/test;")
assert not criteria("/")
assert not criteria("/test")
assert criteria("/test;")
assert not criteria("/test; ")
def test_stop_sequence_criteria_escape():
criteria = StopSequenceCriteria("<|stop|>")
assert not criteria("<")
assert not criteria("<|stop")
assert criteria("<|stop|>")
assert not criteria("<|stop|> ")
def test_stopping_criteria():
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
assert criteria(65827, "/test") == (False, None)
assert criteria(30, ";") == (True, FinishReason.FINISH_REASON_STOP_SEQUENCE)
def test_stopping_criteria_eos():
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
assert criteria(1, "") == (False, None)
assert criteria(0, "") == (True, FinishReason.FINISH_REASON_EOS_TOKEN)
def test_stopping_criteria_max():
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
assert criteria(1, "") == (False, None)
assert criteria(1, "") == (False, None)
assert criteria(1, "") == (False, None)
assert criteria(1, "") == (False, None)
assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH)
def test_batch_top_tokens():
top_n_tokens = [0, 2, 3, 4, 5]
top_n_tokens_tensor = torch.tensor(top_n_tokens)
inp_logprobs = torch.tensor([[-1.0, -3.0, -4.0, -2.0, -3.0]] * 5)
accepted_ids = torch.ones_like(top_n_tokens_tensor)
topn_tok_ids, topn_tok_logprobs = batch_top_tokens(
top_n_tokens, top_n_tokens_tensor, inp_logprobs, accepted_ids
)
assert topn_tok_ids[0] == [[]]
assert topn_tok_ids[1] == [[0, 3]]
assert topn_tok_ids[2] == [[0, 3, 1, 4]]
assert topn_tok_ids[3] == [[0, 3, 1, 4]]
assert topn_tok_ids[4] == [[0, 3, 1, 4, 2]]
assert topn_tok_logprobs[0] == [[]]
assert topn_tok_logprobs[1] == [[-1, -2]]
assert topn_tok_logprobs[2] == [[-1, -2, -3, -3]]
assert topn_tok_logprobs[3] == [[-1, -2, -3, -3]]
assert topn_tok_logprobs[4] == [[-1, -2, -3, -3, -4]]
# Now let's make second member of the batch be speculated
inp_logprobs = torch.tensor([[-1.0, -3.0, -4.0, -2.0, -3.0]] * 5 * 2)
accepted_ids[1] = 2
topn_tok_ids, topn_tok_logprobs = batch_top_tokens(
top_n_tokens, top_n_tokens_tensor, inp_logprobs, accepted_ids
)
assert topn_tok_ids[0] == [[]]
assert topn_tok_ids[1] == [[0, 3], [0, 3]]
assert topn_tok_ids[2] == [[0, 3, 1, 4]]
assert topn_tok_ids[3] == [[0, 3, 1, 4]]
assert topn_tok_ids[4] == [[0, 3, 1, 4, 2]]
assert topn_tok_logprobs[0] == [[]]
assert topn_tok_logprobs[1] == [[-1, -2], [-1, -2]]
assert topn_tok_logprobs[2] == [[-1, -2, -3, -3]]
assert topn_tok_logprobs[3] == [[-1, -2, -3, -3]]
assert topn_tok_logprobs[4] == [[-1, -2, -3, -3, -4]]

View File

@ -1,54 +0,0 @@
# test_watermark_logits_processor.py
import os
import numpy as np
import torch
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
GAMMA = os.getenv("WATERMARK_GAMMA", 0.5)
DELTA = os.getenv("WATERMARK_DELTA", 2.0)
def test_seed_rng():
input_ids = [101, 2036, 3731, 102, 2003, 103]
processor = WatermarkLogitsProcessor()
processor._seed_rng(input_ids)
assert isinstance(processor.rng, torch.Generator)
def test_get_greenlist_ids():
input_ids = [101, 2036, 3731, 102, 2003, 103]
processor = WatermarkLogitsProcessor()
result = processor._get_greenlist_ids(input_ids, 10, torch.device("cpu"))
assert max(result) <= 10
assert len(result) == int(10 * 0.5)
def test_calc_greenlist_mask():
processor = WatermarkLogitsProcessor()
scores = torch.tensor([[0.5, 0.3, 0.2, 0.8], [0.1, 0.2, 0.7, 0.9]])
greenlist_token_ids = torch.tensor([2, 3])
result = processor._calc_greenlist_mask(scores, greenlist_token_ids)
assert result.tolist() == [[False, False, False, False], [False, False, True, True]]
assert result.shape == scores.shape
def test_bias_greenlist_logits():
processor = WatermarkLogitsProcessor()
scores = torch.tensor([[0.5, 0.3, 0.2, 0.8], [0.1, 0.2, 0.7, 0.9]])
green_tokens_mask = torch.tensor(
[[False, False, True, True], [False, False, False, True]]
)
greenlist_bias = 2.0
result = processor._bias_greenlist_logits(scores, green_tokens_mask, greenlist_bias)
assert np.allclose(result.tolist(), [[0.5, 0.3, 2.2, 2.8], [0.1, 0.2, 0.7, 2.9]])
assert result.shape == scores.shape
def test_call():
input_ids = [101, 2036, 3731, 102, 2003, 103]
processor = WatermarkLogitsProcessor()
scores = torch.tensor([[0.5, 0.3, 0.2, 0.8], [0.1, 0.2, 0.7, 0.9]])
result = processor(input_ids, scores)
assert result.shape == scores.shape

File diff suppressed because it is too large Load Diff