fix(server): decode buffer should be pair

This commit is contained in:
OlivierDehaene 2023-05-16 10:24:36 +02:00
parent dbdc587ddd
commit 92178b875e
17 changed files with 98 additions and 48 deletions

View File

@ -66,7 +66,8 @@ jobs:
- name: Run server tests
run: |
pip install pytest
make python-server-tests
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
pytest -s -vv server/tests
- name: Run Rust fmt
run: |
cargo fmt --check

View File

@ -31,7 +31,7 @@ update-integration-tests: install-integration-tests
pytest -s -vv --snapshot-update integration-tests
python-server-tests:
HF_HUB_ENABLE_HF_TRANSFER=1 pytest server/tests
HF_HUB_ENABLE_HF_TRANSFER=1 pytest -s -vv -m "not private" server/tests
python-client-tests:
pytest clients/python/tests

View File

@ -2,7 +2,7 @@ include Makefile-transformers
include Makefile-flash-att
unit-tests:
python -m pytest tests
pytest -s -vv -m "not private" tests
gen-server:
# Compile protos

View File

@ -0,0 +1,54 @@
import pytest
import torch
from transformers import AutoTokenizer
from text_generation_server.models import Model
@pytest.mark.private
def test_decode_streaming():
class TestModel(Model):
def batch_type(self):
raise NotImplementedError
def generate_token(self, batch):
raise NotImplementedError
tokenizer = AutoTokenizer.from_pretrained("huggingface/llama-7b")
model = TestModel(
torch.nn.Linear(1, 1), tokenizer, False, torch.float32, torch.device("cpu")
)
all_input_ids = [
30672,
232,
193,
139,
233,
135,
162,
235,
179,
165,
30919,
30210,
234,
134,
176,
30993,
]
truth = "我很感谢你的热情"
decoded_text = ""
offset = None
token_offset = None
for i in range(len(all_input_ids)):
text, offset, token_offset = model.decode_token(
all_input_ids[: i + 1], offset, token_offset
)
decoded_text += text
assert decoded_text == truth

View File

@ -104,9 +104,9 @@ class BLOOMSharded(BLOOM):
rank=rank,
world_size=world_size,
)
self.model = model.eval()
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,

View File

@ -448,7 +448,7 @@ class CausalLM(Model):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
decode_buffer: int = 3,
decode_buffer: int = 4,
):
if torch.cuda.is_available():
device = torch.device("cuda")
@ -463,20 +463,21 @@ class CausalLM(Model):
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left"
)
self.model = AutoModelForCausalLM.from_pretrained(
model = AutoModelForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None,
load_in_8bit=quantize == "bitsandbytes",
).eval()
)
tokenizer.pad_token_id = (
self.model.config.pad_token_id
if self.model.config.pad_token_id is not None
else self.model.config.eos_token_id
model.config.pad_token_id
if model.config.pad_token_id is not None
else model.config.eos_token_id
)
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,

View File

@ -394,7 +394,7 @@ class FlashCausalLM(Model):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
decode_buffer: int = 3,
decode_buffer: int = 4,
):
if torch.cuda.is_available():
device = torch.device("cuda")
@ -405,18 +405,15 @@ class FlashCausalLM(Model):
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left"
)
self.model = (
model_cls.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
load_in_8bit=quantize == "bitsandbytes",
)
.eval()
.to(device)
model = model_cls.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
load_in_8bit=quantize == "bitsandbytes",
)
super(FlashCausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,

View File

@ -64,9 +64,9 @@ class FlashLlama(FlashCausalLM):
model = FlashLlamaForCausalLM(config)
self.load_weights(model, filenames, quantize, device, dtype)
self.model = model.eval().to(device)
super(FlashCausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
@ -189,9 +189,9 @@ class FlashLlamaSharded(FlashLlama):
rank=rank,
world_size=world_size,
)
self.model = model.eval().to(device)
torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,

View File

@ -73,9 +73,9 @@ class FlashNeoXSharded(FlashNeoX):
rank=rank,
world_size=world_size,
)
self.model = model.eval().to(device)
torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,

View File

@ -67,9 +67,9 @@ class FlashSantacoder(FlashCausalLM):
dtype,
config.architectures[0].startswith("GPT2"),
)
self.model = model.eval().to(device)
super(FlashCausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
@ -213,9 +213,9 @@ class FlashSantacoderSharded(FlashSantacoder):
world_size=world_size,
transpose=config.architectures[0].startswith("GPT2"),
)
self.model = model.eval().to(device)
torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,

View File

@ -231,9 +231,9 @@ class GalacticaSharded(Galactica):
rank=rank,
world_size=world_size,
)
self.model = model.eval()
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,

View File

@ -70,9 +70,9 @@ class GPTNeoxSharded(CausalLM):
rank=rank,
world_size=world_size,
)
self.model = model.eval()
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,

View File

@ -13,17 +13,19 @@ B = TypeVar("B", bound=Batch)
class Model(ABC):
def __init__(
self,
model: torch.nn.Module,
tokenizer: PreTrainedTokenizerBase,
requires_padding: bool,
dtype: torch.dtype,
device: torch.device,
decode_buffer: int = 3,
decode_buffer: int = 4,
rank: int = 0,
world_size: int = 1,
):
if decode_buffer < 1:
raise ValueError("decode_buffer must be >= 1")
self.model = model.eval().to(device)
self.tokenizer = tokenizer
self.all_special_ids = set(tokenizer.all_special_ids)
self.requires_padding = requires_padding
@ -66,7 +68,7 @@ class Model(ABC):
)
if token_offset is None:
token_offset = len(all_input_ids) - self.decode_buffer
token_offset = max(len(all_input_ids) - self.decode_buffer, 0)
# left token buffer
if self.decode_buffer > 1:
# Decode token_offset token minus last one and token_offset tokens

View File

@ -86,9 +86,9 @@ class OPTSharded(OPT):
rank=rank,
world_size=world_size,
)
self.model = model.eval()
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,

View File

@ -46,19 +46,16 @@ class SantaCoder(CausalLM):
}
)
self.model = (
AutoModelForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=True, # required
)
.to(device)
.eval()
model = AutoModelForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=True, # required
)
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,

View File

@ -502,7 +502,7 @@ class Seq2SeqLM(Model):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
decode_buffer: int = 3,
decode_buffer: int = 4,
):
if torch.cuda.is_available():
device = torch.device("cuda")
@ -514,19 +514,20 @@ class Seq2SeqLM(Model):
device = torch.device("cpu")
dtype = torch.float32
self.model = AutoModelForSeq2SeqLM.from_pretrained(
model = AutoModelForSeq2SeqLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None,
load_in_8bit=quantize == "bitsandbytes",
).eval()
)
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left"
)
tokenizer.bos_token_id = self.model.config.decoder_start_token_id
tokenizer.bos_token_id = model.config.decoder_start_token_id
super(Seq2SeqLM, self).__init__(
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,

View File

@ -16,9 +16,6 @@ from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
)
from text_generation_server.utils.layers import (
FastLinear,
)
from transformers.models.t5.parallel_layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
@ -73,9 +70,9 @@ class T5Sharded(Seq2SeqLM):
rank=rank,
world_size=world_size,
)
self.model = model.eval()
torch.distributed.barrier(group=self.process_group)
super(Seq2SeqLM, self).__init__(
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,