feat(server): optimize decode for sane tokenizers

This commit is contained in:
OlivierDehaene 2023-04-12 11:24:02 +02:00
parent 6f0f1d70f6
commit 2aa5004482
8 changed files with 64 additions and 36 deletions

4
benchmark/Cargo.lock generated
View File

@ -853,7 +853,7 @@ checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b"
[[package]] [[package]]
name = "grpc-metadata" name = "grpc-metadata"
version = "0.4.1" version = "0.1.0"
dependencies = [ dependencies = [
"opentelemetry", "opentelemetry",
"tonic", "tonic",
@ -2140,7 +2140,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-client" name = "text-generation-client"
version = "0.4.3" version = "0.5.0"
dependencies = [ dependencies = [
"futures", "futures",
"grpc-metadata", "grpc-metadata",

View File

@ -49,6 +49,11 @@ class BloomCausalLMBatch(CausalLMBatch):
class BLOOM(CausalLM): class BLOOM(CausalLM):
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
super(BLOOM, self).__init__(
model_id=model_id, revision=revision, quantize=quantize, decode_buffer=1
)
@property @property
def batch_type(self) -> Type[CausalLMBatch]: def batch_type(self) -> Type[CausalLMBatch]:
return BloomCausalLMBatch return BloomCausalLMBatch
@ -94,8 +99,7 @@ class BLOOMSharded(BLOOM):
self.model = model.eval().to(dtype) self.model = model.eval().to(dtype)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
tokenizer=tokenizer, tokenizer=tokenizer, device=device, decode_buffer=1
device=device,
) )
@staticmethod @staticmethod

View File

@ -291,7 +291,13 @@ class CausalLMBatch(Batch):
class CausalLM(Model): class CausalLM(Model):
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False): def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: bool = False,
decode_buffer: int = 3,
):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
@ -319,8 +325,7 @@ class CausalLM(Model):
) )
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
tokenizer=tokenizer, tokenizer=tokenizer, device=device, decode_buffer=decode_buffer
device=device,
) )
@property @property

View File

@ -212,7 +212,8 @@ class FlashCausalLM(Model):
model_cls: Type[PreTrainedModel], model_cls: Type[PreTrainedModel],
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize=False, quantize: bool = False,
decode_buffer: int = 3,
): ):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
@ -237,8 +238,7 @@ class FlashCausalLM(Model):
) )
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
tokenizer=tokenizer, tokenizer=tokenizer, device=device, decode_buffer=decode_buffer
device=device,
) )
@property @property

View File

@ -62,8 +62,7 @@ class FlashSantacoder(FlashCausalLM):
self.model = model.eval().to(device).to(dtype) self.model = model.eval().to(device).to(dtype)
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
tokenizer=tokenizer, tokenizer=tokenizer, device=device, decode_buffer=1
device=device,
) )
@staticmethod @staticmethod

View File

@ -10,10 +10,14 @@ B = TypeVar("B", bound=Batch)
class Model(ABC): class Model(ABC):
def __init__(self, tokenizer: PreTrainedTokenizerBase, device: torch.device): def __init__(self, tokenizer: PreTrainedTokenizerBase, device: torch.device, decode_buffer: int = 3):
if decode_buffer < 1:
raise ValueError("decode_buffer must be >= 1")
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.all_special_ids = set(tokenizer.all_special_ids) self.all_special_ids = set(tokenizer.all_special_ids)
self.device = device self.device = device
self.decode_buffer = decode_buffer
@property @property
@abstractmethod @abstractmethod
@ -25,10 +29,10 @@ class Model(ABC):
raise NotImplementedError raise NotImplementedError
def decode_token( def decode_token(
self, self,
all_input_ids: List[int], all_input_ids: List[int],
offset: Optional[int] = None, offset: Optional[int] = None,
token_offset: Optional[int] = None, token_offset: Optional[int] = None,
) -> Tuple[str, Optional[int], Optional[int]]: ) -> Tuple[str, Optional[int], Optional[int]]:
"""Hack to hopefully support generate_stream for the maximum number of tokenizers""" """Hack to hopefully support generate_stream for the maximum number of tokenizers"""
if all_input_ids[-1] in self.all_special_ids: if all_input_ids[-1] in self.all_special_ids:
@ -39,23 +43,35 @@ class Model(ABC):
) )
if token_offset is None: if token_offset is None:
token_offset = len(all_input_ids) - 3 token_offset = len(all_input_ids) - self.decode_buffer
# left token buffer
if self.decode_buffer > 1:
# Decode token_offset token minus last one and token_offset tokens
raw_texts = self.tokenizer.batch_decode(
[all_input_ids[token_offset:-1], all_input_ids[token_offset:]],
skip_special_tokens=False,
)
# Decode token_offset token minus last one and token_offset tokens # default offset is only the last token
results = self.tokenizer.batch_decode( offset = len(raw_texts[0])
[all_input_ids[token_offset:-1], all_input_ids[token_offset:]], sequence_text = raw_texts[1]
skip_special_tokens=False, else:
) # Only decode the last token without using a token buffer
sequence_text = self.tokenizer.decode(all_input_ids[-1], skip_special_tokens=False)
# default offset is only the last token # no offset in this case
if offset is None: offset = 0
offset = len(results[0]) else:
assert offset is not None
sequence_text = self.tokenizer.decode(
all_input_ids[token_offset:],
skip_special_tokens=False,
)
# get text # get text
text = results[1][offset:] token_text = sequence_text[offset:]
# if text is utf-8 # if text is utf-8
if text and text[-1] != "<EFBFBD>": if token_text and token_text[-1] != "<EFBFBD>":
return text, None, None return token_text, None, None
else: else:
return "", offset, token_offset return "", offset, token_offset

View File

@ -54,8 +54,7 @@ class SantaCoder(CausalLM):
) )
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
tokenizer=tokenizer, tokenizer=tokenizer, device=device, decode_buffer=1
device=device,
) )
def decode(self, generated_ids: List[int]) -> str: def decode(self, generated_ids: List[int]) -> str:

View File

@ -330,7 +330,13 @@ class Seq2SeqLMBatch(Batch):
class Seq2SeqLM(Model): class Seq2SeqLM(Model):
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False): def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: bool = False,
decode_buffer: int = 3,
):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
@ -354,8 +360,7 @@ class Seq2SeqLM(Model):
tokenizer.bos_token_id = self.model.config.decoder_start_token_id tokenizer.bos_token_id = self.model.config.decoder_start_token_id
super(Seq2SeqLM, self).__init__( super(Seq2SeqLM, self).__init__(
tokenizer=tokenizer, tokenizer=tokenizer, device=device, decode_buffer=decode_buffer
device=device,
) )
@property @property
@ -496,7 +501,7 @@ class Seq2SeqLM(Model):
if stop: if stop:
# Slice with decoder_input_length to remove padding # Slice with decoder_input_length to remove padding
# Decode all tokens # Decode all tokens
output_text = self.decode(decoder_input_ids[-new_decoder_input_length:]) output_text = self.decode(decoder_input_ids[-decoder_input_length:])
# Get seed # Get seed
if isinstance(next_token_chooser.choice, Sampling): if isinstance(next_token_chooser.choice, Sampling):