Easiest fix.

This commit is contained in:
Nicolas Patry 2022-12-20 15:24:42 +01:00
parent 611e21cb13
commit d8e1ce669b

View File

@ -354,7 +354,8 @@ class CausalLM(Model):
if stop:
# Decode all tokens
output_text = self.tokenizer.decode(
all_input_ids.squeeze(-1), skip_special_tokens=True
all_input_ids.squeeze(-1), skip_special_tokens=True,
cleanup_tokenization_spaces=False
)
# Slice with input_length to remove padding
token_ids = all_input_ids[-new_input_length:]