mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Damned tensor equality.
This commit is contained in:
parent
650f45ce77
commit
fd705ef292
@ -1,5 +1,5 @@
|
||||
import re
|
||||
from typing import List, Optional, Tuple, Set
|
||||
from typing import List, Optional, Tuple, Set, Union
|
||||
|
||||
import math
|
||||
import torch
|
||||
@ -143,7 +143,7 @@ class StopSequenceCriteria:
|
||||
class StoppingCriteria:
|
||||
def __init__(
|
||||
self,
|
||||
eos_token_ids: Optional[Set[int]],
|
||||
eos_token_ids: Optional[Union[Set[int], int]],
|
||||
stop_sequence_criterias: List[StopSequenceCriteria],
|
||||
max_new_tokens: int = 20,
|
||||
ignore_eos_token: bool = False,
|
||||
@ -170,6 +170,9 @@ class StoppingCriteria:
|
||||
if self.current_tokens >= self.max_new_tokens:
|
||||
return True, FinishReason.FINISH_REASON_LENGTH
|
||||
|
||||
if isinstance(last_token, torch.Tensor):
|
||||
last_token = last_token.item()
|
||||
|
||||
if not self.ignore_eos_token and last_token in self.eos_token_ids:
|
||||
return True, FinishReason.FINISH_REASON_EOS_TOKEN
|
||||
|
||||
@ -194,13 +197,8 @@ class StoppingCriteria:
|
||||
stop_sequence_criterias = [
|
||||
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
|
||||
]
|
||||
eos_token_id = tokenizer.eos_token_id
|
||||
if eos_token_id is None:
|
||||
eos_token_ids = {}
|
||||
else:
|
||||
eos_token_ids = eos_token_id
|
||||
return StoppingCriteria(
|
||||
eos_token_ids,
|
||||
tokenizer.eos_token_id,
|
||||
stop_sequence_criterias,
|
||||
pb.max_new_tokens,
|
||||
pb.ignore_eos_token,
|
||||
|
Loading…
Reference in New Issue
Block a user