diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 9155fd54..49ef2d3b 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -1,5 +1,5 @@ import re -from typing import List, Optional, Tuple, Set +from typing import List, Optional, Tuple, Set, Union import math import torch @@ -143,7 +143,7 @@ class StopSequenceCriteria: class StoppingCriteria: def __init__( self, - eos_token_ids: Optional[Set[int]], + eos_token_ids: Optional[Union[Set[int], int]], stop_sequence_criterias: List[StopSequenceCriteria], max_new_tokens: int = 20, ignore_eos_token: bool = False, @@ -170,6 +170,9 @@ class StoppingCriteria: if self.current_tokens >= self.max_new_tokens: return True, FinishReason.FINISH_REASON_LENGTH + if isinstance(last_token, torch.Tensor): + last_token = last_token.item() + if not self.ignore_eos_token and last_token in self.eos_token_ids: return True, FinishReason.FINISH_REASON_EOS_TOKEN @@ -194,13 +197,8 @@ class StoppingCriteria: stop_sequence_criterias = [ StopSequenceCriteria(sequence) for sequence in pb.stop_sequences ] - eos_token_id = tokenizer.eos_token_id - if eos_token_id is None: - eos_token_ids = {} - else: - eos_token_ids = eos_token_id return StoppingCriteria( - eos_token_ids, + tokenizer.eos_token_id, stop_sequence_criterias, pb.max_new_tokens, pb.ignore_eos_token,