mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Moving logic closer to use.
This commit is contained in:
parent
f33fccfb13
commit
f6243fc8ad
@ -143,11 +143,21 @@ class StopSequenceCriteria:
|
||||
class StoppingCriteria:
|
||||
def __init__(
|
||||
self,
|
||||
eos_token_ids: Set[int],
|
||||
eos_token_ids: Optional[Set[int]],
|
||||
stop_sequence_criterias: List[StopSequenceCriteria],
|
||||
max_new_tokens: int = 20,
|
||||
ignore_eos_token: bool = False,
|
||||
):
|
||||
if eos_token_ids is None:
|
||||
eos_token_ids = set()
|
||||
elif isinstance(eos_token_ids, int):
|
||||
eos_token_ids = set(eos_token_ids)
|
||||
elif isinstance(eos_token_ids, set):
|
||||
eos_token_ids = eos_token_ids
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"eos_token_ids is of invalid type {type(eos_token_ids)}, expected int, None or set[int]"
|
||||
)
|
||||
self.eos_token_ids = eos_token_ids
|
||||
self.stop_sequence_criterias = stop_sequence_criterias
|
||||
self.max_new_tokens = max_new_tokens
|
||||
@ -185,11 +195,8 @@ class StoppingCriteria:
|
||||
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
|
||||
]
|
||||
eos_token_id = tokenizer.eos_token_id
|
||||
|
||||
if eos_token_id is None:
|
||||
eos_token_ids = {}
|
||||
elif isinstance(eos_token_id, set):
|
||||
eos_token_ids = eos_token_id
|
||||
else:
|
||||
eos_token_ids = eos_token_id
|
||||
return StoppingCriteria(
|
||||
|
Loading…
Reference in New Issue
Block a user