mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Damned tensor equality.
This commit is contained in:
parent
650f45ce77
commit
fd705ef292
@ -1,5 +1,5 @@
|
|||||||
import re
|
import re
|
||||||
from typing import List, Optional, Tuple, Set
|
from typing import List, Optional, Tuple, Set, Union
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
@ -143,7 +143,7 @@ class StopSequenceCriteria:
|
|||||||
class StoppingCriteria:
|
class StoppingCriteria:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
eos_token_ids: Optional[Set[int]],
|
eos_token_ids: Optional[Union[Set[int], int]],
|
||||||
stop_sequence_criterias: List[StopSequenceCriteria],
|
stop_sequence_criterias: List[StopSequenceCriteria],
|
||||||
max_new_tokens: int = 20,
|
max_new_tokens: int = 20,
|
||||||
ignore_eos_token: bool = False,
|
ignore_eos_token: bool = False,
|
||||||
@ -170,6 +170,9 @@ class StoppingCriteria:
|
|||||||
if self.current_tokens >= self.max_new_tokens:
|
if self.current_tokens >= self.max_new_tokens:
|
||||||
return True, FinishReason.FINISH_REASON_LENGTH
|
return True, FinishReason.FINISH_REASON_LENGTH
|
||||||
|
|
||||||
|
if isinstance(last_token, torch.Tensor):
|
||||||
|
last_token = last_token.item()
|
||||||
|
|
||||||
if not self.ignore_eos_token and last_token in self.eos_token_ids:
|
if not self.ignore_eos_token and last_token in self.eos_token_ids:
|
||||||
return True, FinishReason.FINISH_REASON_EOS_TOKEN
|
return True, FinishReason.FINISH_REASON_EOS_TOKEN
|
||||||
|
|
||||||
@ -194,13 +197,8 @@ class StoppingCriteria:
|
|||||||
stop_sequence_criterias = [
|
stop_sequence_criterias = [
|
||||||
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
|
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
|
||||||
]
|
]
|
||||||
eos_token_id = tokenizer.eos_token_id
|
|
||||||
if eos_token_id is None:
|
|
||||||
eos_token_ids = {}
|
|
||||||
else:
|
|
||||||
eos_token_ids = eos_token_id
|
|
||||||
return StoppingCriteria(
|
return StoppingCriteria(
|
||||||
eos_token_ids,
|
tokenizer.eos_token_id,
|
||||||
stop_sequence_criterias,
|
stop_sequence_criterias,
|
||||||
pb.max_new_tokens,
|
pb.max_new_tokens,
|
||||||
pb.ignore_eos_token,
|
pb.ignore_eos_token,
|
||||||
|
Loading…
Reference in New Issue
Block a user