From 92178b875e30a36bf6e3cdd38a30d6781d276ad9 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 16 May 2023 10:24:36 +0200 Subject: [PATCH] fix(server): decode buffer should be pair --- .github/workflows/tests.yaml | 3 +- Makefile | 2 +- server/Makefile | 2 +- server/tests/models/test_model.py | 54 +++++++++++++++++++ server/text_generation_server/models/bloom.py | 2 +- .../models/causal_lm.py | 13 ++--- .../models/flash_causal_lm.py | 17 +++--- .../models/flash_llama.py | 4 +- .../models/flash_neox.py | 2 +- .../models/flash_santacoder.py | 4 +- .../models/galactica.py | 2 +- .../text_generation_server/models/gpt_neox.py | 2 +- server/text_generation_server/models/model.py | 6 ++- server/text_generation_server/models/opt.py | 2 +- .../models/santacoder.py | 17 +++--- .../models/seq2seq_lm.py | 9 ++-- server/text_generation_server/models/t5.py | 5 +- 17 files changed, 98 insertions(+), 48 deletions(-) create mode 100644 server/tests/models/test_model.py diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index d9858a3b..7e5ba52c 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -66,7 +66,8 @@ jobs: - name: Run server tests run: | pip install pytest - make python-server-tests + export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} + pytest -s -vv server/tests - name: Run Rust fmt run: | cargo fmt --check diff --git a/Makefile b/Makefile index 29c318fa..7309aaee 100644 --- a/Makefile +++ b/Makefile @@ -31,7 +31,7 @@ update-integration-tests: install-integration-tests pytest -s -vv --snapshot-update integration-tests python-server-tests: - HF_HUB_ENABLE_HF_TRANSFER=1 pytest server/tests + HF_HUB_ENABLE_HF_TRANSFER=1 pytest -s -vv -m "not private" server/tests python-client-tests: pytest clients/python/tests diff --git a/server/Makefile b/server/Makefile index 150d7e4a..6eb56c75 100644 --- a/server/Makefile +++ b/server/Makefile @@ -2,7 +2,7 @@ include Makefile-transformers include Makefile-flash-att unit-tests: - python -m pytest tests + pytest -s -vv -m "not private" tests gen-server: # Compile protos diff --git a/server/tests/models/test_model.py b/server/tests/models/test_model.py new file mode 100644 index 00000000..fbcf873d --- /dev/null +++ b/server/tests/models/test_model.py @@ -0,0 +1,54 @@ +import pytest +import torch + +from transformers import AutoTokenizer + +from text_generation_server.models import Model + + +@pytest.mark.private +def test_decode_streaming(): + class TestModel(Model): + def batch_type(self): + raise NotImplementedError + + def generate_token(self, batch): + raise NotImplementedError + + tokenizer = AutoTokenizer.from_pretrained("huggingface/llama-7b") + + model = TestModel( + torch.nn.Linear(1, 1), tokenizer, False, torch.float32, torch.device("cpu") + ) + + all_input_ids = [ + 30672, + 232, + 193, + 139, + 233, + 135, + 162, + 235, + 179, + 165, + 30919, + 30210, + 234, + 134, + 176, + 30993, + ] + + truth = "我很感谢你的热情" + + decoded_text = "" + offset = None + token_offset = None + for i in range(len(all_input_ids)): + text, offset, token_offset = model.decode_token( + all_input_ids[: i + 1], offset, token_offset + ) + decoded_text += text + + assert decoded_text == truth diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 9029e954..e2a475c1 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -104,9 +104,9 @@ class BLOOMSharded(BLOOM): rank=rank, world_size=world_size, ) - self.model = model.eval() torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( + model=model, tokenizer=tokenizer, requires_padding=True, dtype=dtype, diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 0d521ac4..870d261f 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -448,7 +448,7 @@ class CausalLM(Model): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - decode_buffer: int = 3, + decode_buffer: int = 4, ): if torch.cuda.is_available(): device = torch.device("cuda") @@ -463,20 +463,21 @@ class CausalLM(Model): tokenizer = AutoTokenizer.from_pretrained( model_id, revision=revision, padding_side="left", truncation_side="left" ) - self.model = AutoModelForCausalLM.from_pretrained( + model = AutoModelForCausalLM.from_pretrained( model_id, revision=revision, torch_dtype=dtype, device_map="auto" if torch.cuda.is_available() else None, load_in_8bit=quantize == "bitsandbytes", - ).eval() + ) tokenizer.pad_token_id = ( - self.model.config.pad_token_id - if self.model.config.pad_token_id is not None - else self.model.config.eos_token_id + model.config.pad_token_id + if model.config.pad_token_id is not None + else model.config.eos_token_id ) super(CausalLM, self).__init__( + model=model, tokenizer=tokenizer, requires_padding=True, dtype=dtype, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 0a9fccca..c16cc19b 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -394,7 +394,7 @@ class FlashCausalLM(Model): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - decode_buffer: int = 3, + decode_buffer: int = 4, ): if torch.cuda.is_available(): device = torch.device("cuda") @@ -405,18 +405,15 @@ class FlashCausalLM(Model): tokenizer = AutoTokenizer.from_pretrained( model_id, revision=revision, padding_side="left", truncation_side="left" ) - self.model = ( - model_cls.from_pretrained( - model_id, - revision=revision, - torch_dtype=dtype, - load_in_8bit=quantize == "bitsandbytes", - ) - .eval() - .to(device) + model = model_cls.from_pretrained( + model_id, + revision=revision, + torch_dtype=dtype, + load_in_8bit=quantize == "bitsandbytes", ) super(FlashCausalLM, self).__init__( + model=model, tokenizer=tokenizer, requires_padding=False, dtype=dtype, diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index b775bd79..3fd8774e 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -64,9 +64,9 @@ class FlashLlama(FlashCausalLM): model = FlashLlamaForCausalLM(config) self.load_weights(model, filenames, quantize, device, dtype) - self.model = model.eval().to(device) super(FlashCausalLM, self).__init__( + model=model, tokenizer=tokenizer, requires_padding=False, dtype=dtype, @@ -189,9 +189,9 @@ class FlashLlamaSharded(FlashLlama): rank=rank, world_size=world_size, ) - self.model = model.eval().to(device) torch.distributed.barrier(group=self.process_group) super(FlashCausalLM, self).__init__( + model=model, tokenizer=tokenizer, requires_padding=False, dtype=dtype, diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index 0924f107..c322ecbc 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -73,9 +73,9 @@ class FlashNeoXSharded(FlashNeoX): rank=rank, world_size=world_size, ) - self.model = model.eval().to(device) torch.distributed.barrier(group=self.process_group) super(FlashCausalLM, self).__init__( + model=model, tokenizer=tokenizer, requires_padding=False, dtype=dtype, diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index 031a67eb..6824118d 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -67,9 +67,9 @@ class FlashSantacoder(FlashCausalLM): dtype, config.architectures[0].startswith("GPT2"), ) - self.model = model.eval().to(device) super(FlashCausalLM, self).__init__( + model=model, tokenizer=tokenizer, requires_padding=False, dtype=dtype, @@ -213,9 +213,9 @@ class FlashSantacoderSharded(FlashSantacoder): world_size=world_size, transpose=config.architectures[0].startswith("GPT2"), ) - self.model = model.eval().to(device) torch.distributed.barrier(group=self.process_group) super(FlashCausalLM, self).__init__( + model=model, tokenizer=tokenizer, requires_padding=False, dtype=dtype, diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index d1e5e841..b34489d8 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -231,9 +231,9 @@ class GalacticaSharded(Galactica): rank=rank, world_size=world_size, ) - self.model = model.eval() torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( + model=model, tokenizer=tokenizer, requires_padding=True, dtype=dtype, diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index f95e5be2..a10dfcb8 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -70,9 +70,9 @@ class GPTNeoxSharded(CausalLM): rank=rank, world_size=world_size, ) - self.model = model.eval() torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( + model=model, tokenizer=tokenizer, requires_padding=True, dtype=dtype, diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 03f14013..f19fecb8 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -13,17 +13,19 @@ B = TypeVar("B", bound=Batch) class Model(ABC): def __init__( self, + model: torch.nn.Module, tokenizer: PreTrainedTokenizerBase, requires_padding: bool, dtype: torch.dtype, device: torch.device, - decode_buffer: int = 3, + decode_buffer: int = 4, rank: int = 0, world_size: int = 1, ): if decode_buffer < 1: raise ValueError("decode_buffer must be >= 1") + self.model = model.eval().to(device) self.tokenizer = tokenizer self.all_special_ids = set(tokenizer.all_special_ids) self.requires_padding = requires_padding @@ -66,7 +68,7 @@ class Model(ABC): ) if token_offset is None: - token_offset = len(all_input_ids) - self.decode_buffer + token_offset = max(len(all_input_ids) - self.decode_buffer, 0) # left token buffer if self.decode_buffer > 1: # Decode token_offset token minus last one and token_offset tokens diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index 093cf70a..fdae795b 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -86,9 +86,9 @@ class OPTSharded(OPT): rank=rank, world_size=world_size, ) - self.model = model.eval() torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( + model=model, tokenizer=tokenizer, requires_padding=True, dtype=dtype, diff --git a/server/text_generation_server/models/santacoder.py b/server/text_generation_server/models/santacoder.py index 4bd56de1..4368ed60 100644 --- a/server/text_generation_server/models/santacoder.py +++ b/server/text_generation_server/models/santacoder.py @@ -46,19 +46,16 @@ class SantaCoder(CausalLM): } ) - self.model = ( - AutoModelForCausalLM.from_pretrained( - model_id, - revision=revision, - torch_dtype=dtype, - load_in_8bit=quantize == "bitsandbytes", - trust_remote_code=True, # required - ) - .to(device) - .eval() + model = AutoModelForCausalLM.from_pretrained( + model_id, + revision=revision, + torch_dtype=dtype, + load_in_8bit=quantize == "bitsandbytes", + trust_remote_code=True, # required ) super(CausalLM, self).__init__( + model=model, tokenizer=tokenizer, requires_padding=True, dtype=dtype, diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 84854f5d..98be4c71 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -502,7 +502,7 @@ class Seq2SeqLM(Model): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - decode_buffer: int = 3, + decode_buffer: int = 4, ): if torch.cuda.is_available(): device = torch.device("cuda") @@ -514,19 +514,20 @@ class Seq2SeqLM(Model): device = torch.device("cpu") dtype = torch.float32 - self.model = AutoModelForSeq2SeqLM.from_pretrained( + model = AutoModelForSeq2SeqLM.from_pretrained( model_id, revision=revision, torch_dtype=dtype, device_map="auto" if torch.cuda.is_available() else None, load_in_8bit=quantize == "bitsandbytes", - ).eval() + ) tokenizer = AutoTokenizer.from_pretrained( model_id, revision=revision, padding_side="left", truncation_side="left" ) - tokenizer.bos_token_id = self.model.config.decoder_start_token_id + tokenizer.bos_token_id = model.config.decoder_start_token_id super(Seq2SeqLM, self).__init__( + model=model, tokenizer=tokenizer, requires_padding=True, dtype=dtype, diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index 8e3826a4..b1ba2432 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -16,9 +16,6 @@ from text_generation_server.utils import ( initialize_torch_distributed, weight_files, ) -from text_generation_server.utils.layers import ( - FastLinear, -) from transformers.models.t5.parallel_layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -73,9 +70,9 @@ class T5Sharded(Seq2SeqLM): rank=rank, world_size=world_size, ) - self.model = model.eval() torch.distributed.barrier(group=self.process_group) super(Seq2SeqLM, self).__init__( + model=model, tokenizer=tokenizer, requires_padding=True, dtype=dtype,