mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
add clean_up_tokenization_spaces
This commit is contained in:
parent
6a56f945c0
commit
56d23753bb
@ -36,6 +36,8 @@ class Model(ABC):
|
|||||||
def decode_token(self, token_id: int) -> str:
|
def decode_token(self, token_id: int) -> str:
|
||||||
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
|
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
|
||||||
# append token to special decode token and decode both
|
# append token to special decode token and decode both
|
||||||
result = self.tokenizer.decode([self.special_decode_token_id, token_id], skip_special_tokens=False)
|
result = self.tokenizer.decode(
|
||||||
|
[self.special_decode_token_id, token_id], skip_special_tokens=False
|
||||||
|
)
|
||||||
# slice to remove special decode token
|
# slice to remove special decode token
|
||||||
return result[self.special_decode_token_length:]
|
return result[self.special_decode_token_length :]
|
||||||
|
@ -342,7 +342,9 @@ class Seq2SeqLM(Model):
|
|||||||
return Seq2SeqLMBatch
|
return Seq2SeqLMBatch
|
||||||
|
|
||||||
def decode(self, decoder_ids: List[int]) -> str:
|
def decode(self, decoder_ids: List[int]) -> str:
|
||||||
return self.tokenizer.decode(decoder_ids, skip_special_tokens=True)
|
return self.tokenizer.decode(
|
||||||
|
decoder_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
Loading…
Reference in New Issue
Block a user