Damned tensor equality.

This commit is contained in:
Nicolas Patry 2024-04-25 15:04:53 +00:00
parent 650f45ce77
commit fd705ef292

View File

@ -1,5 +1,5 @@
import re
from typing import List, Optional, Tuple, Set
from typing import List, Optional, Tuple, Set, Union
import math
import torch
@ -143,7 +143,7 @@ class StopSequenceCriteria:
class StoppingCriteria:
def __init__(
self,
eos_token_ids: Optional[Set[int]],
eos_token_ids: Optional[Union[Set[int], int]],
stop_sequence_criterias: List[StopSequenceCriteria],
max_new_tokens: int = 20,
ignore_eos_token: bool = False,
@ -170,6 +170,9 @@ class StoppingCriteria:
if self.current_tokens >= self.max_new_tokens:
return True, FinishReason.FINISH_REASON_LENGTH
if isinstance(last_token, torch.Tensor):
last_token = last_token.item()
if not self.ignore_eos_token and last_token in self.eos_token_ids:
return True, FinishReason.FINISH_REASON_EOS_TOKEN
@ -194,13 +197,8 @@ class StoppingCriteria:
stop_sequence_criterias = [
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
]
eos_token_id = tokenizer.eos_token_id
if eos_token_id is None:
eos_token_ids = {}
else:
eos_token_ids = eos_token_id
return StoppingCriteria(
eos_token_ids,
tokenizer.eos_token_id,
stop_sequence_criterias,
pb.max_new_tokens,
pb.ignore_eos_token,