diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml
index 79b3c777..9305b8e7 100644
--- a/.github/workflows/build.yaml
+++ b/.github/workflows/build.yaml
@@ -213,12 +213,13 @@ jobs:
sudo mount /dev/nvme1n1 ${{ env.DOCKER_VOLUME }}
- name: Install
run: |
+ pip install pytest-xdist
make install-integration-tests
- name: Run tests
run: |
export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }}
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
- pytest -s -vv integration-tests
+ pytest -s -vv -n 2 --dist loadfile integration-tests
stop-runner:
name: Stop self-hosted EC2 runner
diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py
index ba1abca9..3086ecda 100644
--- a/integration-tests/conftest.py
+++ b/integration-tests/conftest.py
@@ -1,3 +1,4 @@
+import sys
import subprocess
import contextlib
import pytest
@@ -7,6 +8,7 @@ import docker
import json
import math
import time
+import random
from docker.errors import NotFound
from typing import Optional, List, Dict
@@ -205,10 +207,12 @@ def launcher(event_loop):
def local_launcher(
model_id: str, num_shard: Optional[int] = None, quantize: Optional[str] = None
):
- port = 9999
- master_port = 19999
+ port = random.randint(8000, 10_000)
+ master_port = random.randint(10_000, 20_000)
- shard_uds_path = f"/tmp/{model_id.replace('/', '--')}-server"
+ shard_uds_path = (
+ f"/tmp/tgi-tests-{model_id.split('/')[-1]}-{num_shard}-{quantize}-server"
+ )
args = [
"text-generation-launcher",
@@ -236,7 +240,7 @@ def launcher(event_loop):
process.wait(60)
launcher_output = process.stdout.read().decode("utf-8")
- print(launcher_output)
+ print(launcher_output, file=sys.stderr)
process.stdout.close()
process.stderr.close()
@@ -245,7 +249,7 @@ def launcher(event_loop):
def docker_launcher(
model_id: str, num_shard: Optional[int] = None, quantize: Optional[str] = None
):
- port = 9999
+ port = random.randint(8000, 10_000)
args = ["--model-id", model_id, "--env"]
@@ -298,7 +302,7 @@ def launcher(event_loop):
pass
container_output = container.logs().decode("utf-8")
- print(container_output)
+ print(container_output, file=sys.stderr)
container.remove()
diff --git a/integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama_load.json b/integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama_load.json
index 5a8ba217..9bbb5322 100644
--- a/integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama_load.json
+++ b/integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama_load.json
@@ -1,92 +1,4 @@
[
- {
- "details": {
- "best_of_sequences": null,
- "finish_reason": "length",
- "generated_tokens": 10,
- "prefill": [
- {
- "id": 1,
- "logprob": null,
- "text": ""
- },
- {
- "id": 4321,
- "logprob": -8.6875,
- "text": "Test"
- },
- {
- "id": 2009,
- "logprob": -11.5546875,
- "text": "request"
- }
- ],
- "seed": null,
- "tokens": [
- {
- "id": 363,
- "logprob": -1.5322266,
- "special": false,
- "text": " for"
- },
- {
- "id": 847,
- "logprob": -2.5585938,
- "special": false,
- "text": " /"
- },
- {
- "id": 2754,
- "logprob": -2.265625,
- "special": false,
- "text": "api"
- },
- {
- "id": 29914,
- "logprob": -0.034088135,
- "special": false,
- "text": "/"
- },
- {
- "id": 29894,
- "logprob": -0.96240234,
- "special": false,
- "text": "v"
- },
- {
- "id": 29896,
- "logprob": -0.36816406,
- "special": false,
- "text": "1"
- },
- {
- "id": 29914,
- "logprob": -0.013191223,
- "special": false,
- "text": "/"
- },
- {
- "id": 16418,
- "logprob": -3.15625,
- "special": false,
- "text": "projects"
- },
- {
- "id": 29914,
- "logprob": -0.43774414,
- "special": false,
- "text": "/"
- },
- {
- "id": 29896,
- "logprob": -1.9443359,
- "special": false,
- "text": "1"
- }
- ]
- },
- "generated_text": "for /api/v1/projects/1"
- },
{
"details": {
"best_of_sequences": null,
@@ -263,6 +175,94 @@
},
"generated_text": "for /api/v1/projects/1"
},
+ {
+ "details": {
+ "best_of_sequences": null,
+ "finish_reason": "length",
+ "generated_tokens": 10,
+ "prefill": [
+ {
+ "id": 1,
+ "logprob": null,
+ "text": ""
+ },
+ {
+ "id": 4321,
+ "logprob": -8.6875,
+ "text": "Test"
+ },
+ {
+ "id": 2009,
+ "logprob": -11.5546875,
+ "text": "request"
+ }
+ ],
+ "seed": null,
+ "tokens": [
+ {
+ "id": 363,
+ "logprob": -1.5322266,
+ "special": false,
+ "text": " for"
+ },
+ {
+ "id": 847,
+ "logprob": -2.5585938,
+ "special": false,
+ "text": " /"
+ },
+ {
+ "id": 2754,
+ "logprob": -2.265625,
+ "special": false,
+ "text": "api"
+ },
+ {
+ "id": 29914,
+ "logprob": -0.034088135,
+ "special": false,
+ "text": "/"
+ },
+ {
+ "id": 29894,
+ "logprob": -0.96240234,
+ "special": false,
+ "text": "v"
+ },
+ {
+ "id": 29896,
+ "logprob": -0.36816406,
+ "special": false,
+ "text": "1"
+ },
+ {
+ "id": 29914,
+ "logprob": -0.013191223,
+ "special": false,
+ "text": "/"
+ },
+ {
+ "id": 16418,
+ "logprob": -3.15625,
+ "special": false,
+ "text": "projects"
+ },
+ {
+ "id": 29914,
+ "logprob": -0.43774414,
+ "special": false,
+ "text": "/"
+ },
+ {
+ "id": 29896,
+ "logprob": -1.9443359,
+ "special": false,
+ "text": "1"
+ }
+ ]
+ },
+ "generated_text": "for /api/v1/projects/1"
+ },
{
"details": {
"best_of_sequences": null,
diff --git a/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base.json b/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base.json
index 2a26e3db..c1cd24cd 100644
--- a/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base.json
+++ b/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base.json
@@ -16,7 +16,7 @@
"id": 926,
"logprob": -4.3554688,
"special": false,
- "text": "To"
+ "text": " To"
},
{
"id": 18295,
diff --git a/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_all_params.json b/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_all_params.json
index fd77252d..3e9f3d73 100644
--- a/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_all_params.json
+++ b/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_all_params.json
@@ -16,7 +16,7 @@
"id": 16017,
"logprob": -1.3505859,
"special": false,
- "text": "blue"
+ "text": " blue"
},
{
"id": 20495,
diff --git a/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_load.json b/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_load.json
index c9e552b6..c0834ae1 100644
--- a/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_load.json
+++ b/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_load.json
@@ -1,58 +1,4 @@
[
- {
- "details": {
- "best_of_sequences": null,
- "finish_reason": "eos_token",
- "generated_tokens": 6,
- "prefill": [
- {
- "id": 0,
- "logprob": null,
- "text": ""
- }
- ],
- "seed": null,
- "tokens": [
- {
- "id": 259,
- "logprob": -1.3789062,
- "special": false,
- "text": ""
- },
- {
- "id": 39261,
- "logprob": -0.36279297,
- "special": false,
- "text": "Because"
- },
- {
- "id": 609,
- "logprob": -1.0966797,
- "special": false,
- "text": " it"
- },
- {
- "id": 339,
- "logprob": -0.8276367,
- "special": false,
- "text": " is"
- },
- {
- "id": 16017,
- "logprob": -1.6845703,
- "special": false,
- "text": " blue"
- },
- {
- "id": 1,
- "logprob": -0.72753906,
- "special": true,
- "text": ""
- }
- ]
- },
- "generated_text": "Because it is blue"
- },
{
"details": {
"best_of_sequences": null,
@@ -71,7 +17,7 @@
"id": 259,
"logprob": -1.3798828,
"special": false,
- "text": ""
+ "text": " "
},
{
"id": 39261,
@@ -125,7 +71,7 @@
"id": 259,
"logprob": -1.3789062,
"special": false,
- "text": ""
+ "text": " "
},
{
"id": 39261,
@@ -179,7 +125,61 @@
"id": 259,
"logprob": -1.3789062,
"special": false,
- "text": ""
+ "text": " "
+ },
+ {
+ "id": 39261,
+ "logprob": -0.36279297,
+ "special": false,
+ "text": "Because"
+ },
+ {
+ "id": 609,
+ "logprob": -1.0966797,
+ "special": false,
+ "text": " it"
+ },
+ {
+ "id": 339,
+ "logprob": -0.8276367,
+ "special": false,
+ "text": " is"
+ },
+ {
+ "id": 16017,
+ "logprob": -1.6845703,
+ "special": false,
+ "text": " blue"
+ },
+ {
+ "id": 1,
+ "logprob": -0.72753906,
+ "special": true,
+ "text": ""
+ }
+ ]
+ },
+ "generated_text": "Because it is blue"
+ },
+ {
+ "details": {
+ "best_of_sequences": null,
+ "finish_reason": "eos_token",
+ "generated_tokens": 6,
+ "prefill": [
+ {
+ "id": 0,
+ "logprob": null,
+ "text": ""
+ }
+ ],
+ "seed": null,
+ "tokens": [
+ {
+ "id": 259,
+ "logprob": -1.3789062,
+ "special": false,
+ "text": " "
},
{
"id": 39261,
diff --git a/router/src/main.rs b/router/src/main.rs
index 5ad49003..82bf6ba8 100644
--- a/router/src/main.rs
+++ b/router/src/main.rs
@@ -146,7 +146,7 @@ fn main() -> Result<(), std::io::Error> {
sha: None,
pipeline_tag: None,
},
- false => get_model_info(&tokenizer_name, &revision, authorization_token).await.unwrap_or({
+ false => get_model_info(&tokenizer_name, &revision, authorization_token).await.unwrap_or_else(|| {
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
HubModelInfo { model_id: tokenizer_name.to_string(), sha: None, pipeline_tag: None }
}),
diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py
index e2a475c1..1f324f77 100644
--- a/server/text_generation_server/models/bloom.py
+++ b/server/text_generation_server/models/bloom.py
@@ -56,7 +56,7 @@ class BLOOM(CausalLM):
quantize: Optional[str] = None,
):
super(BLOOM, self).__init__(
- model_id=model_id, revision=revision, quantize=quantize, decode_buffer=1
+ model_id=model_id, revision=revision, quantize=quantize
)
@property
@@ -111,7 +111,6 @@ class BLOOMSharded(BLOOM):
requires_padding=True,
dtype=dtype,
device=device,
- decode_buffer=1,
rank=rank,
world_size=world_size,
)
diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py
index 7dc3e0aa..3a45ae06 100644
--- a/server/text_generation_server/models/causal_lm.py
+++ b/server/text_generation_server/models/causal_lm.py
@@ -81,8 +81,6 @@ class CausalLMBatch(Batch):
for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
inputs.append(r.inputs)
- # offsets.append(None)
- # token_offsets.append(None)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
@@ -102,7 +100,7 @@ class CausalLMBatch(Batch):
truncation=True,
max_length=max_truncation,
).to(device)
- for i, r in enumerate(pb.requests):
+ for _ in pb.requests:
input_len = tokenized_inputs["input_ids"].shape[1]
offsets.append(0)
token_offsets.append(input_len)
@@ -452,7 +450,6 @@ class CausalLM(Model):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
- decode_buffer: int = 4,
):
if torch.cuda.is_available():
device = torch.device("cuda")
@@ -486,7 +483,6 @@ class CausalLM(Model):
requires_padding=True,
dtype=dtype,
device=device,
- decode_buffer=decode_buffer,
)
@property
diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py
index c16cc19b..28376729 100644
--- a/server/text_generation_server/models/flash_causal_lm.py
+++ b/server/text_generation_server/models/flash_causal_lm.py
@@ -108,8 +108,8 @@ class FlashCausalLMBatch(Batch):
max_seqlen = max(max_seqlen, input_length)
input_lengths.append(input_length)
- offsets.append(None)
- token_offsets.append(None)
+ offsets.append(0)
+ token_offsets.append(input_length)
all_input_ids.append(tokenized_input)
@@ -394,7 +394,6 @@ class FlashCausalLM(Model):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
- decode_buffer: int = 4,
):
if torch.cuda.is_available():
device = torch.device("cuda")
@@ -410,7 +409,7 @@ class FlashCausalLM(Model):
revision=revision,
torch_dtype=dtype,
load_in_8bit=quantize == "bitsandbytes",
- )
+ ).to(device)
super(FlashCausalLM, self).__init__(
model=model,
@@ -418,7 +417,6 @@ class FlashCausalLM(Model):
requires_padding=False,
dtype=dtype,
device=device,
- decode_buffer=decode_buffer,
)
@property
diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py
index 3fd8774e..ebdbe206 100644
--- a/server/text_generation_server/models/flash_llama.py
+++ b/server/text_generation_server/models/flash_llama.py
@@ -66,7 +66,7 @@ class FlashLlama(FlashCausalLM):
self.load_weights(model, filenames, quantize, device, dtype)
super(FlashCausalLM, self).__init__(
- model=model,
+ model=model.to(device),
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
@@ -191,7 +191,7 @@ class FlashLlamaSharded(FlashLlama):
)
torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__(
- model=model,
+ model=model.to(device),
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py
index c322ecbc..cac40bab 100644
--- a/server/text_generation_server/models/flash_neox.py
+++ b/server/text_generation_server/models/flash_neox.py
@@ -75,7 +75,7 @@ class FlashNeoXSharded(FlashNeoX):
)
torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__(
- model=model,
+ model=model.to(device),
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py
index 6824118d..5dc31309 100644
--- a/server/text_generation_server/models/flash_santacoder.py
+++ b/server/text_generation_server/models/flash_santacoder.py
@@ -69,12 +69,11 @@ class FlashSantacoder(FlashCausalLM):
)
super(FlashCausalLM, self).__init__(
- model=model,
+ model=model.to(device),
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device,
- decode_buffer=1,
)
@staticmethod
@@ -215,14 +214,13 @@ class FlashSantacoderSharded(FlashSantacoder):
)
torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__(
- model=model,
+ model=model.to(device),
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
- decode_buffer=1,
)
@staticmethod
diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py
index 3fa50678..29bad321 100644
--- a/server/text_generation_server/models/model.py
+++ b/server/text_generation_server/models/model.py
@@ -18,20 +18,15 @@ class Model(ABC):
requires_padding: bool,
dtype: torch.dtype,
device: torch.device,
- decode_buffer: int = 4,
rank: int = 0,
world_size: int = 1,
):
- if decode_buffer < 1:
- raise ValueError("decode_buffer must be >= 1")
-
- self.model = model.eval().to(device)
+ self.model = model.eval()
self.tokenizer = tokenizer
self.all_special_ids = set(tokenizer.all_special_ids)
self.requires_padding = requires_padding
self.dtype = dtype
self.device = device
- self.decode_buffer = decode_buffer
self.rank = rank
self.world_size = world_size
self.check_initialized()
@@ -61,12 +56,6 @@ class Model(ABC):
) -> Tuple[str, int, int]:
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
- # Compatibility layer for old None values.
- if prefix_offset is None:
- prefix_offset = 0
- if read_offset is None:
- read_offset = 0
-
# The prefix text is necessary only to defeat cleanup algorithms in the decode
# which decide to add a space or not depending on the surrounding ids.
prefix_text = self.tokenizer.decode(
diff --git a/server/text_generation_server/models/santacoder.py b/server/text_generation_server/models/santacoder.py
index 4368ed60..23f89f48 100644
--- a/server/text_generation_server/models/santacoder.py
+++ b/server/text_generation_server/models/santacoder.py
@@ -52,7 +52,7 @@ class SantaCoder(CausalLM):
torch_dtype=dtype,
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=True, # required
- )
+ ).to(device)
super(CausalLM, self).__init__(
model=model,
@@ -60,7 +60,6 @@ class SantaCoder(CausalLM):
requires_padding=True,
dtype=dtype,
device=device,
- decode_buffer=1,
)
def decode(self, generated_ids: List[int]) -> str:
diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py
index 42201e9f..ac7b9cdd 100644
--- a/server/text_generation_server/models/seq2seq_lm.py
+++ b/server/text_generation_server/models/seq2seq_lm.py
@@ -91,8 +91,6 @@ class Seq2SeqLMBatch(Batch):
inputs.append(r.inputs)
requests_idx_mapping[r.id] = i
decoder_input_lengths.append(1)
- # offsets.append(None)
- # token_offsets.append(None)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
@@ -123,7 +121,7 @@ class Seq2SeqLMBatch(Batch):
.repeat(len(pb.requests))
.view(-1, 1)
)
- for i, r in enumerate(pb.requests):
+ for _ in pb.requests:
offsets.append(0)
token_offsets.append(1)
all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
@@ -505,7 +503,6 @@ class Seq2SeqLM(Model):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
- decode_buffer: int = 4,
):
if torch.cuda.is_available():
device = torch.device("cuda")
@@ -535,7 +532,6 @@ class Seq2SeqLM(Model):
requires_padding=True,
dtype=dtype,
device=device,
- decode_buffer=decode_buffer,
)
@property