Easiest fix.

This commit is contained in:
Nicolas Patry 2022-12-20 15:24:42 +01:00
parent 611e21cb13
commit d8e1ce669b

View File

@ -354,7 +354,8 @@ class CausalLM(Model):
if stop: if stop:
# Decode all tokens # Decode all tokens
output_text = self.tokenizer.decode( output_text = self.tokenizer.decode(
all_input_ids.squeeze(-1), skip_special_tokens=True all_input_ids.squeeze(-1), skip_special_tokens=True,
cleanup_tokenization_spaces=False
) )
# Slice with input_length to remove padding # Slice with input_length to remove padding
token_ids = all_input_ids[-new_input_length:] token_ids = all_input_ids[-new_input_length:]