mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Moving logic closer to use.
This commit is contained in:
parent
f33fccfb13
commit
f6243fc8ad
@ -143,11 +143,21 @@ class StopSequenceCriteria:
|
|||||||
class StoppingCriteria:
|
class StoppingCriteria:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
eos_token_ids: Set[int],
|
eos_token_ids: Optional[Set[int]],
|
||||||
stop_sequence_criterias: List[StopSequenceCriteria],
|
stop_sequence_criterias: List[StopSequenceCriteria],
|
||||||
max_new_tokens: int = 20,
|
max_new_tokens: int = 20,
|
||||||
ignore_eos_token: bool = False,
|
ignore_eos_token: bool = False,
|
||||||
):
|
):
|
||||||
|
if eos_token_ids is None:
|
||||||
|
eos_token_ids = set()
|
||||||
|
elif isinstance(eos_token_ids, int):
|
||||||
|
eos_token_ids = set(eos_token_ids)
|
||||||
|
elif isinstance(eos_token_ids, set):
|
||||||
|
eos_token_ids = eos_token_ids
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"eos_token_ids is of invalid type {type(eos_token_ids)}, expected int, None or set[int]"
|
||||||
|
)
|
||||||
self.eos_token_ids = eos_token_ids
|
self.eos_token_ids = eos_token_ids
|
||||||
self.stop_sequence_criterias = stop_sequence_criterias
|
self.stop_sequence_criterias = stop_sequence_criterias
|
||||||
self.max_new_tokens = max_new_tokens
|
self.max_new_tokens = max_new_tokens
|
||||||
@ -185,11 +195,8 @@ class StoppingCriteria:
|
|||||||
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
|
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
|
||||||
]
|
]
|
||||||
eos_token_id = tokenizer.eos_token_id
|
eos_token_id = tokenizer.eos_token_id
|
||||||
|
|
||||||
if eos_token_id is None:
|
if eos_token_id is None:
|
||||||
eos_token_ids = {}
|
eos_token_ids = {}
|
||||||
elif isinstance(eos_token_id, set):
|
|
||||||
eos_token_ids = eos_token_id
|
|
||||||
else:
|
else:
|
||||||
eos_token_ids = eos_token_id
|
eos_token_ids = eos_token_id
|
||||||
return StoppingCriteria(
|
return StoppingCriteria(
|
||||||
|
Loading…
Reference in New Issue
Block a user