fix(server): fix generate_stream by forcing tokens to be decoded correctly (#100)

This commit is contained in:
OlivierDehaene 2023-03-06 13:22:58 +01:00 committed by GitHub
parent 1c19b0934e
commit 9b205d33cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 45 additions and 29 deletions

View File

@ -385,10 +385,8 @@ class CausalLM(Model):
# Generated token # Generated token
next_token_logprob = logprobs[-1, next_token_id] next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze() next_token_id_squeezed = next_token_id.squeeze()
next_token_text = self.tokenizer.decode( next_token_text = self.decode_token(
next_token_id_squeezed, next_token_id_squeezed,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
) )
# Evaluate stopping criteria # Evaluate stopping criteria

View File

@ -15,6 +15,15 @@ class Model(ABC):
self.all_special_ids = set(tokenizer.all_special_ids) self.all_special_ids = set(tokenizer.all_special_ids)
self.device = device self.device = device
# see `decode_token` method
self.tokenizer.add_special_tokens(
{"additional_special_tokens": ["<decode-token>"]}
)
self.special_decode_token_id = self.tokenizer.convert_tokens_to_ids(
"<decode-token>"
)
self.special_decode_token_length = len("<decode-token>")
@property @property
@abstractmethod @abstractmethod
def batch_type(self) -> Type[B]: def batch_type(self) -> Type[B]:
@ -23,3 +32,12 @@ class Model(ABC):
@abstractmethod @abstractmethod
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]: def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
raise NotImplementedError raise NotImplementedError
def decode_token(self, token_id: int) -> str:
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
# append token to special decode token and decode both
result = self.tokenizer.decode(
[self.special_decode_token_id, token_id], skip_special_tokens=False
)
# slice to remove special decode token
return result[self.special_decode_token_length :]

View File

@ -342,7 +342,9 @@ class Seq2SeqLM(Model):
return Seq2SeqLMBatch return Seq2SeqLMBatch
def decode(self, decoder_ids: List[int]) -> str: def decode(self, decoder_ids: List[int]) -> str:
return self.tokenizer.decode(decoder_ids, skip_special_tokens=True) return self.tokenizer.decode(
decoder_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
def forward( def forward(
self, self,
@ -457,10 +459,8 @@ class Seq2SeqLM(Model):
# Generated token # Generated token
next_token_logprob = logprobs[-1, next_token_id] next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze() next_token_id_squeezed = next_token_id.squeeze()
next_token_text = self.tokenizer.decode( next_token_text = self.decode_token(
next_token_id_squeezed, next_token_id_squeezed,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
) )
# Evaluate stopping criteria # Evaluate stopping criteria