diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index a661eed9..34ed9b58 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -143,11 +143,21 @@ class StopSequenceCriteria: class StoppingCriteria: def __init__( self, - eos_token_ids: Set[int], + eos_token_ids: Optional[Set[int]], stop_sequence_criterias: List[StopSequenceCriteria], max_new_tokens: int = 20, ignore_eos_token: bool = False, ): + if eos_token_ids is None: + eos_token_ids = set() + elif isinstance(eos_token_ids, int): + eos_token_ids = set(eos_token_ids) + elif isinstance(eos_token_ids, set): + eos_token_ids = eos_token_ids + else: + raise RuntimeError( + f"eos_token_ids is of invalid type {type(eos_token_ids)}, expected int, None or set[int]" + ) self.eos_token_ids = eos_token_ids self.stop_sequence_criterias = stop_sequence_criterias self.max_new_tokens = max_new_tokens @@ -185,11 +195,8 @@ class StoppingCriteria: StopSequenceCriteria(sequence) for sequence in pb.stop_sequences ] eos_token_id = tokenizer.eos_token_id - if eos_token_id is None: eos_token_ids = {} - elif isinstance(eos_token_id, set): - eos_token_ids = eos_token_id else: eos_token_ids = eos_token_id return StoppingCriteria(