Moving logic closer to use.

This commit is contained in:
Nicolas Patry 2024-04-25 16:02:11 +02:00
parent f33fccfb13
commit f6243fc8ad

View File

@ -143,11 +143,21 @@ class StopSequenceCriteria:
class StoppingCriteria: class StoppingCriteria:
def __init__( def __init__(
self, self,
eos_token_ids: Set[int], eos_token_ids: Optional[Set[int]],
stop_sequence_criterias: List[StopSequenceCriteria], stop_sequence_criterias: List[StopSequenceCriteria],
max_new_tokens: int = 20, max_new_tokens: int = 20,
ignore_eos_token: bool = False, ignore_eos_token: bool = False,
): ):
if eos_token_ids is None:
eos_token_ids = set()
elif isinstance(eos_token_ids, int):
eos_token_ids = set(eos_token_ids)
elif isinstance(eos_token_ids, set):
eos_token_ids = eos_token_ids
else:
raise RuntimeError(
f"eos_token_ids is of invalid type {type(eos_token_ids)}, expected int, None or set[int]"
)
self.eos_token_ids = eos_token_ids self.eos_token_ids = eos_token_ids
self.stop_sequence_criterias = stop_sequence_criterias self.stop_sequence_criterias = stop_sequence_criterias
self.max_new_tokens = max_new_tokens self.max_new_tokens = max_new_tokens
@ -185,11 +195,8 @@ class StoppingCriteria:
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
] ]
eos_token_id = tokenizer.eos_token_id eos_token_id = tokenizer.eos_token_id
if eos_token_id is None: if eos_token_id is None:
eos_token_ids = {} eos_token_ids = {}
elif isinstance(eos_token_id, set):
eos_token_ids = eos_token_id
else: else:
eos_token_ids = eos_token_id eos_token_ids = eos_token_id
return StoppingCriteria( return StoppingCriteria(